ARC 067 F - Yakiniku Restaurants 別解

editorial で紹介されている解法は  O(N ^ 2+MN\log N) だが,別解として  O(N ^ 2 + MN) O(MN\log N) の 2 通りで解けたので ( N\leq 5000,\ M\leq 200 の制約だとあまり差は出なさそう).

問題

atcoder.jp

解法

基本的な考え方は,初めに訪れる焼肉店  l を降順に動かしながら差分を計算するというもの.

焼肉店  i-1 で終了する場合の最適値と焼肉店  i で終了する場合の最適値の差分  d_i を管理する. l を動かしたときに, d_i を高速に更新することが出来れば, l を固定した場合の最適値は  \displaystyle\max\left\{\sum_{i=l} ^ r d_i \;\middle|\;l\leq r\leq N\right\} として求まる.これは毎回愚直に求めても全体で  O(N ^ 2) である.

 \displaystyle
\begin{aligned}
X_j(l, r)&=\begin{cases} \displaystyle
\max _ {l \leq i \leq r} B _ {i, j} & \text{(if $\:1\leq l\leq r\leq N$)} \\
0 & \text{(otherwise)}
\end{cases},\\
Y_j(l, r)&=X_j(l, r)-X_j(l, r - 1)
\end{aligned}

とする.

 d の更新には, M 本の stack を用いる. j\;(1\leq j\leq M) 番目の stack を  S_j とし, S_j には (i,\;Y_j(l,i)) のペアを積んでいく.ただし, Y_j(l,i)=0 の場合は積まないものとする.

 l を動かすごとに,各  j に対して  S_j を以下の手続きにより更新する.

  1.  v:=B _ {l, j} とする.
  2.  S_j が空でなく,かつ  v\gt 0 であるならば,3 へ.そうでなければ,4 へ.
  3.  S_j の先頭  (x,\;y) を pop する. v\geq y ならば, v:=v-y として 2 へ.そうでなければ, S_j (x,\;y-v) を push して 4 へ.
  4.  S_j (l,\;B _ {l, j}) を push する.

つまるところ, B _ {i, j}\leq B _ {l, j} であるような  i に対応するペアを  S_j から取り除き,初めて  B _ {i, j} \gt B _ {l, j} となる  i に対してその差分を更新している.

 d に関しては, (x,\;y) が pop される度に  d _ x := d _ x - y と更新し,push される度に  d _ x := d _ x + y と更新する.また,移動距離を考慮して  d _ {l+1}:=d _ {l+1}-A _ l とする必要がある.

 S_j および  d の更新にかかる計算量は一見すると  \Theta(MN ^ 2 ) であるが,実は  O(MN) となっている.

[tex: O(MN)] の理由

 j に対して更新回数が全体で  O(N) となっていることが言えればよい.

手続きの 3. において push が起こる回数は,各  l に対して高々  1*1 で,合計  O(N) 回である.また, 4. において push が起こる回数も各  l に対して高々  1 回である.従って, S_j への push が起こる回数は  O(N) 回である. S_j から pop する回数は  S_j への push の回数以下なので,pop の回数も  O(N) である.

以上より,各  j に対して,手続き全体にかかる計算量は  O(N) である.

 S_j および  d の更新に  O(MN) \displaystyle\max\left\{\sum_{i=l} ^ r d_i \;\middle|\;l\leq r\leq N\right\} を求めるのに  O(N ^ 2) かかるので,結局全体  O(N ^ 2 + MN) でこの問題が解けた.

上の解法において, d の定義を累積和に変更して区間 add,区間 max の遅延セグ木に載せれば  O(MN\log N) となる.

実装 (Python)

提出リンク (AC)

def solve():
    N, M = map(int, input().split())
    A = [*map(int, input().split())]
    B = [[*map(int, input().split())] for _ in range(N)]
    d = [0] * N
    S = [[] for _ in range(M)]
    ans = 0
    for l in reversed(range(N)):
        if l < N - 1:
            d[l + 1] -= A[l]
        for j in range(M):
            v = B[l][j]
            while S[j] and v:
                x, y = S[j].pop()
                d[x] -= y
                if v >= y:
                    v -= y
                else:
                    S[j].append((x, y - v))
                    d[x] += y - v
                    break
            S[j].append((l, B[l][j]))
            d[l] += B[l][j]
        d_sum = 0
        for r in range(l, N):
            d_sum += d[r]
            ans = max(ans, d_sum)
    return ans

if __name__ == '__main__':
    print(solve())

*1:push が起こると即座に 4. へと移り,更新が終了するため