题意:给定树上k个关键点,每个点属于离他最近,然后编号最小的关键点。求每个关键点管辖多少点。
解:虚树 + DP。
虚树不解释。主要是DP。用二元组存虚树上每个点的归属和距离。这一部分是二次扫描与换根法。
然后把关键点改为虚树节点,统计每个虚树节点管辖多少个节点,用SIZ表示,初始时SIZ = siz,SIZ[RT] = n。
如果一条虚树边两端点的归属相同。那么SIZ[fa] -= siz[son]
否则树上倍增找到y是最靠上属于的son的,然后SIZ[fa] -= siz[y] SIZ[son] = siz[y]
1 #include <cstdio> 2 #include <algorithm> 3 #include <vector> 4 #include <cstring> 5 6 typedef long long LL; 7 const int N = 300010, INF = 0x3f3f3f3f; 8 9 struct Edge { 10 int nex, v, len; 11 }edge[N * 2], EDGE[N * 2]; int tp, TP; 12 13 struct Node { 14 int x, d; 15 Node(int X = 0, int D = 0) { 16 x = X; 17 d = D; 18 } 19 }small[N]; 20 21 int e[N], E[N], RT, siz[N], d[N], num, pos[N], pw[N], Time, imp[N], stk[N], top, now[N], imp2[N]; 22 int ans[N], fr[N], SIZ[N], n, fa[N][20], use[N]; 23 24 inline void add(int x, int y) { 25 tp++; 26 edge[tp].v = y; 27 edge[tp].nex = e[x]; 28 e[x] = tp; 29 return; 30 } 31 32 inline void ADD(int x, int y) { 33 // printf("ADD %d %d \n", x, y); 34 TP++; 35 EDGE[TP].v = y; 36 EDGE[TP].len = d[y] - d[x]; 37 EDGE[TP].nex = E[x]; 38 E[x] = TP; 39 return; 40 } 41 42 inline bool cmp(const int &a, const int &b) { 43 return pos[a] < pos[b]; 44 } 45 46 void DFS_1(int x, int father) { 47 fa[x][0] = father; 48 d[x] = d[father] + 1; 49 siz[x] = 1; 50 pos[x] = ++num; 51 for(int i = e[x]; i; i = edge[i].nex) { 52 int y = edge[i].v; 53 if(y == father) { 54 continue; 55 } 56 DFS_1(y, x); 57 siz[x] += siz[y]; 58 } 59 return; 60 } 61 62 inline int lca(int x, int y) { 63 if(d[x] > d[y]) { 64 std::swap(x, y); 65 } 66 int t = pw[n]; 67 while(t >= 0 && d[x] < d[y]) { 68 if(d[fa[y][t]] >= d[x]) { 69 y = fa[y][t]; 70 } 71 t--; 72 } 73 if(x == y) { 74 return x; 75 } 76 t = pw[n]; 77 while(t >= 0 && fa[x][0] != fa[y][0]) { 78 if(fa[x][t] != fa[y][t]) { 79 x = fa[x][t]; 80 y = fa[y][t]; 81 } 82 t--; 83 } 84 return fa[x][0]; 85 } 86 87 inline void clear(int x) { 88 if(use[x] != Time) { 89 use[x] = Time; 90 E[x] = 0; 91 } 92 return; 93 } 94 95 inline void build_t(int k) { 96 std::sort(imp + 1, imp + k + 1, cmp); 97 TP = top = 0; 98 clear(imp[1]); 99 stk[++top] = imp[1]; 100 for(int i = 2; i <= k; i++) { 101 int x = imp[i], y = lca(stk[top], x); 102 clear(x); clear(y); 103 while(top > 1 && pos[y] <= pos[stk[top - 1]]) { 104 ADD(stk[top - 1], stk[top]); 105 top--; 106 } 107 if(y != stk[top]) { 108 ADD(y, stk[top]); 109 stk[top] = y; 110 } 111 stk[++top] = x; 112 } 113 while(top > 1) { 114 ADD(stk[top - 1], stk[top]); 115 top--; 116 } 117 RT = stk[top]; 118 return; 119 } 120 121 void out_t(int x) { 122 printf("out x = %d \n", x); 123 for(int i = E[x]; i; i = EDGE[i].nex) { 124 int y = EDGE[i].v; 125 // printf("EDGE %d y %d \n", i, y); 126 out_t(y); 127 } 128 return; 129 } 130 131 void getSmall(int x) { 132 (now[x] == Time) ? small[x] = Node(x, 0) : small[x] = Node(n + 1, INF); 133 // printf("getSmall x = %d small = %d \n", x, small[x].x); 134 SIZ[x] = siz[x]; 135 for(int i = E[x]; i; i = EDGE[i].nex) { 136 int y = EDGE[i].v; 137 getSmall(y); 138 if(small[x].d > small[y].d + EDGE[i].len) { 139 small[x] = small[y]; 140 small[x].d += EDGE[i].len; 141 } 142 else if(small[x].d == small[y].d + EDGE[i].len) { 143 small[x].x = std::min(small[x].x, small[y].x); 144 } 145 } 146 return; 147 } 148 149 void getEXsmall(int x, Node t) { 150 // printf("EX x = %d small = %d \n", x, small[x].x); 151 if(small[x].d > t.d || (small[x].d == t.d && small[x].x > t.x)) { 152 small[x] = t; 153 } 154 // printf("x = %d small = %d \n", x, small[x].x); 155 for(int i = E[x]; i; i = EDGE[i].nex) { 156 int y = EDGE[i].v; 157 getEXsmall(y, Node(small[x].x, small[x].d + EDGE[i].len)); 158 } 159 return; 160 } 161 162 inline int getPos(int x, int f) { 163 int t = pw[d[x] - d[f]], y = x; 164 while(t >= 0) { 165 int mid = fa[y][t]; 166 if(d[x] - d[mid] + small[x].d < d[mid] - d[f] + small[f].d) { 167 y = mid; 168 } 169 else if(d[x] - d[mid] + small[x].d == d[mid] - d[f] + small[f].d && small[f].x > small[x].x) { 170 y = mid; 171 } 172 t--; 173 } 174 return y; 175 } 176 177 void del(int x, int f) { 178 // printf("del x = %d small = %d %d \n", x, small[x].x, small[x].d); 179 if(f) { 180 if(small[x].x == small[f].x) { 181 SIZ[f] -= siz[x]; 182 // printf("SIZ %d -= %d = %d \n", f, siz[x], SIZ[f]); 183 } 184 else { 185 int y = getPos(x, f); 186 SIZ[f] -= siz[y]; 187 // printf("SIZ %d -= %d = %d \n", f, siz[y], SIZ[f]); 188 SIZ[x] = siz[y]; 189 // printf("SIZ %d = siz %d %d \n", x, y, siz[y]); 190 } 191 } 192 for(int i = E[x]; i; i = EDGE[i].nex) { 193 int y = EDGE[i].v; 194 del(y, x); 195 } 196 ans[small[x].x] += SIZ[x]; 197 return; 198 } 199 200 inline void solve() { 201 int k; 202 scanf("%d", &k); 203 Time++; 204 for(int i = 1; i <= k; i++) { 205 scanf("%d", &imp[i]); 206 now[imp[i]] = Time; 207 ans[imp[i]] = 0; 208 } 209 memcpy(imp2 + 1, imp + 1, k * sizeof(int)); 210 build_t(k); 211 212 // out_t(RT); 213 // get Small 214 getSmall(RT); // get Size 215 getEXsmall(RT, Node(n + 1, INF)); 216 // 217 SIZ[RT] = n; 218 del(RT, 0); 219 for(int i = 1; i <= k; i++) { 220 printf("%d ", ans[imp2[i]]); 221 } 222 printf("\n"); 223 return; 224 } 225 226 int main() { 227 228 // freopen("in.in", "r", stdin); 229 // freopen("a.out", "w", stdout); 230 231 scanf("%d", &n); 232 for(int i = 1, x, y; i < n; i++) { 233 scanf("%d%d", &x, &y); 234 add(x, y); 235 add(y, x); 236 } 237 DFS_1(1, 0); 238 for(int i = 2; i <= n; i++) { 239 pw[i] = pw[i >> 1] + 1; 240 } 241 for(int j = 1; j <= pw[n]; j++) { 242 for(int i = 1; i <= n; i++) { 243 fa[i][j] = fa[fa[i][j - 1]][j - 1]; 244 } 245 } 246 int q; 247 scanf("%d", &q); 248 while(q--) { 249 solve(); 250 } 251 return 0; 252 }AC代码