P5664 [CSP2019]Emiya 家今天的饭

P5664 [CSP2019]Emiya 家今天的饭

Meaning

考虑一个 n×mn \times m 的矩形,位置 (i,j)(i,j) 处权值为 ai,ja_{i,j}

对于所有的 nn 行,每一行至多可以选出 11 个数;

对于所有的 mm 列,如果一共选出了 kk 个数,每一列至多可以选出 k2\lfloor\frac{k}{2}\rfloor 个数。

一种方案的贡献为其选取出的所有权值的乘积,即 ai,j\prod{a_{i,j}}

Sol

考虑容斥,所有的选取方案数为 i(jai,j+1)1\prod_{i} ({ \sum_{j} a_{i,j} } + 1) - 1

不合法的方案一定为存在某一种列中选取超过 k2\lfloor\frac{k}{2}\rfloor 次,此时一个简单的观察是,同一种不合法的方案中,有且仅有一列选取超过了限制。那么我们枚举某一列强制超出限制,此时考虑使用 dp 实现不合法方案计数。

假设此时强制第 ll 列将超出限制,令 dpi,j,kdp_{i,j,k} 表示选取到第 ii 行,第 ll 列选取有 jj 个元素,其他列选取有 kk 个元素。那么显然有转移方程为

dpi,j,k=dpi1,j1,k×ai,l+dpi1,j,k1×(jai,jai,l)+dpi1,j,kdp_{i,j,k} = dp_{i-1,j-1,k} \times a_{i,l} + dp_{i-1,j,k-1} \times \Big( \sum_{j} a_{i,j} - a_{i,l} \Big) + dp_{i-1,j,k}

那么最后统计不合法方案数的答案为 j>kdpn,j,k\sum_{ j > k} {dp_{n,j,k} },此时时间复杂度为 O(n3m)\text{O}(n^3 \cdot m)

n = readint, m = readint, Ans = 1ll;
for (int i = 1; i <= n; i ++) for (int j = 1; j <= m; j ++) a[i][j] = readlint;
for (int i = 1; i <= n; i ++) for (int j = 1; j <= m; j ++) (sum[i] += a[i][j]) %= Mod;
for (int i = 1; i <= n; i ++) (Ans *= (sum[i] + 1) % Mod) %= Mod; (Ans += (Mod - 1)) %= Mod;
for (int l = 1; l <= m; l ++) {
    dp[0][0][0] = 1ll;
    for (int i = 1; i <= n; i ++) for (int j = 0; j <= n; j ++) for (int k = 0; j + k <= i; k ++) {
        dp[i][j][k] = dp[i - 1][j][k];
        if (j) (dp[i][j][k] += dp[i - 1][j - 1][k] * a[i][l] % Mod) %= Mod;
        if (k) (dp[i][j][k] += (sum[i] - a[i][l] + Mod) % Mod * dp[i - 1][j][k - 1] % Mod) %= Mod;
    }
    for (int i = 1; i <= n; i ++) for (int j = 0; j < i; j ++) (Ans += (Mod - dp[n][i][j]) % Mod) %= Mod;
    memset(dp, 0, sizeof(dp));
}
printf("%lld\n", Ans);

考虑如何优化,我们发现最后统计答案时,我们只关心 i,ji,j 之间的大小关系是否满足 i>ji>j,那么考虑优化 dp 转移状态,令 dpi,jdp_{i,j} 表示选取到第 ii 行时,第 ll 列选取的元素比其他列多 jj 个(注意可能为负)。那么最后统一答案仅需考虑 jj 值是否为正。那么新的转移方程为

dpi,j=dpi1,j1×ai,l+dpi1,j+1×(jai,jai,l)+dpi1,jdp_{i,j} = dp_{i-1,j-1} \times a_{i,l} + dp_{i-1,j+1} \times \Big( \sum_{j} a_{i,j} - a_{i,l} \Big) + dp_{i-1,j}

那么最后统计不合法方案数的答案为 i>0dpn,i\sum_{i>0} dp_{n,i},此时的时间复杂度为 O(n2m)\text{O}(n^2 \cdot m)

Detail

  1. 注意下边界处理和数组下标为负时整体上移;
  2. 计数 dp 一定要整理好思路。

Code

#include <iostream>
#include <cstring>
#include <cstdio>
#include <cctype>
using namespace std;

template <typename T>
inline T read() {
    T x = 0, f = 1; char c = getchar();
    while (!isdigit(c)) { if (c == '-') f = - f; c = getchar(); }
    while (isdigit(c)) x = x * 10 + (c ^ 48), c = getchar();
    return x * f;
}

#define lint long long int
#define ulint unsigned lint
#define readint read <int> ()
#define readlint read <lint> ()
const int inf = 1e9 + 1e7, MAXN = 1e2 + 1e1, MAXM = 2e3 + 1e1, Tp = 1e2 + 1e0, Mod = 998244353;
const lint INF = 1e18 + 1e9;

lint dp[MAXN][MAXN << 1], a[MAXN][MAXM], sum[MAXN], Ans;
int n, m;

int main(void) {

    n = readint, m = readint, Ans = 1ll;
    for (int i = 1; i <= n; i ++) for (int j = 1; j <= m; j ++) a[i][j] = readlint;
    for (int i = 1; i <= n; i ++) for (int j = 1; j <= m; j ++) (sum[i] += a[i][j]) %= Mod;
    for (int i = 1; i <= n; i ++) (Ans *= (sum[i] + 1) % Mod) %= Mod; (Ans += (Mod - 1)) %= Mod;
    for (int l = 1; l <= m; l ++) {
        dp[0][Tp] = 1ll;
        for (int i = 1; i <= n; i ++) for (int j = Tp - i; j <= Tp + i; j ++) {
            if (i + j > Tp && j - i < Tp) dp[i][j] = dp[i - 1][j];
            if (i + j > Tp) (dp[i][j] += dp[i - 1][j - 1] * a[i][l] % Mod) %= Mod;
            if (j - i < Tp) (dp[i][j] += (sum[i] - a[i][l] + Mod) % Mod * dp[i - 1][j + 1] % Mod) %= Mod;
        }
        for (int i = 1; i <= n; i ++) (Ans += (Mod - dp[n][i + Tp]) % Mod) %= Mod;
        memset(dp, 0, sizeof(dp));
    }
    printf("%lld\n", Ans);

    return 0;
}