1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
| #include <bits/stdc++.h> #define next ___________________________________________________________________ using namespace std;
typedef pair<int, int> pii; const int MAXN = 400005;
int n, k, m, flag[MAXN], head[MAXN], to[MAXN], next[MAXN], w[MAXN], tot = 1, dist[MAXN], tmp, f[MAXN], root, size[MAXN], sum, vis[MAXN], ans, d[MAXN], num[MAXN], maxn, maxv[MAXN]; vector<pii> e;
inline void add(int x, int y, int z) { w[tot] = z, to[tot] = y, next[tot] = head[x], head[x] = tot++; }
inline void getroot(int now, int fa) { size[now] = 1, f[now] = 0; for (int i = head[now]; i; i = next[i]) { int v = to[i]; if (v == fa || vis[v]) continue; getroot(v, now), size[now] += size[v], f[now] = max(f[now], size[v]); } f[now] = max(f[now], sum - size[now]); if (f[now] < f[root]) root = now; }
inline void getdep(int now, int fa) { maxn = max(maxn, num[now]); for (int i = head[now]; i; i = next[i]) { int v = to[i]; if (v == fa || vis[v]) continue; d[v] = d[now] + w[i], num[v] = num[now] + flag[v], getdep(v, now); } }
inline void getmax(int now, int fa) { dist[num[now]] = max(dist[num[now]], d[now]); for (int i = head[now]; i; i = next[i]) { int v = to[i]; if (v == fa || vis[v]) continue; getmax(v, now); } }
inline void solve(int now) { vis[now] = 1, e.clear(); if (flag[now]) k--; for (int i = head[now]; i; i = next[i]) { int v = to[i]; if (vis[v]) continue; num[v] = flag[v], d[v] = w[i], maxn = 0, getdep(v, now), e.push_back(make_pair(maxn, v)); } sort(e.begin(), e.end()); for (int i = 0; i < e.size(); i++) { getmax(e[i].second, now); int res = 0; if (i != 0) for (int j = e[i].first; j >= 0; j--) { while (res + 1 + j <= k && res + 1 <= e[i - 1].first) res++, maxv[res] = max(maxv[res], maxv[res - 1]); if (res + j <= k) ans = max(ans, maxv[res] + dist[j]); } if (i != e.size() - 1) for (int j = 0; j <= e[i].first; j++) maxv[j] = max(maxv[j], dist[j]), dist[j] = 0; else for (int j = 0; j <= e[i].first; j++) maxv[j] = dist[j] = 0; } if (flag[now]) k++; for (int i = head[now]; i; i = next[i]) { int v = to[i]; if (vis[v]) continue; root = 0, f[0] = sum = size[v]; getroot(v, 0), solve(root); } }
int main() { scanf("%d%d%d", &n, &k, &m); for (int i = 1; i <= m; i++) { int x; scanf("%d", &x), flag[x] = 1; } for (int i = 1, u, v, w; i < n; i++) { scanf("%d%d%d", &u, &v, &w); if (i == 1) tmp = w; add(u, v, w), add(v, u, w); } if (n == 2 && k >= m) return printf("%d\n", tmp), 0; f[0] = sum = n, getroot(1, 0), solve(root), printf("%d\n", ans); return 0; }
|