[BZOJ4987] Tree
Meaning
给定一棵 个点的边带权的树,找出 个点 ,使得 最小。
Sol
我们考虑简单特殊情况。当 时,我们必然会按照一条主链的路径向下遍历,当遇到子树时我们向内遍历并返回主链,此时我们发现除主链以外的路径仅被遍历 次,其余的子树内的点则会被遍历 次。
此时我们将这个统计答案的方式拓展到一般的情况上,易知选择的 条边必定存在相邻关系,那么我们只需要找到一棵大小为 的连通子树使得 最小,其中 表示所选连通子树中直径长度。
那么我们分类讨论下直径可能的转移状态和转移方程:
- 令 表示以点 为根的子树内选取 个点的最小边权和;
- 令 表示以点 为根的子树内选取 个点使得 的最小值,且直径中存在一个端点为点 ;
- 令 表示以点 为根的子树内选取 个点使得 的最小值。
此时我们可以考虑鉴于一种类似线段树 Pushup 的方式更新答案(仙人所言),相互维护以上的信息,并考虑以树上背包枚举已知点和未知点的方式转移。如下
Detail
- 写树形背包 dp 时尤其注意状态设计是否合理,状态转移方程是否正确;
- 注意如何处理边界条件很重要。
Code
#include <iostream>
#include <cstdio>
#include <cctype>
using namespace std;
template <typename T>
inline T read() {
T x = 0, f = 1; char c = getchar();
while (!isdigit(c)) { if (c == '-') f = - f; c = getchar(); }
while (isdigit(c)) x = x * 10 + (c ^ 48), c = getchar();
return x * f;
}
#define lint long long int
#define ulint unsigned lint
#define readint read <int> ()
#define readlint read <lint> ()
const int inf = 1e9 + 1e7, MAXN = 3e3 + 1e1;
const lint INF = 1e18 + 1e9;
struct Edge { int nxt, to; lint w; } e[MAXN << 1];
lint dp[MAXN][MAXN][3], Ans = INF;
int head[MAXN], fa[MAXN], sz[MAXN], n, k, tot;
void Addedge(int u, int v, lint w) { e[++ tot] = (Edge){head[u], v, w}, head[u] = tot; return ; }
void Dp(int u, int v, lint w) {
for (int i = min(sz[u], k); i >= 0; i --) for (int j = min(sz[v], k); j >= 0; j --) if (i + j <= k) {
dp[u][i + j][0] = min(dp[u][i][0] + dp[v][j][0] + w, dp[u][i + j][0]);
dp[u][i + j][1] = min(dp[v][j][1] + 2 * dp[u][i][0] + w, dp[u][i + j][1]);
dp[u][i + j][1] = min(dp[u][i][1] + 2 * (dp[v][j][0] + w), dp[u][i + j][1]);
dp[u][i + j][2] = min(dp[u][i][2] + 2 * (dp[v][j][0] + w), dp[u][i + j][2]);
dp[u][i + j][2] = min(dp[v][j][2] + 2 * (dp[u][i][0] + w), dp[u][i + j][2]);
dp[u][i + j][2] = min(dp[u][i][1] + dp[v][j][1] + w, dp[u][i + j][2]);
}
return ;
}
void Dfs(int u) {
sz[u] = 1;
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to; lint w = e[i].w;
if (fa[u] == v) continue ; fa[v] = u;
Dfs(v), Dp(u, v, w);
sz[u] += sz[v];
}
return ;
}
int main(void) {
n = readint, k = readint;
for (int i = 1; i <= n; i ++) for (int j = 2; j <= k; j ++) {
for (int l = 0; l < 3; l ++) dp[i][j][l] = INF;
}
for (int i = 1; i < n; i ++) {
int u = readint, v = readint; lint w = readlint;
Addedge(u, v, w), Addedge(v, u, w);
}
fa[1] = 1, Dfs(1);
for (int i = 1; i <= n; i ++) Ans = min(dp[i][k][2], Ans);
printf("%lld\n", Ans);
return 0;
}