AOJ 2378 SolveMe

最近記事書いてなかったので書きます。
雰囲気です(数式の書き方がわからない...)。


A^ x B A^  y B A^ z= I

を満たす  A B の組み合わせを求める問題
まあ、まずこの定式化をちゃんとできるかみたいな話からありますが...。

基本的に  A B全単射である必要があります。( x = 0, y = 0, z = 0 のときはこれに限りません)
まあ、部屋1と部屋2から右耳に乗ると同じ部屋に行く、みたいなことがあったら困るし、いけない部屋あっても当然困るので。
変形させると


\begin{align}
A^ x B A^  y B A^ z &= I  \\
A^ {x+z} B A^  y B &= I  \\
A^ {x-y+z} A^ y B A^  y B &= I  \\
A^ {x-y+z} (A^ y B)^ 2 &= I  \\
A^ {x-y+z} &= (A^ y B)^ {-2}  \\
A^ {x-y+z} &= C^ 2 \\
\end{align}

 C = (A^ y B)^ {-1} としました。
 C B によって適当に決められるので、ある変換の x-y+z 乗がある変換の 2 乗で表せられるような組み合わせを考えることになります。


さて、この変換はいくつかのサイクルからなるので、それを元にDPを考えます。
 A^ {x-y+z} = C^ 2 = D とすると

dp[i][j] :=  Dについて大きさ i のサイクルまで考えた、j 個決まっている時の組み合わせ

また、このDPを計算するために

cycleA[i][j] := サイズj写像x-y+z乗したとき全て大きさiのサイクルになる組み合わせ
cycleB[i][j] := サイズj写像2乗したとき全て大きさiのサイクルになる組み合わせ

を定義します

遷移は

\begin{align}
dp \lbrack i+1 \rbrack \lbrack j+k \rbrack &+ = dp\lbrack i\rbrack \lbrack j \rbrack \\
&× combination(j+k, k) \\
&× cycleA \lbrack i+1 \rbrack \lbrack k \rbrack \\
&× cycleB \lbrack i+1 \rbrack \lbrack k \rbrack \\
&÷ 大きさi+1のサイクルをk/(i+1)個作る組み合わせ
\end{align}

みたいな感じ
あるサイクルの大きさを s とするとそれを  t 乗したとき  gcd(s, t) 個に分かれることを考えながらうまくcycleA、cycleBを計算します。
 x = 0, y = 0, z = 0 のときは  A全単射である必要がないのでその分をかけてあげます。

頭壊れる。

#include <bits/stdc++.h>
  
using namespace std;
  
#define REP(i,n) for(ll (i) = (0);(i) < (n);++i)
#define REV(i,n) for(ll (i) = (n) - 1;(i) >= 0;--i)
#define PB push_back
#define EB emplace_back
#define MP make_pair
#define FI first
#define SE second
#define SHOW1d(v,n) {REP(WW,n)cerr << v[WW] << ' ';cerr << endl << endl;}
#define SHOW2d(v,WW,HH) {REP(W_,WW){REP(H_,HH)cerr << v[W_][H_] << ' ';cerr << endl;}cerr << endl;}
#define ALL(v) v.begin(),v.end()
#define Decimal fixed<<setprecision(20)
#define INF 1000000000
#define LLINF 1000000000000000000LL
#define MOD 1000000007
  
typedef long long ll;
typedef pair<ll,ll> P;
  
const ll N_MAX = 1001;
  
ll dp[N_MAX][N_MAX];
ll cycle[2][N_MAX][N_MAX];
ll fact[N_MAX];
ll invfact[N_MAX];
ll factPow[N_MAX][N_MAX];
ll invfactPow[N_MAX][N_MAX];
ll comb[N_MAX][N_MAX];
  
//a^b % MOD
ll mod_pow(ll a,ll b){
    ll ret = 1;
    ll c = a;
    for(int i = 0;i <= 60;i++){
        if(b & (1LL << i))ret = (ret * c) % MOD;
        c = (c * c) % MOD;
    }
    return ret;
}
//aをbで割る
ll mod_div(ll a,ll b){
    ll tmp = MOD - 2,c = b,ret = 1;
    while(tmp > 0){
        if(tmp & 1){
            ret *= c;ret %= MOD;
        }
        c *= c;c %= MOD;tmp >>= 1;
    }
    return a*ret%MOD;
}
  
void mod_add(ll &a, ll b){
    a += b;
    if(a >= MOD)a -= MOD;
}
  
ll gcd(ll a, ll b){
    return (b == 0 ? a : gcd(b, a % b));
}
  
void calComb() {
    REP(i, N_MAX){
        REP(j, i + 1){
            if(j == 0 || j == i){
                comb[i][j] = 1;
            }
            else {
                comb[i][j] = (comb[i-1][j-1] + comb[i-1][j]) % MOD;
            }
        }
    }
}
  
void calFact() {
    fact[0] = 1;
    invfact[0] = 1;
    for(ll i = 0;i < N_MAX ;i++){
        if(i < N_MAX - 1) {
            fact[i+1] = fact[i] * (i + 1) % MOD;
            invfact[i+1] = mod_div(1, fact[i+1]);
        }
  
        factPow[i][0] = 1;
        invfactPow[i][0] = 1;
        for(int j = 1;j < N_MAX;j++){
            factPow[i][j] = factPow[i][j-1] * fact[i] % MOD;
            invfactPow[i][j] = invfactPow[i][j-1] * invfact[i] % MOD;
        }
    }
}
  
void calCycle(ll type, ll t) {
    for(int i = 0;i < N_MAX;i++){
        cycle[type][i][0] = 1;
    }
    for(ll i = 1;i < N_MAX;i++) {
        ll now = i / gcd(i, t);
        vector<ll> tmp(N_MAX);
        for(ll j = 0;i * j < N_MAX;j++){
            ll m_kake = invfactPow[i][j] * factPow[i-1][j] % MOD;
            for(ll k = 0;k + i * j < N_MAX;k++){
                ll kake = fact[k + i * j] * invfact[k] % MOD;
                kake = kake * m_kake % MOD;
                kake = kake * invfact[j] % MOD;
                mod_add(tmp[k + i * j], cycle[type][now][k] * kake % MOD);
            }
        }
  
        for(ll j = 0;j < N_MAX;j++){
            cycle[type][now][j] = tmp[j];
        }
    }
}
  
int main(){
    cin.tie(0);cout.tie(0);ios::sync_with_stdio(false);
  
    ll n, x, y, z;cin >> n >> x >> y >> z;
    ll t = abs(x - y + z);
  
    calFact();
    calComb();
    calCycle(0, t);
    calCycle(1, 2);
  
    dp[0][0] = 1;
    for(ll i = 1;i <= n;i++){
        for(ll j = n;j >= 0;j--){
            for(ll k = 0;k * i <= j;k++){
                ll pre = j - k * i;
                ll kake = cycle[0][i][k * i] * cycle[1][i][k * i] % MOD;
                kake = kake * comb[j][k * i] % MOD;
                ll wari = fact[k * i] * invfactPow[i][k] % MOD;
                wari = wari * invfact[k] % MOD;
                wari = wari * factPow[i-1][k] % MOD;
                mod_add(dp[i][j], dp[i-1][pre] * mod_div(kake, wari) % MOD);
            }
        }
    }
  
    if(x == 0 && y == 0 && z == 0)cout << (dp[n][n] * invfact[n] % MOD)* mod_pow(n, n) % MOD << endl;
    else cout << dp[n][n] << endl;
  
    return 0;
}