P3177 [HAOI2015]树上染色

P3177 [HAOI2015]树上染色

Meaning

给定一棵 nn 个结点的树,此时需要选出 kk 个结点染成黑色,其余染成白色。

求使得所有黑点两两之间的距离和白点两两之间距离之和的最大值。

Sol

我们考虑如何简单记算贡献:由于一条边连接两棵子树,那么统计白点数分别为 w1,w2w_1,w_2,黑点数分别为 b1,b2b_1,b_2,则此边的贡献为 w1×w2+b1×b2w_1 \times w_2 + b_1 \times b_2

那么我们考虑一个树上背包的模型。令 dpi,jdp_{i,j} 表示以点 ii 为根的结点的子树内已有 jj 个点被染成黑色,那么我们直接统计子树后枚举黑色点数分配方案,此时时间复杂度为 O(n3)\text{O}(n^3)

考虑一个比较经典的优化:

我们枚举已经统计答案的子树内的点和未统计答案的子树内的点,他们之间的贡献值我们能够方便求出。同时由于每一对点我们有且仅有一次计算,那么最终会被计算 n2n^2 次,此时时间复杂度为 O(n2)\text{O}(n^2)

此时的状态转移方程为

dpu,i+j=max{dpu,i+dpv,j+(kj)×j×w+(szvj)×(nkszv+j)×w}dp_{u,i+j} = \max \{ dp_{u,i} + dp_{v,j} + (k - j) \times j \times w + (sz_v - j) \times (n - k - sz_v + j) \times w \}

Detail

  1. 开 long long。

Code

#include <iostream>
#include <cstdio>
#include <cctype>
using namespace std;

template <typename T>
inline T read() {
    T x = 0, f = 1; char c = getchar();
    while (!isdigit(c)) { if (c == '-') f = - f; c = getchar(); }
    while (isdigit(c)) x = x * 10 + (c ^ 48), c = getchar();
    return x * f;
}

#define lint long long int
#define ulint unsigned lint
#define readint read <int> ()
#define readlint read <lint> ()
const int inf = 1e9 + 1e7, MAXN = 2e3 + 1e1;
const lint INF = 1e18 + 1e9;

struct Edge {
    int nxt, to, w;
} e[MAXN << 1];
int head[MAXN], fa[MAXN], sz[MAXN], n, cnt, tot;
lint dp[MAXN][MAXN];

void Addedge(int u, int v, int w) {
    e[++ tot] = (Edge){head[u], v, w};
    head[u] = tot;
    return ;
}

void Dfs(int u) {
    sz[u] = 1;
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to, w = e[i].w;
        if (fa[u] == v) continue;
        fa[v] = u, Dfs(v);
        for (int j = min(cnt, sz[u]); j >= 0; j --) for (int k = min(cnt - j, sz[v]); k >= 0; k --) {
            dp[u][j + k] = max(dp[u][j] + dp[v][k] + 1ll * (cnt - k) * k * w + 1ll * (sz[v] - k) * (n - cnt - sz[v] + k) * w, dp[u][j + k]);
        }
        sz[u] += sz[v];
    }
    return ;
}

int main(void) {

    scanf("%d%d", &n, &cnt);
    for (int i = 1; i < n; i ++) {
        int u = readint, v = readint, w = readint;
        Addedge(u, v, w), Addedge(v, u, w);
    }
    Dfs(1);
    printf("%lld\n", dp[1][cnt]);

    return 0;
}