题解 鼠树(留坑)

传送门

考场上思路对了,只差一个子树求和想不到如何做,于是喜提10pts

首先树剖可以动态维护每个点的归属点
先不考虑操作6,可以把修改操作都挂在归属的黑点上
问题在于子树求和,但其实很简单
在每个黑点再维护一个域,存其管辖点的权值和
这一部分可以dfs序上区间查询
考虑会有一部分统计不到的白点
用子树大小减去子树内所有黑点管辖点个数和可以得到这部分白点个数
而它们的权值就是向上找到最近的黑点权值
这样就可以处理前5个操作了

考虑删除
难点在于保留被删除黑点及其子树的权值,想不到
考虑再开一棵线段树存一个额外的偏移量
删一个点时,对其子树区间加以留下这个贡献
然后会有加重的,可以利用操作4子树加上一个相反数减回来
代码有点长,不过相对还好写
题解说树剖动态维护每个点的归属点的过程可以在链上维护信息达到单次 \(logn\) 复杂度
但我没看懂怎么做,留个坑吧

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 300010
#define usd unsigned
#define ll long long 
//#define int long long 

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m;
int head[N], size;
usd val[N];
bool black[N];
struct edge{int to, next;}e[N];
inline void add(int s, int t) {e[++size].to=t; e[size].next=head[s]; head[s]=size;}

namespace force{
	void dfs2(int u, usd dat) {
		val[u]+=dat;
		for (int i=head[u]; ~i; i=e[i].next) if (!black[e[i].to]) dfs2(e[i].to, dat);
	}
	usd dfs3(int u) {
		usd sum=val[u];
		for (int i=head[u]; ~i; i=e[i].next) sum+=dfs3(e[i].to);
		return sum;
	}
	void dfs4(int u, usd dat) {
		if (black[u]) dfs2(u, dat);
		for (int i=head[u]; ~i; i=e[i].next) dfs4(e[i].to, dat);
	}
	void solve() {
		for (int i=1,op,k; i<=m; ++i) {
			op=read(); k=read();
			//cout<<"i: "<<op<<' '<<k<<endl;
			if (op==1) printf("%u\n", val[k]);
			else if (op==2) dfs2(k, read());
			else if (op==3) printf("%u\n", dfs3(k));
			else if (op==4) dfs4(k, read());
			else if (op==5) black[k]=1;
			else black[k]=0;
		}
		exit(0);
	}
}

namespace task{
	int siz[N], msiz[N], mson[N], id[N], rk[N], tot, dep[N], top[N], fa[N], sta[N], stop;
	bool disable[N];
	int tl[N<<2], tr[N<<2], rid[N<<2]; usd sum[N<<2], val[N<<2], cnt[N<<2], tag[N<<2], sum2[N<<2], len[N<<2], tag2[N<<2];
	#define tl(p) tl[p]
	#define tr(p) tr[p]
	#define rid(p) rid[p]
	#define sum(p) sum[p]
	#define cnt(p) cnt[p]
	#define val(p) val[p]
	#define tag(p) tag[p]
	#define sum2(p) sum2[p]
	#define len(p) len[p]
	#define tag2(p) tag2[p]
	void pushup(int p) {
		rid(p)=rid(p<<1|1)?rid(p<<1|1):rid(p<<1);
		//cout<<"pushup: "<<p<<' '<<rid(p)<<endl;
		sum(p)=sum(p<<1)+sum(p<<1|1);
		cnt(p)=cnt(p<<1)+cnt(p<<1|1);
	}
	void spread(int p) {
		//cout<<"spread "<<val(4)<<endl;
		if (!tag(p)) return ;
		if (rid(p<<1)) {
			val(p<<1)+=tag(p);
			sum(p<<1)+=tag(p)*cnt(p<<1), tag(p<<1)+=tag(p);
		}
		if (rid(p<<1|1)) {
			val(p<<1|1)+=tag(p);
			sum(p<<1|1)+=tag(p)*cnt(p<<1|1), tag(p<<1|1)+=tag(p);
		}
		tag(p)=0;
	}
	void build(int p, int l, int r) {
		tl(p)=l; tr(p)=r;
		if (l==r) return ;
		int mid=(l+r)>>1;
		build(p<<1, l, mid);
		build(p<<1|1, mid+1, r);
	}
	// 添加/删除黑点
	void upd1(int p, int pos, bool op, usd v, usd c) {
		//cout<<"upd1 "<<p<<' '<<tl(p)<<' '<<tr(p)<<' '<<pos<<' '<<op<<' '<<v<<' '<<c<<endl;
		if (tl(p)==tr(p)) {rid(p)=op?p:0; val(p)=v; cnt(p)=c; sum(p)=v*c; return ;}
		spread(p);
		int mid=(tl(p)+tr(p))>>1;
		if (pos<=mid) upd1(p<<1, pos, op, v, c);
		else upd1(p<<1|1, pos, op, v, c);
		pushup(p);
	}
	// 对黑点区间修改
    void upd2(int p, int l, int r, usd dat) {
    	//cout<<"upd2: "<<p<<' '<<tl(p)<<' '<<tr(p)<<' '<<l<<' '<<r<<' '<<dat<<' '<<rid(p)<<endl;
    	if (!rid(p)) return ;
        if (tl(p)==tr(p)) {sum(p)+=dat*cnt(p); val(p)+=dat; return ;}
        if (l<=tl(p) && r>=tr(p)) {sum(p)+=dat*cnt(p); tag(p)+=dat; return ;}
        spread(p);
        int mid=(tl(p)+tr(p))>>1;
        //cout<<"mid: "<<tl(p)<<' '<<tr(p)<<' '<<mid<<' '<<l<<' '<<r<<endl;
        if (l<=mid) upd2(p<<1, l, r, dat);
        if (r>mid) upd2(p<<1|1, l, r, dat);
        pushup(p);
    }
    // 对黑点单点修改
    void upd3(int p, int pos, usd v, usd c) {
        if (tl(p)==tr(p)) {val(p)+=v; cnt(p)+=c; sum(p)=val(p)*cnt(p); return ;}
        spread(p);
        int mid=(tl(p)+tr(p))>>1;
        if (pos<=mid) upd3(p<<1, pos, v, c);
        else upd3(p<<1|1, pos, v, c);
        pushup(p);
    }
	// 查询最靠右的黑点
	int qpoint(int p, int l, int r) {
		if (l<=tl(p) && r>=tr(p)) return rid(p);
		spread(p);
		int mid=(tl(p)+tr(p))>>1, ans=0;
		if (r>mid) ans=qpoint(p<<1|1, l, r);
		if (ans) return ans;
		if (l<=mid) ans=qpoint(p<<1, l, r);
		return ans;
	}
    // 对黑点单点查询权值
    usd qval(int p, int pos) {
        if (tl(p)==tr(p)) {assert(sum(p)==val(p)*cnt(p)); return val(p);}
        spread(p);
        int mid=(tl(p)+tr(p))>>1;
        if (pos<=mid) return qval(p<<1, pos);
        else return qval(p<<1|1, pos);
    }
    // 对子树区间查询被管辖的点个数和
    usd qcnt(int p, int l, int r) {
    	if (l<=tl(p) && r>=tr(p)) return cnt(p);
    	spread(p);
    	int mid=(tl(p)+tr(p))>>1; usd ans=0;
    	if (l<=mid) ans+=qcnt(p<<1, l, r);
    	if (r>mid) ans+=qcnt(p<<1|1, l, r);
    	return ans;
    }
    // 对子树区间查询权值和
    usd qsum(int p, int l, int r) {
    	if (l<=tl(p) && r>=tr(p)) return sum(p);
    	spread(p);
    	int mid=(tl(p)+tr(p))>>1; usd ans=0;
    	if (l<=mid) ans+=qsum(p<<1, l, r);
    	if (r>mid) ans+=qsum(p<<1|1, l, r);
    	return ans;
    }
	void dfs1(int u) {
		siz[u]=1;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			dep[v]=dep[u]+1, fa[v]=u, dfs1(v);
			siz[u]+=siz[v];
			if (siz[v]>msiz[u]) msiz[u]=siz[v], mson[u]=v;
		}
	}
	void dfs2(int u, int t) {
		top[u]=t;
		id[u]=++tot;
		rk[tot]=u;
		if (!mson[u]) return ;
		dfs2(mson[u], t);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v!=mson[u]) dfs2(v, v);
		}
	}
	int qfa(int a) {
		int ans=0;
		while (top[a]) {
			if (!disable[top[a]]) {
				ans=qpoint(1, id[top[a]], id[a]);
				if (ans) return ans;
			}
			a=fa[top[a]];
		}
		puts("error");
		return 0;
	}
	void spread2(int p) {
		if (!tag2(p)) return ;
		sum2(p<<1)+=tag2(p)*len(p<<1), tag2(p<<1)+=tag2(p);
		sum2(p<<1|1)+=tag2(p)*len(p<<1|1), tag2(p<<1|1)+=tag2(p);
		tag2(p)=0;
	}
	void build2(int p, int l, int r) {
		len(p)=r-l+1;
		if (l==r) return ;
		int mid=(l+r)>>1;
		build2(p<<1, l, mid);
		build2(p<<1|1, mid+1, r);
	}
	void upd4(int p, int l, int r, usd dat) {
		if (l<=tl(p) && r>=tr(p)) {sum2(p)+=dat*len(p); tag2(p)+=dat; return ;}
		spread2(p);
		int mid=(tl(p)+tr(p))>>1;
		if (l<=mid) upd4(p<<1, l, r, dat);
		if (r>mid) upd4(p<<1|1, l, r, dat);
		sum2(p)=sum2(p<<1)+sum2(p<<1|1);
	}
	usd qdlt(int p, int l, int r) {
		if (l<=tl(p) && r>=tr(p)) return sum2(p);
		spread2(p);
		int mid=(tl(p)+tr(p))>>1; usd ans=0;
		if (l<=mid) ans+=qdlt(p<<1, l, r);
		if (r>mid) ans+=qdlt(p<<1|1, l, r);
		return ans;
	}
	void solve() {
		usd ans, t1, t2;
		dep[1]=1; dfs1(1); dfs2(1, 1); build(1, 1, n); build2(1, 1, n);
		upd1(1, id[1], 1, 0, n);
		for (int i=1,op,k,f; i<=m; ++i) {
			op=read(); k=read();
			//cout<<"i: "<<i<<' '<<op<<' '<<k<<endl;
			if (op==1) {
				//cout<<"op=1: "<<endl;
				//cout<<"fa: "<<qfa(k)<<endl;
				//cout<<"cnt[fa]: "<<cnt[qfa(k)]<<endl;
				stop=0; f=qfa(k);
				for (int j=f>>1; j; j>>=1) sta[++stop]=j;
				while (stop) spread(sta[stop]), spread2(sta[stop--]);
				printf("%u\n", val[f]+qdlt(1, id[k], id[k]));
			}
			else if (op==2) upd3(1, id[k], read(), 0);
			else if (op==3) {
				ans=0; f=qfa(k); stop=0;
				for (int j=f>>1; j; j>>=1) sta[++stop]=j;
				while (stop) spread(sta[stop--]);
				ans+=val[f]*(siz[k]-qcnt(1, id[k], id[k]+siz[k]-1));
				ans+=qsum(1, id[k], id[k]+siz[k]-1);
				ans+=qdlt(1, id[k], id[k]+siz[k]-1);
				printf("%u\n", ans);
			}
			else if (op==4) {
				//cout<<"op=4: "<<endl;
				//cout<<"id: "<<id[k]<<' '<<id[k]+siz[k]-1<<endl;
				upd2(1, id[k], id[k]+siz[k]-1, read());
			}
			else if (op==5) {
				//cout<<"op=5: "<<endl;
				f=qfa(k); int newsiz=siz[k]-qcnt(1, id[k], id[k]+siz[k]-1);
				//cout<<"go qcnt: "<<siz[k]<<' '<<qcnt(1, id[k], id[k]+siz[k]-1)<<endl;
				//cout<<"f="<<f<<" and newsiz="<<newsiz<<endl;
				//disable[top[k]]=1;
				stop=0;
				for (int j=f>>1; j; j>>=1) sta[++stop]=j;
				while (stop) spread(sta[stop--]);
				upd1(1, id[k], 1, val[f], newsiz);
				cnt[f]-=newsiz; sum(f)=val(f)*cnt(f);
				for (int j=f>>1; j; j>>=1) pushup(j);
			}
			else {
				//cout<<"op=6: "<<endl;
				f=qfa(fa[k]); stop=0;
				//cout<<"f: "<<f<<' '<<val(f)<<endl;
				for (int j=f>>1; j; j>>=1) sta[++stop]=j;
				while (stop) spread(sta[stop--]); //, cout<<val(4)<<endl;
				t1=val(f); t2=qval(1, id[k]);
				//cout<<"t: "<<t1<<' '<<t2<<endl;
				upd4(1, id[k], id[k]+siz[k]-1, t2-t1);
				//cout<<"try qdlt: "<<qdlt(1, id[k], id[k])<<endl;
				upd1(1, id[k], 0, 0, 0);
				upd2(1, id[k], id[k]+siz[k]-1, t1-t2);
				int newsiz=siz[k]-qcnt(1, id[k], id[k]+siz[k]-1);
				//cout<<"newsiz: "<<newsiz<<endl;
				cnt(f)+=newsiz; sum(f)=val(f)*cnt(f);
				for (int j=f>>1; j; j>>=1) pushup(j);
			}
		}
		exit(0);
	}
}

signed main()
{
	memset(head, -1, sizeof(head));
	n=read(); m=read();
	black[1]=1;
	for (int i=2; i<=n; ++i) add(read(), i);
	//force::solve();
	task::solve();
	
	return 0;
}
上一篇:撤回骗分: 莫队Ⅴ


下一篇:Python 1