题解 Revive

传送门

柿子人畜无害,但有个地方误导性极强

  • 给定一棵树,每条边有一个边权,要带修查询一个点与其子树的所有点间的距离和
    这个东西没有可以时间可以接受的解法!

考场上就死这了……觉得维护出来就可以A了,结果死活维护不出来
正解是另一种解法:
题解 Revive

  • \((\sum a_i)^2 = \sum a_i^2 + \sum\limits_{x<y} 2*a_x*a_y\)

首先初始 \(ans\) 可以很方便地树形DP求出来,维护 \(\sum\limits_{v\in u's\ tree}dis(u, v)\) 和 \(\sum\limits_{v\in u's\ tree}dis^2(u, v)\) ,考虑路径合并就行了
于是考虑如何维护同时经过 \(a, b\) 两条边的路径数
\(a\) 是给定的,\(b\) 要分情况讨论:在 \(a\) 的子树里,在 \(a\) 到根节点的路径上,在其它子树中
那可以拆成dfs序用树状数组维护
case 1:直接维护
case 2:差分一下也可以直接维护
case 3:开两个树状数组,把贡献打在离开子树时同样可以维护
复杂度 \(O(n+(n+q)*13logn)\),完全不卡常跑得巨快,吊打线段树 \(O(n+q*2logn)\) 复杂度

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long 
//#define int long long 

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, q, sub;
int head[N], size, dep[N], id[N], rk[N], tot=1, top[N], fa[N], msiz[N], mson[N];
ll val[N], sum[N], siz[N];
const ll mod=1e9+7;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++size].to=t; e[size].next=head[s]; head[s]=size;}

namespace force{
	ll bin[N];
	inline void upd(int i, ll dat) {for (; i<=tot; i+=i&-i) md(bin[i], dat);}
	inline ll query(int i) {ll ans=0; for (; i; i-=i&-i) md(ans, bin[i]); return ans;}
	void build() {for (int i=2; i<=n; ++i) upd(id[i], val[i]%mod);}
	void dfs1(int u, int pa) {
		siz[u]=1;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==pa) continue;
			dep[v]=dep[u]+1, fa[v]=u, dfs1(v, u);
			siz[u]+=siz[v];
			if (siz[v]>msiz[u]) msiz[u]=siz[v], mson[u]=v;
		}
	}
	void dfs2(int u, int f, int t) {
		top[u]=t;
		id[u]=++tot;
		rk[tot]=u;
		if (!mson[u]) return ;
		dfs2(mson[u], u, t);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==f || v==mson[u]) continue;
			dfs2(v, u, v);
		}
	}
	ll qsum(int a, int b) {
		ll ans=0;
		while (top[a]!=top[b]) {
			if (dep[top[a]]<dep[top[b]]) swap(a, b);
			ans=(ans+query(id[a])-query(id[top[a]]-1))%mod;
			a=fa[top[a]];
		}
		if (dep[a]>dep[b]) swap(a, b);
		ans=(ans+query(id[b])-query(id[a]))%mod;
		return (ans+mod)%mod;
	}
	void solve() {
		dep[1]=1; dfs1(1, 0); dfs2(1, 0, 1); build();
		ll ans=0, tem;
		for (int i=1; i<=n; ++i)
			for (int j=i+1; j<=n; ++j) {
				tem=qsum(i, j);
				ans=(ans+tem*tem)%mod;
			}
		printf("%lld\n", ans);
		for (int i=1,u,w; i<=q; ++i) {
			u=read(); w=read();
			upd(id[u], w);
			ans=0;
			for (int j=1; j<=n; ++j)
				for (int k=j+1; k<=n; ++k) {
					tem=qsum(j, k);
					ans=(ans+tem*tem)%mod;
				}
			printf("%lld\n", ans);
		}
		exit(0);
	}
}

namespace task1{
	ll bin[N], sva[N], sva2[N], ans;
	inline void upd(int i, ll dat) {for (; i<=tot; i+=i&-i) md(bin[i], dat);}
	inline ll query(int i) {ll ans=0; for (; i; i-=i&-i) md(ans, bin[i]); return ans;}
	void build() {for (int i=2; i<=n; ++i) upd(id[i], val[i]%mod);}
	void dfs1(int u, int pa) {
		siz[u]=1;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==pa) continue;
			dep[v]=dep[u]+1, fa[v]=u, dfs1(v, u);
			ans = (ans + (siz[v]*sva2[u])%mod
			 + (siz[u]-1)*((sva2[v]+2*val[v]*sva[v]%mod+siz[v]*val[v]%mod*val[v]%mod)%mod)%mod
			 + 2*sva[u]*(sva[v]+siz[v]*val[v]%mod)%mod)%mod;
			siz[u]+=siz[v];
			sva2[u]=(sva2[u]+sva2[v]+2*val[v]*sva[v]%mod+siz[v]*val[v]%mod*val[v]%mod)%mod;
			sva[u]=(sva[u]+sva[v]+siz[v]*val[v])%mod;
			if (siz[v]>msiz[u]) msiz[u]=siz[v], mson[u]=v;
		}
		ans=(ans+sva2[u])%mod;
	}
	void dfs2(int u, int f, int t) {
		top[u]=t;
		id[u]=++tot;
		rk[tot]=u;
		if (!mson[u]) return ;
		dfs2(mson[u], u, t);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==f || v==mson[u]) continue;
			dfs2(v, u, v);
		}
	}
	ll qsum(int a, int b) {
		ll ans=0;
		while (top[a]!=top[b]) {
			if (dep[top[a]]<dep[top[b]]) swap(a, b);
			ans=(ans+query(id[a])-query(id[top[a]]-1))%mod;
			a=fa[top[a]];
		}
		if (dep[a]>dep[b]) swap(a, b);
		ans=(ans+query(id[b])-query(id[a]))%mod;
		return (ans+mod)%mod;
	}
	void solve() {
		dep[1]=1; dfs1(1, 0); dfs2(1, 0, 1); build();
		ll tem, dlt;
		printf("%lld\n", ans);
		for (int i=1,u,w; i<=q; ++i) {
			u=read(); w=read(); tem=0, dlt=0;
			for (int now=fa[u],lst=u; now; lst=now,now=fa[now]) {
				dlt = (dlt+val[lst])%mod;
				tem = (tem + sva[now]-sva[lst]-val[lst]*siz[lst]%mod + (siz[now]-siz[lst])*dlt%mod)%mod;
			}
			tem=(tem*siz[u]%mod + sva[u]*(n-siz[u]))%mod;
			ans = (ans+2ll*w*tem%mod+(n-siz[u])*siz[u]%mod*w%mod*w%mod)%mod;
			val[u]=(val[u]+w)%mod;
			for (int now=fa[u]; now; now=fa[now]) {
				sva[now]=(sva[now]+1ll*w*siz[u]%mod)%mod;
			}
			printf("%lld\n", (ans%mod+mod)%mod);
		}
		exit(0);
	}
}

namespace task{
	ll in_tree[N<<1], on_chain[N<<1], other_tree[N<<1], other_dlt[N<<1], sva[N], sva2[N], ans, in[N], out[N];
	inline void upd(ll* bin, int i, ll dat) {for (; i<=tot; i+=i&-i) bin[i]=(bin[i]+dat)%mod;}
	inline ll query(ll* bin, int i) {ll ans=0; for (; i; i-=i&-i) ans=(ans+bin[i])%mod; return ans;}
	void build() {
		for (int i=2; i<=n; ++i) {
			//cout<<"i: "<<i<<' '<<in[i]<<' '<<out[i]<<endl;
			upd(in_tree, in[i], val[i]*siz[i]%mod);
			upd(on_chain, in[i], val[i]*(n-siz[i])%mod);
			upd(on_chain, out[i], -val[i]*(n-siz[i])%mod);
			upd(other_dlt, in[i], -val[i]*siz[i]%mod);
			upd(other_dlt, out[i], val[i]*siz[i]%mod);
			upd(other_tree, out[i], val[i]*siz[i]%mod);
		}
	}
	void dfs1(int u, int pa) {
		//cout<<"dfs1 "<<u<<' '<<pa<<endl;
		siz[u]=1;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==pa) continue;
			dep[v]=dep[u]+1, fa[v]=u, dfs1(v, u);
			ans = (ans + (siz[v]*sva2[u])%mod
			 + (siz[u]-1)*((sva2[v]+2*val[v]*sva[v]%mod+siz[v]*val[v]%mod*val[v]%mod)%mod)%mod
			 + 2*sva[u]*(sva[v]+siz[v]*val[v]%mod)%mod)%mod;
			siz[u]+=siz[v];
			sva2[u]=(sva2[u]+sva2[v]+2*val[v]*sva[v]%mod+siz[v]*val[v]%mod*val[v]%mod)%mod;
			sva[u]=(sva[u]+sva[v]+siz[v]*val[v]%mod)%mod;
			if (siz[v]>msiz[u]) msiz[u]=siz[v], mson[u]=v;
		}
		ans=(ans+sva2[u])%mod;
	}
	void dfs2(int u, int f, int t) {
		top[u]=t;
		in[u]=id[u]=++tot;
		rk[tot]=u;
		if (!mson[u]) {out[u]=++tot; return ;}
		dfs2(mson[u], u, t);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==f || v==mson[u]) continue;
			dfs2(v, u, v);
		}
		out[u]=++tot;
	}
	ll qsum(int a, int b) {
		ll ans=0;
		while (top[a]!=top[b]) {
			if (dep[top[a]]<dep[top[b]]) swap(a, b);
			ans=(ans+query(on_chain, id[a])-query(on_chain, id[top[a]]-1))%mod;
			a=fa[top[a]];
		}
		if (dep[a]>dep[b]) swap(a, b);
		ans=(ans+query(on_chain, id[b])-query(on_chain, id[a]))%mod;
		return (ans+mod)%mod;
	}
	void solve() {
		dep[1]=1; dfs1(1, 0); dfs2(1, 0, 1); build();
		ll w;
		printf("%lld\n", ans);
		for (int i=1,u; i<=q; ++i) {
			//cout<<"i: "<<i<<endl;
			u=read(); w=read();
			ans = (ans+siz[u]*(n-siz[u])%mod*(2*val[u]*w%mod+w*w%mod)%mod)%mod;
			#if 0
			cout<<"ans1: "<<ans<<endl;
			cout<<"query: "<<2*w*query(other_tree, in[u])<<' '<<
							2*w*query(other_dlt, in[u]-1)<<' '<<2*w*(query(other_tree, tot)-query(other_tree, out[u]))<<endl;
			cout<<"assert: "<<query(other_tree, out[2])<<endl;
			cout<<"now: "<< (n-siz[u])*(query(in_tree, out[u])-query(in_tree, in[u]))%mod<<' '<<
						siz[u]*query(on_chain, in[u]-1)%mod<<' '<<
						siz[u]*((query(other_tree, in[u])+
							query(other_dlt, in[u]-1)+query(other_tree, tot)-query(other_tree, out[u]))%mod)%mod<<endl;
			#endif
			ans = (ans + 2*w*( (n-siz[u])*(query(in_tree, out[u])-query(in_tree, in[u]))%mod
						+ siz[u]*query(on_chain, in[u]-1)%mod
						+ siz[u]*((query(other_tree, in[u])
							+query(other_dlt, in[u]-1)+query(other_tree, tot)-query(other_tree, out[u]))%mod)%mod)%mod)%mod;
			
			val[u]=(val[u]+w)%mod;
			
			upd(in_tree, in[u], w*siz[u]%mod);
			upd(on_chain, in[u], w*(n-siz[u])%mod);
			upd(on_chain, out[u], -w*(n-siz[u])%mod);
			upd(other_dlt, in[u], -w*siz[u]%mod);
			upd(other_dlt, out[u], w*siz[u]%mod);
			upd(other_tree, out[u], w*siz[u]%mod);
			
			printf("%lld\n", (ans%mod+mod)%mod);
		}
		exit(0);
	}
}


signed main()
{
	memset(head, -1, sizeof(head));
	sub=read(); n=read(); q=read();
	for (int i=2,u; i<=n; ++i) {
		u=read(); val[i]=read();
		add(u, i); add(i, u);
	}
	task::solve();
	
	return 0;
}
上一篇:P4134 [BJOI2012]连连看


下一篇:P4133 [BJOI2012]最多的方案 二分+DP