AtCoder Grand Contest 043 D - Merge Triplets 定数倍改善

editorial では最後のパートを DP で解いているが,組み合わせ的に処理すると定数倍の軽い  O(N ^ 2) で解くことが出来る.

集合 \displaystyle \left\{i\;\middle|\; i=0\text{ または } i \gt 0 \text{ かつ }P _ {i} \gt \max _ {0\leq j\lt i} P _ j\right\} を昇順にソートしてできる列を  T とする.さらに, T の長さを  L C _ j = \#\{ 1 \leq i \lt L \mid T _ i - T _ {i - 1} = j \} とする.

結論から言えば,順列  P を作ることが出来るための必要十分条件は以下が成り立つことである.

  •  C _ j = 0\quad (j\geq 4)
  •  C _ 2 + C _ 3 \leq N

以下, X=C _ 1 Y=C _ 2 Z=C _ 3 とする.

 0\leq Y + Z\leq N を満たす  Y,  Z を全探索して  P を数え上げる.まず, X Y,  Z を用いて  X=3N-2Y-3Z と表せる.

 X 個の長さ  1 の列に含まれる要素の選び方は  \dfrac{1}{X!}\begin{pmatrix}3N\\ 1\end{pmatrix}\times\cdots\times\begin{pmatrix}3N-X+1\\ 1\end{pmatrix} 通り, Y 個の長さ  2 の列に含まれる要素の選び方は  \dfrac{1}{Y!}\begin{pmatrix}3N-X\\ 2\end{pmatrix}\times\cdots\times\begin{pmatrix}3N-X-2Y+2\\ 2\end{pmatrix} 通り, Z 個の長さ  3 の列に含まれる要素の選び方は  \dfrac{1}{Z!}\begin{pmatrix}3N-X-2Y\\ 3\end{pmatrix}\times\cdots\times\begin{pmatrix}3N-X-2Y-3Z+3\\ 3\end{pmatrix} 通りである.

順列  P において  X+Y+Z 個の列は最大値の小さい順に現れるので,列の並べ方は  1 通りに定まる.

各列はその最大値が先頭に来るように並べる必要があるので,長さ  j に対して  (j-1)! 通りの並べ方が存在する.

以上より,

 \displaystyle
\begin{align}
&(0!)^X\times\dfrac{1}{X!}\begin{pmatrix}3N\\ 1\end{pmatrix}\times\cdots\times\begin{pmatrix}3N-X+1\\ 1\end{pmatrix}\\
\times
&(1!)^Y\times\dfrac{1}{Y!}\begin{pmatrix}3N-X\\ 2\end{pmatrix}\times\cdots\times\begin{pmatrix}3N-X-2Y+2\\ 2\end{pmatrix}\\
\times
&(2!)^Z\times\dfrac{1}{Z!}\begin{pmatrix}3N-X-2Y\\ 3\end{pmatrix}\times\cdots\times\begin{pmatrix}3N-X-2Y-3Z+3\\ 3\end{pmatrix}\\
=&\frac{(3N)!}{X!\times Y!\times Z!\times 1^X\times 2^Y\times 3^Z}
\end{align}

 Y,  Z を固定した場合の寄与である.以上より,求める答えは次で表される.

 \displaystyle
\sum_{Y=0}^{N}\sum_{Z=0}^{N-Y}\frac{(3N)!}{(3N-2Y-3Z)!\times Y!\times Z!\times 2^Y\times 3^Z}

実装 (Python)

PyPy3 (7.3.0) で以下のコードを提出して 273 ms で AC が得られた.

N, M = map(int, input().split())
T = 3 * N

def mul_mod(*iterable):
    res = 1
    for v in iterable:
        res = res * v % M
    return res

facT = mul_mod(*range(1, T + 1))

fac_inv = [0] * T + [pow(facT, M - 2, M)]
for i in range(T, 0, -1):
    fac_inv[i - 1] = fac_inv[i] * i % M

pow2_inv = [1] + [0] * N
inv_2 = pow(2, M - 2, M)
for i in range(N):
    pow2_inv[i + 1] = pow2_inv[i] * inv_2 % M

pow3_inv = [1] + [0] * N
inv_3 = pow(3, M - 2, M)
for i in range(N):
    pow3_inv[i + 1] = pow3_inv[i] * inv_3 % M

print(sum(sum(mul_mod(facT, fac_inv[T - 2 * y - 3 * z], fac_inv[y], fac_inv[z], pow2_inv[y], pow3_inv[z]) for z in range(N + 1 - y)) for y in range(N + 1)) % M)