BZOJ4987 Tree

[BZOJ4987] Tree

Meaning

给定一棵 nn 个点的边带权的树,找出 kk 个点 a1,a2,,aka_1,a_2,\cdots,a_k,使得 disai,ai+1\sum dis_{ a_i , a_{i+1} } 最小。

Sol

我们考虑简单特殊情况。当 k=nk=n 时,我们必然会按照一条主链的路径向下遍历,当遇到子树时我们向内遍历并返回主链,此时我们发现除主链以外的路径仅被遍历 11 次,其余的子树内的点则会被遍历 22 次。

此时我们将这个统计答案的方式拓展到一般的情况上,易知选择的 kk 条边必定存在相邻关系,那么我们只需要找到一棵大小为 kk 的连通子树使得 2×wedgelen2 \times \sum w_{edge} - len 最小,其中 lenlen 表示所选连通子树中直径长度。

那么我们分类讨论下直径可能的转移状态和转移方程:

  1. dpi,j,0dp_{i,j,0} 表示以点 ii 为根的子树内选取 jj 个点的最小边权和;
  2. dpi,j,1dp_{i,j,1} 表示以点 ii 为根的子树内选取 jj 个点使得 2×wedgelen2 \times \sum w_{edge} - len 的最小值,且直径中存在一个端点为点 ii
  3. dpi,j,2dp_{i,j,2} 表示以点 ii 为根的子树内选取 jj 个点使得 2×wedgelen2 \times \sum w_{edge} - len 的最小值。

此时我们可以考虑鉴于一种类似线段树 Pushup 的方式更新答案(仙人所言),相互维护以上的信息,并考虑以树上背包枚举已知点和未知点的方式转移。如下

dpu,i+j,0=min{dpu,i,0+dpv,j,0+w}dpu,i+j,1=min{dpv,j,1+2×dpu,i,0+w,dpu,i,1+2×(dpv,j,0+w)}dpu,i+j,1=min{dpu,i,2+2×(dpv,j,0+w),dpv,j,2+2×(dpu,i,0+w),dpu,i,1+dpv,j,1+w}dp_{u,i + j,0} = \min \{ dp_{u,i,0} + dp_{v,j,0} + w \} \\ dp_{u,i + j,1} = \min \{ dp_{v,j,1} + 2 \times dp_{u,i,0} + w, dp_{u,i,1} + 2\times (dp_{v,j,0} + w) \} \\ dp_{u,i+j,1} = \min \{ dp_{u,i,2} + 2 \times (dp_{v,j,0} + w), dp_{v,j,2} + 2 \times (dp_{u,i,0} + w), dp_{u,i,1} + dp_{v,j,1} + w \}

Detail

  1. 写树形背包 dp 时尤其注意状态设计是否合理,状态转移方程是否正确;
  2. 注意如何处理边界条件很重要。

Code

#include <iostream>
#include <cstdio>
#include <cctype>
using namespace std;

template <typename T>
inline T read() {
    T x = 0, f = 1; char c = getchar();
    while (!isdigit(c)) { if (c == '-') f = - f; c = getchar(); }
    while (isdigit(c)) x = x * 10 + (c ^ 48), c = getchar();
    return x * f;
}

#define lint long long int
#define ulint unsigned lint
#define readint read <int> ()
#define readlint read <lint> ()
const int inf = 1e9 + 1e7, MAXN = 3e3 + 1e1;
const lint INF = 1e18 + 1e9;

struct Edge { int nxt, to; lint w; } e[MAXN << 1];
lint dp[MAXN][MAXN][3], Ans = INF;
int head[MAXN], fa[MAXN], sz[MAXN], n, k, tot;

void Addedge(int u, int v, lint w) { e[++ tot] = (Edge){head[u], v, w}, head[u] = tot; return ; }

void Dp(int u, int v, lint w) {
    for (int i = min(sz[u], k); i >= 0; i --) for (int j = min(sz[v], k); j >= 0; j --) if (i + j <= k) {
        dp[u][i + j][0] = min(dp[u][i][0] + dp[v][j][0] + w, dp[u][i + j][0]);
        dp[u][i + j][1] = min(dp[v][j][1] + 2 * dp[u][i][0] + w, dp[u][i + j][1]);
        dp[u][i + j][1] = min(dp[u][i][1] + 2 * (dp[v][j][0] + w), dp[u][i + j][1]);
        dp[u][i + j][2] = min(dp[u][i][2] + 2 * (dp[v][j][0] + w), dp[u][i + j][2]);
        dp[u][i + j][2] = min(dp[v][j][2] + 2 * (dp[u][i][0] + w), dp[u][i + j][2]);
        dp[u][i + j][2] = min(dp[u][i][1] + dp[v][j][1] + w, dp[u][i + j][2]);
    }
    return ;
}

void Dfs(int u) {
    sz[u] = 1;
    for (int i = head[u]; i; i = e[i].nxt) {
        int v = e[i].to; lint w = e[i].w;
        if (fa[u] == v) continue ; fa[v] = u;
        Dfs(v), Dp(u, v, w);
        sz[u] += sz[v];
    }
    return ;
}

int main(void) {

    n = readint, k = readint;
    for (int i = 1; i <= n; i ++) for (int j = 2; j <= k; j ++) {
        for (int l = 0; l < 3; l ++) dp[i][j][l] = INF;
    }
    for (int i = 1; i < n; i ++) {
        int u = readint, v = readint; lint w = readlint;
        Addedge(u, v, w), Addedge(v, u, w);
    }
    fa[1] = 1, Dfs(1);
    for (int i = 1; i <= n; i ++) Ans = min(dp[i][k][2], Ans);
    printf("%lld\n", Ans);

    return 0;
}