P5664 [CSP2019]Emiya 家今天的饭
Meaning
考虑一个 的矩形,位置 处权值为 。
对于所有的 行,每一行至多可以选出 个数;
对于所有的 列,如果一共选出了 个数,每一列至多可以选出 个数。
一种方案的贡献为其选取出的所有权值的乘积,即 。
Sol
考虑容斥,所有的选取方案数为 。
不合法的方案一定为存在某一种列中选取超过 次,此时一个简单的观察是,同一种不合法的方案中,有且仅有一列选取超过了限制。那么我们枚举某一列强制超出限制,此时考虑使用 dp 实现不合法方案计数。
假设此时强制第 列将超出限制,令 表示选取到第 行,第 列选取有 个元素,其他列选取有 个元素。那么显然有转移方程为
那么最后统计不合法方案数的答案为 ,此时时间复杂度为 。
n = readint, m = readint, Ans = 1ll;
for (int i = 1; i <= n; i ++) for (int j = 1; j <= m; j ++) a[i][j] = readlint;
for (int i = 1; i <= n; i ++) for (int j = 1; j <= m; j ++) (sum[i] += a[i][j]) %= Mod;
for (int i = 1; i <= n; i ++) (Ans *= (sum[i] + 1) % Mod) %= Mod; (Ans += (Mod - 1)) %= Mod;
for (int l = 1; l <= m; l ++) {
dp[0][0][0] = 1ll;
for (int i = 1; i <= n; i ++) for (int j = 0; j <= n; j ++) for (int k = 0; j + k <= i; k ++) {
dp[i][j][k] = dp[i - 1][j][k];
if (j) (dp[i][j][k] += dp[i - 1][j - 1][k] * a[i][l] % Mod) %= Mod;
if (k) (dp[i][j][k] += (sum[i] - a[i][l] + Mod) % Mod * dp[i - 1][j][k - 1] % Mod) %= Mod;
}
for (int i = 1; i <= n; i ++) for (int j = 0; j < i; j ++) (Ans += (Mod - dp[n][i][j]) % Mod) %= Mod;
memset(dp, 0, sizeof(dp));
}
printf("%lld\n", Ans);
考虑如何优化,我们发现最后统计答案时,我们只关心 之间的大小关系是否满足 ,那么考虑优化 dp 转移状态,令 表示选取到第 行时,第 列选取的元素比其他列多 个(注意可能为负)。那么最后统一答案仅需考虑 值是否为正。那么新的转移方程为
那么最后统计不合法方案数的答案为 ,此时的时间复杂度为 。
Detail
- 注意下边界处理和数组下标为负时整体上移;
- 计数 dp 一定要整理好思路。
Code
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cctype>
using namespace std;
template <typename T>
inline T read() {
T x = 0, f = 1; char c = getchar();
while (!isdigit(c)) { if (c == '-') f = - f; c = getchar(); }
while (isdigit(c)) x = x * 10 + (c ^ 48), c = getchar();
return x * f;
}
#define lint long long int
#define ulint unsigned lint
#define readint read <int> ()
#define readlint read <lint> ()
const int inf = 1e9 + 1e7, MAXN = 1e2 + 1e1, MAXM = 2e3 + 1e1, Tp = 1e2 + 1e0, Mod = 998244353;
const lint INF = 1e18 + 1e9;
lint dp[MAXN][MAXN << 1], a[MAXN][MAXM], sum[MAXN], Ans;
int n, m;
int main(void) {
n = readint, m = readint, Ans = 1ll;
for (int i = 1; i <= n; i ++) for (int j = 1; j <= m; j ++) a[i][j] = readlint;
for (int i = 1; i <= n; i ++) for (int j = 1; j <= m; j ++) (sum[i] += a[i][j]) %= Mod;
for (int i = 1; i <= n; i ++) (Ans *= (sum[i] + 1) % Mod) %= Mod; (Ans += (Mod - 1)) %= Mod;
for (int l = 1; l <= m; l ++) {
dp[0][Tp] = 1ll;
for (int i = 1; i <= n; i ++) for (int j = Tp - i; j <= Tp + i; j ++) {
if (i + j > Tp && j - i < Tp) dp[i][j] = dp[i - 1][j];
if (i + j > Tp) (dp[i][j] += dp[i - 1][j - 1] * a[i][l] % Mod) %= Mod;
if (j - i < Tp) (dp[i][j] += (sum[i] - a[i][l] + Mod) % Mod * dp[i - 1][j + 1] % Mod) %= Mod;
}
for (int i = 1; i <= n; i ++) (Ans += (Mod - dp[n][i + Tp]) % Mod) %= Mod;
memset(dp, 0, sizeof(dp));
}
printf("%lld\n", Ans);
return 0;
}