ABC273 G - Row Column Sums 2

やっていること自体は公式解説と同じなのですが、状態の持ち方が割と素直かも?と思ったので

問題

atcoder.jp

 N 次正方行列であって、以下を共に満たすものの個数を  998244353 で割った余りで求めよ。
・すべての  1 \le i \le N について、 i 行目の総和が  R_i に一致
・すべての  1 \le i \le N について、 i 列目の総和が  C_i に一致

制約

 1 \le N \le 5000
 0 \le R_i, C_i \le 2

解法

行の総和と列の総和が一致しない場合は明らかに条件を満たしえないので、以下は行の総和と列の総和が一致しているものとして考える。

 R_i の昇順に行を、 C_i の昇順に列を並び替えたとしても求める値は変わらないので、そのように並び替えたものを考える。すると行列の内、 R_i > 0, C_i > 0 を満たすような右下の長方形領域だけを考えれば良いことが分かる。また、 R_i = 1 の行(以下  1 の行と定め、他についても同様に定める)についてどの列に  1 を書き込むかが決まればその行はもう以後考える必要がなくなり、前の長方形領域よりも小さい長方形領域について考えれば十分になる。同様に  2 の行についても書き込み方を定めていくことで、最終的な答えが求められる。以上のことを踏まえると再帰的な dp をしたくなり、以下のような状態を定めることができる。

 dp_{r1, r2, c1, c2} 1 の行が  r1 行、 2 の行が  r2 行、 1 の列が  c1 列、 2 の列が  c2 列ある場合についての条件を満たす行列の書き込み方の個数

さて、あとはこの dp の遷移(詳細は下記実装を参照)を丁寧に一つずつ実装すれば正しい答えの求まるアルゴリズムになるが、問題は計算量である。一見状態数が  O(N^{4}) で到底間に合わなさそうだが、よくよく考えてみると遷移先の状態において常に行の総和と列の総和の一致、すなわち  r1 + 2 \times r2 = c1 + 2 \times c2 が成り立つので、  r1 r2 が決まれば、 c1 c2 が決まることで一意に定まることが分かる。また  1 の行から処理していくことを考えると、 r1 + r2 O(N) であるので、実はとりうる状態数は  O(N^{2}) で抑えられる。遷移も  O(1) なので実はこれを素直にメモ化再帰で実装するだけで AC できる。

なお、メモ化の際に素直に std::map<std::tuple<int, int, int, int>, mint> とかすると  log がついて TLE してしまうので、適切に状態を int 型で計算して、std::vectorstd::unordered_map などでメモ化することで TL に間に合うようになる。詳細は実装例を参照。

atcoder.jp

#include <bits/stdc++.h>
#include <atcoder/modint>
using namespace std;
using mint = atcoder::modint998244353;

const mint inv2 = mint(1) / 2;
int R1, R2, C1, C2;
vector<int> memo; // メモ化再帰用の配列

int calcState(int r1, int r2, int c2) { // 状態を計算
  return (r1 + r2) * (C2 + 1) + c2;
}

mint dp(int r1, int r2, int c1, int c2) { // メモ化再帰
  if (r1 == 0 and r2 == 0 and c1 == 0 and c2 == 0) return 1;
  const int state = calcState(r1, r2, c2);
  if (memo[state] != -1) return memo[state];
  
  mint res = 0;
  if (r1 > 0) { // 1 の行から処理
    if (c1 > 0) res += dp(r1 - 1, r2, c1 - 1, c2) * c1; // 1 の列を 1 つ選んで 1 を書き込む
    if (c2 > 0) res += dp(r1 - 1, r2, c1 + 1, c2 - 1) * c2; // 2 の列を 1 つ選んで 1 を書き込む(1 の列に変化する)
  } else { // 2 の行の処理
    if (c1 > 1) res += dp(r1, r2 - 1, c1 - 2, c2) * c1 * (c1 - 1) * inv2; // 1 の列を 2 つ選んで 1 を書き込む
    if (c2 > 0) res += dp(r1, r2 - 1, c1, c2 - 1) * c2; // 2 の列を 1 つ選んで 2 を書き込む
    if (c1 > 0 and c2 > 0) res += dp(r1, r2 - 1, c1, c2 - 1) * c1 * c2; // 1 の列と 2 の列を 1 つずつ選んでそれぞれ 1 を書き込む
    if (c2 > 1) res += dp(r1, r2 - 1, c1 + 2, c2 - 2) * c2 * (c2 - 1) * inv2; // 2 の列を 2 つ選んで 1 を書き込む
  }
  return memo[state] = res.val();
}

int main() {
  // 入力 配列の要素をとりながら R1, R2, C1, C2 を count
  int n;
  cin >> n;
  for (int i = 0; i < n; i++) {
    int r;
    cin >> r;
    if (r == 1) R1++;
    if (r == 2) R2++;
  }
  for (int i = 0; i < n; i++) {
    int c;
    cin >> c;
    if (c == 1) C1++;
    if (c == 2) C2++;
  }
  
  // 答えを出力 行の総和と列の総和が一致していない場合は 0 であることに注意
  if (R1 + 2 * R2 != C1 + 2 * C2) {
    cout << 0 << '\n';
  } else {
    memo = vector<int>((R1 + R2) * (C2 + 1) + C2 + 1, -1);
    cout << dp(R1, R2, C1, C2).val() << '\n';
  }
}

感想

ぱっと見で  O(N^{2}) に収まるような状態の持ち方に見えなかったので、びっくりする。