第一次做树剖
找同学要了模板 + 各种借鉴
先用dfs在划分轻重链并编号, install的时候就从查询的节点到根寻找标记的点有多少个,再用深度减去标记的点的个数,并把路径上所有点都标记
uninstall就是搜索查询的点的子树的标记个数,并取消所有的标记。
代码如下
/**************************************************************
Problem: 4196
Language: C++
Result: Accepted
Time:8312 ms
Memory:7416 kb
****************************************************************/ #include <cstdio>
#include <vector>
#include <cstring>
using namespace std;
#define g(l, r) (l + r | l != r)
#define o g(l, r)
#define ls g(l, mid)
#define rs g(mid + 1, r)
const int N = ; int n, dep[N], fa[N], hs[N], size[N], top[N], pos[N], mark[N<<], sum[N<<], tot, L, R, m;
vector < int > edge[N]; inline void dfs1(int u, int d, int f){
dep[u] = d; fa[u] = f; hs[u] = -; size[u] = ;
int tmp = ;
for (int i = ; i < edge[u].size(); i++){
int &v = edge[u][i];
dfs1(v, d + , u);
if (size[v] > tmp)
hs[u] = v, tmp = size[v];
size[u] += size[v];
}
} inline void dfs2(int u, int T){
top[u] = T; pos[u] = ++tot;
if (hs[u] == -) return ;
dfs2(hs[u], T);
for (int i = ; i < edge[u].size(); i++){
int &v = edge[u][i];
if (hs[u] == v) continue;
dfs2(v, v);
}
} inline void push(int l, int r) {
if (mark[o] == -) return;
int mid = l + r >> ;
if (l < r) {
mark[ls] = mark[o];
sum[ls] = mark[o] * (mid - l + );
mark[rs] = mark[o];
sum[rs] = mark[o] * (r - (mid + ) + );
}
mark[o] = -;
} void modify(int l, int r){
if (L <= l && r <= R){
sum[o] = m * (r - l + );
mark[o] = m;
return ;
}
push(l, r);
int mid = l + r >> ;
if (L <= mid) modify(l, mid);
if (R >= mid + ) modify(mid + , r);
sum[o] = sum[ls] + sum[rs];
} int getSum(int l, int r){
if (L <= l && r <= R){
return sum[o];
}
push(l, r);
int mid = l + r >> , ans = ;
if (L <= mid) ans += getSum(l, mid);
if (R >= mid + ) ans += getSum(mid + , r);
return ans;
} inline void install(int u){
int f = top[u], ans = dep[u];
m = ;
while(f){
L = pos[f], R = pos[u];
ans -= getSum(, tot);
modify(, tot);
u = fa[f]; f = top[u];
}
L = pos[f], R = pos[u];
ans -= getSum(, tot);
modify(, tot);
printf("%d\n", ans);
} inline void uninstall(int u){
L = pos[u], R = pos[u] + size[u] - , m = ;
int ans = getSum(, tot);
modify(, tot);
printf("%d\n", ans);
} int main(){
scanf("%d", &n);
for (int i = ; i < n; i++){
int x; scanf("%d", &x);
edge[x].push_back(i);
}
dfs1(, , -); dfs2(, );
int q;
scanf("%d", &q);
memset(mark, 0xff, sizeof(mark));
while(q--){
char str[]; int x;
scanf("%s%d", str, &x);
if (*str == 'i') install(x);
else uninstall(x);
}
return ;
}