| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
 100
 101
 102
 103
 104
 105
 106
 
 | #include <bits/stdc++.h>#define next ___________________________________________________________________
 using namespace std;
 
 typedef pair<int, int> pii;
 const int MAXN = 400005;
 
 int n, k, m, flag[MAXN], head[MAXN], to[MAXN], next[MAXN], w[MAXN],
 tot = 1, dist[MAXN], tmp, f[MAXN], root, size[MAXN], sum, vis[MAXN], ans,
 d[MAXN], num[MAXN], maxn, maxv[MAXN];
 vector<pii> e;
 
 inline void add(int x, int y, int z) {
 w[tot] = z, to[tot] = y, next[tot] = head[x], head[x] = tot++;
 }
 
 inline void getroot(int now, int fa) {
 size[now] = 1, f[now] = 0;
 for (int i = head[now]; i; i = next[i]) {
 int v = to[i];
 if (v == fa || vis[v])
 continue;
 getroot(v, now), size[now] += size[v], f[now] = max(f[now], size[v]);
 }
 f[now] = max(f[now], sum - size[now]);
 if (f[now] < f[root])
 root = now;
 }
 
 inline void getdep(int now, int fa) {
 maxn = max(maxn, num[now]);
 for (int i = head[now]; i; i = next[i]) {
 int v = to[i];
 if (v == fa || vis[v])
 continue;
 d[v] = d[now] + w[i], num[v] = num[now] + flag[v], getdep(v, now);
 }
 }
 
 inline void getmax(int now, int fa) {
 dist[num[now]] = max(dist[num[now]], d[now]);
 for (int i = head[now]; i; i = next[i]) {
 int v = to[i];
 if (v == fa || vis[v])
 continue;
 getmax(v, now);
 }
 }
 
 inline void solve(int now) {
 vis[now] = 1, e.clear();
 if (flag[now])
 k--;
 for (int i = head[now]; i; i = next[i]) {
 int v = to[i];
 if (vis[v])
 continue;
 num[v] = flag[v], d[v] = w[i], maxn = 0, getdep(v, now),
 e.push_back(make_pair(maxn, v));
 }
 sort(e.begin(), e.end());
 for (int i = 0; i < e.size(); i++) {
 getmax(e[i].second, now);
 int res = 0;
 if (i != 0)
 for (int j = e[i].first; j >= 0; j--) {
 while (res + 1 + j <= k && res + 1 <= e[i - 1].first)
 res++, maxv[res] = max(maxv[res], maxv[res - 1]);
 if (res + j <= k)
 ans = max(ans, maxv[res] + dist[j]);
 }
 if (i != e.size() - 1)
 for (int j = 0; j <= e[i].first; j++)
 maxv[j] = max(maxv[j], dist[j]), dist[j] = 0;
 else
 for (int j = 0; j <= e[i].first; j++)
 maxv[j] = dist[j] = 0;
 }
 if (flag[now])
 k++;
 for (int i = head[now]; i; i = next[i]) {
 int v = to[i];
 if (vis[v])
 continue;
 root = 0, f[0] = sum = size[v];
 getroot(v, 0), solve(root);
 }
 }
 
 int main() {
 scanf("%d%d%d", &n, &k, &m);
 for (int i = 1; i <= m; i++) {
 int x;
 scanf("%d", &x), flag[x] = 1;
 }
 for (int i = 1, u, v, w; i < n; i++) {
 scanf("%d%d%d", &u, &v, &w);
 if (i == 1)
 tmp = w;
 add(u, v, w), add(v, u, w);
 }
 if (n == 2 && k >= m)
 return printf("%d\n", tmp), 0;
 f[0] = sum = n, getroot(1, 0), solve(root), printf("%d\n", ans);
 return 0;
 }
 
 |