https://www.luogu.com.cn/problem/P4332
成功复习了LCT
首先发现状态改变的一定是叶子向上的一条路径
我们记录\(sum[u]\)表示\(u\)节点有几个儿子是\(1\),\(val[u]=sum[u]>1\)
这样维护最大值和最小值,每次\(access\)一下然后在\(Splay\)上二分可以做到\(nlog^2n\),
考虑记录下来\(Splay\)上最深的\(sum\)不是\(1/2\)的点,然后每次向上修改到这个点即可
代码实现有些细节要注意
code:
#include<bits/stdc++.h>
#define N 2500050
using namespace std;
inline int rd() {
int x = 0; char ch = getchar();
for(; ch < '0' || ch > '9' ;) ch = getchar();
for(; ch >= '0' && ch <= '9'; ) x = (x << 3) + (x << 1) + (ch - '0'), ch = getchar();
return x;
}
struct LCT {
#define ls ch[x][0]
#define rs ch[x][1]
int ch[N][2], val[N], id[N][3], tg[N], sum[N], fa[N];
int get(int x) {return ch[fa[x]][1] == x; }
int nrt(int x) {return ch[fa[x]][0] == x || ch[fa[x]][1] == x;}
void update(int x) {
id[x][1] = id[rs][1], id[x][2] = id[rs][2];
if(!id[x][1]) {
if(sum[x] != 1) id[x][1] = x;
else id[x][1] = id[ls][1];
}
if(!id[x][2]) {
if(sum[x] != 2) id[x][2] = x;
else id[x][2] = id[ls][2];
}
}
void padd(int x, int o) {
tg[x] += o, sum[x] += o, val[x] = sum[x] > 1;
swap(id[x][1], id[x][2]);
}
void pushdown(int x) {
if(tg[x]) {
padd(ls, tg[x]), padd(rs, tg[x]);
tg[x] = 0;
}
}
void rotate(int x) {
int f = fa[x], gf = fa[f], k = get(x);
if(nrt(f)) ch[gf][get(f)] = x; fa[x] = gf;
ch[f][k] = ch[x][!k]; if(ch[x][!k]) fa[ch[x][!k]] = f;
ch[x][!k] = f, fa[f] = x;
update(f), update(x);
}
void pushall(int x) {
if(nrt(x)) pushall(fa[x]);
pushdown(x);
}
void splay(int x) {
pushall(x);
while(nrt(x)) {
int f = fa[x];
if(nrt(f)) rotate(get(f) == get(x)? f : x);
rotate(x);
}
}
void access(int x) {
for(int y = 0; x; y = x, x = fa[x]) {
splay(x), rs = y, update(x);
}
}
} T;
int n, m;
vector<int> g[N];
void dfs(int u, int fa) {
T.sum[u] = 0;
for(int v : g[u]) {
if(v == fa) continue;
T.fa[v] = u;
dfs(v, u);
T.sum[u] += T.val[v];
}
if(u <= n) T.val[u] = T.sum[u] > 1;
}
int main() {
n = rd();
for(int i = 1; i <= n; i ++) {
for(int j = 1; j <= 3; j ++) {
int x;
x = rd();
g[x].push_back(i), g[i].push_back(x);
}
}
for(int i = n + 1; i <= 3 * n + 1; i ++) T.val[i] = rd();
dfs(1, 1);
m = rd();
while(m --) {
int x, y;
y = rd(), x = T.fa[y];
int o = T.val[y]? - 1 : 1;
T.access(x), T.splay(x);
int k = o;
if(k == -1) k = 2;
int z = T.id[x][k];
if(z) {
T.splay(z);
T.padd(T.ch[z][1], o), T.update(T.ch[z][1]);
T.sum[z] += o, T.val[z] = T.sum[z] > 1; T.update(z);
} else T.padd(x, o), T.update(x);
T.val[y] ^= 1; T.splay(1);
printf("%d\n", T.val[1]);
}
return 0;
}