洛谷P3233 世界树

题意:给定树上k个关键点,每个点属于离他最近,然后编号最小的关键点。求每个关键点管辖多少点。

解:虚树 + DP。

虚树不解释。主要是DP。用二元组存虚树上每个点的归属和距离。这一部分是二次扫描与换根法。

然后把关键点改为虚树节点,统计每个虚树节点管辖多少个节点,用SIZ表示,初始时SIZ = siz,SIZ[RT] = n。

如果一条虚树边两端点的归属相同。那么SIZ[fa] -= siz[son]

否则树上倍增找到y是最靠上属于的son的,然后SIZ[fa] -= siz[y] SIZ[son] = siz[y]

洛谷P3233 世界树
  1 #include <cstdio>
  2 #include <algorithm>
  3 #include <vector>
  4 #include <cstring>
  5 
  6 typedef long long LL;
  7 const int N = 300010, INF = 0x3f3f3f3f;
  8 
  9 struct Edge {
 10     int nex, v, len;
 11 }edge[N * 2], EDGE[N * 2]; int tp, TP;
 12 
 13 struct Node {
 14     int x, d;
 15     Node(int X = 0, int D = 0) {
 16         x = X;
 17         d = D;
 18     }
 19 }small[N];
 20 
 21 int e[N], E[N], RT, siz[N], d[N], num, pos[N], pw[N], Time, imp[N], stk[N], top, now[N], imp2[N];
 22 int ans[N], fr[N], SIZ[N], n, fa[N][20], use[N];
 23 
 24 inline void add(int x, int y) {
 25     tp++;
 26     edge[tp].v = y;
 27     edge[tp].nex = e[x];
 28     e[x] = tp;
 29     return;
 30 }
 31 
 32 inline void ADD(int x, int y) {
 33 //    printf("ADD %d %d \n", x, y);
 34     TP++;
 35     EDGE[TP].v = y;
 36     EDGE[TP].len = d[y] - d[x];
 37     EDGE[TP].nex = E[x];
 38     E[x] = TP;
 39     return;
 40 }
 41 
 42 inline bool cmp(const int &a, const int &b) {
 43     return pos[a] < pos[b];
 44 }
 45 
 46 void DFS_1(int x, int father) {
 47     fa[x][0] = father;
 48     d[x] = d[father] + 1;
 49     siz[x] = 1;
 50     pos[x] = ++num;
 51     for(int i = e[x]; i; i = edge[i].nex) {
 52         int y = edge[i].v;
 53         if(y == father) {
 54             continue;
 55         }
 56         DFS_1(y, x);
 57         siz[x] += siz[y];
 58     }
 59     return;
 60 }
 61 
 62 inline int lca(int x, int y) {
 63     if(d[x] > d[y]) {
 64         std::swap(x, y);
 65     }
 66     int t = pw[n];
 67     while(t >= 0 && d[x] < d[y]) {
 68         if(d[fa[y][t]] >= d[x]) {
 69             y = fa[y][t];
 70         }
 71         t--;
 72     }
 73     if(x == y) {
 74         return x;
 75     }
 76     t = pw[n];
 77     while(t >= 0 && fa[x][0] != fa[y][0]) {
 78         if(fa[x][t] != fa[y][t]) {
 79             x = fa[x][t];
 80             y = fa[y][t];
 81         }
 82         t--;
 83     }
 84     return fa[x][0];
 85 }
 86 
 87 inline void clear(int x) {
 88     if(use[x] != Time) {
 89         use[x] = Time;
 90         E[x] = 0;
 91     }
 92     return;
 93 }
 94 
 95 inline void build_t(int k) {
 96     std::sort(imp + 1, imp + k + 1, cmp);
 97     TP = top = 0;
 98     clear(imp[1]);
 99     stk[++top] = imp[1];
100     for(int i = 2; i <= k; i++) {
101         int x = imp[i], y = lca(stk[top], x);
102         clear(x); clear(y);
103         while(top > 1 && pos[y] <= pos[stk[top - 1]]) {
104             ADD(stk[top - 1], stk[top]);
105             top--;
106         }
107         if(y != stk[top]) {
108             ADD(y, stk[top]);
109             stk[top] = y;
110         }
111         stk[++top] = x;
112     }
113     while(top > 1) {
114         ADD(stk[top - 1], stk[top]);
115         top--;
116     }
117     RT = stk[top];
118     return;
119 }
120 
121 void out_t(int x) {
122     printf("out x = %d \n", x);
123     for(int i = E[x]; i; i = EDGE[i].nex) {
124         int y = EDGE[i].v;
125 //        printf("EDGE %d y %d \n", i, y);
126         out_t(y);
127     }
128     return;
129 }
130 
131 void getSmall(int x) {
132     (now[x] == Time) ? small[x] = Node(x, 0) : small[x] = Node(n + 1, INF);
133 //    printf("getSmall x = %d small = %d \n", x, small[x].x);
134     SIZ[x] = siz[x];
135     for(int i = E[x]; i; i = EDGE[i].nex) {
136         int y = EDGE[i].v;
137         getSmall(y);
138         if(small[x].d > small[y].d + EDGE[i].len) {
139             small[x] = small[y];
140             small[x].d += EDGE[i].len;
141         }
142         else if(small[x].d == small[y].d + EDGE[i].len) {
143             small[x].x = std::min(small[x].x, small[y].x);
144         }
145     }
146     return;
147 }
148 
149 void getEXsmall(int x, Node t) {
150 //    printf("EX x = %d small = %d \n", x, small[x].x);
151     if(small[x].d > t.d || (small[x].d == t.d && small[x].x > t.x)) {
152         small[x] = t;
153     }
154 //    printf("x = %d small = %d \n", x, small[x].x);
155     for(int i = E[x]; i; i = EDGE[i].nex) {
156         int y = EDGE[i].v;
157         getEXsmall(y, Node(small[x].x, small[x].d + EDGE[i].len));
158     }
159     return;
160 }
161 
162 inline int getPos(int x, int f) {
163     int t = pw[d[x] - d[f]], y = x;
164     while(t >= 0) {
165         int mid = fa[y][t];
166         if(d[x] - d[mid] + small[x].d < d[mid] - d[f] + small[f].d) {
167             y = mid;
168         }
169         else if(d[x] - d[mid] + small[x].d == d[mid] - d[f] + small[f].d && small[f].x > small[x].x) {
170             y = mid;
171         }
172         t--;
173     }
174     return y;
175 }
176 
177 void del(int x, int f) {
178 //    printf("del x = %d small = %d %d \n", x, small[x].x, small[x].d);
179     if(f) {
180         if(small[x].x == small[f].x) {
181             SIZ[f] -= siz[x];
182 //            printf("SIZ %d -= %d = %d \n", f, siz[x], SIZ[f]);
183         }
184         else {
185             int y = getPos(x, f);
186             SIZ[f] -= siz[y];
187 //            printf("SIZ %d -= %d = %d \n", f, siz[y], SIZ[f]);
188             SIZ[x] = siz[y];
189 //            printf("SIZ %d = siz %d  %d \n", x, y, siz[y]);
190         }
191     }
192     for(int i = E[x]; i; i = EDGE[i].nex) {
193         int y = EDGE[i].v;
194         del(y, x);
195     }
196     ans[small[x].x] += SIZ[x];
197     return;
198 }
199 
200 inline void solve() {
201     int k;
202     scanf("%d", &k);
203     Time++;
204     for(int i = 1; i <= k; i++) {
205         scanf("%d", &imp[i]);
206         now[imp[i]] = Time;
207         ans[imp[i]] = 0;
208     }
209     memcpy(imp2 + 1, imp + 1, k * sizeof(int));
210     build_t(k);
211 
212 //    out_t(RT);
213     // get Small
214     getSmall(RT); // get Size
215     getEXsmall(RT, Node(n + 1, INF));
216     //
217     SIZ[RT] = n;
218     del(RT, 0);
219     for(int i = 1; i <= k; i++) {
220         printf("%d ", ans[imp2[i]]);
221     }
222     printf("\n");
223     return;
224 }
225 
226 int main() {
227 
228 //    freopen("in.in", "r", stdin);
229 //    freopen("a.out", "w", stdout);
230 
231     scanf("%d", &n);
232     for(int i = 1, x, y; i < n; i++) {
233         scanf("%d%d", &x, &y);
234         add(x, y);
235         add(y, x);
236     }
237     DFS_1(1, 0);
238     for(int i = 2; i <= n; i++) {
239         pw[i] = pw[i >> 1] + 1;
240     }
241     for(int j = 1; j <= pw[n]; j++) {
242         for(int i = 1; i <= n; i++) {
243             fa[i][j] = fa[fa[i][j - 1]][j - 1];
244         }
245     }
246     int q;
247     scanf("%d", &q);
248     while(q--) {
249         solve();
250     }
251     return 0;
252 }
AC代码

 

上一篇:LeetCode——1438. 绝对差不超过限制的最长连续子数组(Longest Continuous Subarray With Absolute Diff...)[中等]——分析及代码(Java)


下一篇:mysql和Python3 连接 pymysql 模块