这是一道树链剖分的题目;
很容易想到,我们在树剖后,对于操作1,直接单点修改;
对于答案查询,我们直接的时候,我们假设查询的点是3,那么我们在查询的时候可分为两部分;
第一部分:查找出除3这颗子树以外有多少个蘑菇,然后将蘑菇数*此路径;
然后再一一枚举3这颗树的各个子树即可;
这种做法在牛客上能过,不过比赛时的测评应该会超时,比如当出现菊花图的时候,复杂度就会到n^2log n;
先把这份代码贴上:
1 #include<bits/stdc++.h> 2 using namespace std; 3 const int maxx = 1e6+10; 4 typedef long long LL; 5 struct node 6 { 7 int to,val,next; 8 }e[maxx*2]; 9 int head[maxx],tot=0; 10 int son[maxx],id[maxx],fa[maxx],dep[maxx],siz[maxx],top[maxx],cnt=0; 11 int a[maxx]; 12 LL t[maxx<<2],lazy[maxx<<2]; 13 int n; 14 void update(int l,int r,int p,int q,int k,int rt) 15 { 16 if(l==r){ 17 t[rt]+=1LL*k; 18 return; 19 } 20 int mid=(l+r)/2; 21 if(p<=mid)update(l,mid,p,q,k,rt*2); 22 else update(mid+1,r,p,q,k,rt*2+1); 23 t[rt]=t[rt*2]+t[rt*2+1]; 24 } 25 LL query(int l,int r,int L,int R,int rt) 26 { 27 if(L<=l&&R>=r){ 28 return t[rt]; 29 } 30 int mid=(l+r)/2; 31 LL ans=0; 32 if(L<=mid) ans+=query(l,mid,L,R,rt<<1); 33 if(R>mid) ans+=query(mid+1,r,L,R,rt<<1|1); 34 return ans; 35 } 36 void add(int u,int v,int w) 37 { 38 e[++tot].to=v;e[tot].val=w; 39 e[tot].next=head[u];head[u]=tot; 40 } 41 void dfs1(int x,int f,int deep) 42 { 43 dep[x]=deep; 44 fa[x]=f; 45 siz[x]=1; 46 int maxson=-1; 47 for(int i=head[x];i;i=e[i].next) 48 { 49 int y=e[i].to; 50 if(y==f)continue; 51 dfs1(y,x,deep+1); 52 a[y]=e[i].val; 53 siz[x]+=siz[y]; 54 if(siz[y]>maxson)son[x]=y,maxson=siz[y]; 55 } 56 } 57 void dfs2(int x,int topf) 58 { 59 id[x]=++cnt; 60 top[x]=topf; 61 if(!son[x])return; 62 dfs2(son[x],topf); 63 for(int i=head[x];i;i=e[i].next) 64 { 65 int y=e[i].to; 66 if(y==fa[x]||y==son[x])continue; 67 dfs2(y,y); 68 } 69 } 70 void change(int x,int y,int k) 71 { 72 while(top[x]!=top[y]) 73 { 74 if(dep[top[x]]<dep[top[y]])swap(x,y); 75 // update(1,n,id[top[x]],id[x],k,1); 76 x=fa[top[x]]; 77 } 78 if(dep[x]>dep[y])swap(x,y); 79 // update(1,n,id[x],id[y],k,1); 80 } 81 LL getsum(int x) 82 { 83 LL ans=0; 84 ans+=(query(1,n,id[1],id[1]+siz[1]-1,1)-query(1,n,id[x],id[x]+siz[x]-1,1))*a[x]; 85 for(int i=head[x];i;i=e[i].next) 86 { 87 int y=e[i].to; 88 if(y==fa[x])continue; 89 ans+=query(1,n,id[y],id[y]+siz[y]-1,1)*e[i].val; 90 } 91 return ans; 92 } 93 int main() 94 { 95 scanf("%d",&n); 96 int u,v,w; 97 for(int i=1;i<n;i++){ 98 scanf("%d%d%d",&u,&v,&w); 99 add(u,v,w);add(v,u,w); 100 } 101 dfs1(1,0,1); 102 dfs2(1,1); 103 int q; 104 scanf("%d",&q); 105 int op,st=1,x,k; 106 while(q--){ 107 scanf("%d",&op); 108 if(op==1){ 109 scanf("%d%d",&x,&k); 110 update(1,n,id[x],id[x],k,1); 111 // change(1,x,k); 112 } 113 else scanf("%d",&st); 114 printf("%lld\n",getsum(st)); 115 } 116 return 0; 117 }
那么应该如何优化呢,这就需要充分理解树剖的轻重链;
优化之后的做法分为3部分(需要预处理出目前有多少个蘑菇,已经每个节点有多少个蘑菇)
1.求出某节点的重儿子这棵树有多少个蘑菇,再*上重儿子的权值;
2.求出某节点的轻儿子的最后答案;
3.剩下的蘑菇数就是除这颗树以外的所有蘑菇,我们用总数减去以上两部分,再减去这个节点的蘑菇数(这个节点的蘑菇数贡献为0),得出的数乘上此节点的路径权值即可;
这思路代码我没有自己写,所以贴上某神犇的代码;神犇代码风格与上文略有不同;
我的是单点修改区间查询;
1 #include<bits/stdc++.h> 2 #define rint register int 3 #define deb(x) cerr<<#x<<" = "<<(x)<<'\n'; 4 #define fi first 5 #define se second 6 using namespace std; 7 typedef long long ll; 8 using pii = pair <ll,ll>; 9 const int maxn = 1e6 + 5; 10 int n, q, dep[maxn], fa[maxn], fv[maxn], size[maxn]; 11 ll sum, t[maxn<<2], lz[maxn<<2], cnt[maxn]; 12 int dfn[maxn], id[maxn], tot, son[maxn], top[maxn]; 13 vector <pii> g[maxn]; 14 pii ans[maxn]; 15 16 void dfs1(int u, int f, int de) { 17 dep[u] = de, fa[u] = f, size[u] = 1; 18 for(auto tmp : g[u]) { 19 int v = tmp.fi; 20 int w = tmp.se; 21 if(v == f) continue; 22 dfs1(v, u, de+1); 23 fv[v] = w; 24 size[u] += size[v]; 25 if(size[son[u]] < size[v]) son[u] = v; 26 } 27 } 28 29 void dfs2(int u, int tp) { 30 top[u] = tp, dfn[++tot] = u, id[u] = tot; 31 if(son[u]) dfs2(son[u], tp); 32 for(auto tmp : g[u]) { 33 int v = tmp.fi; 34 if(v == fa[u]) continue; 35 if(v == son[u]) continue; 36 dfs2(v, v); 37 } 38 } 39 40 void pushdown(int rt) { 41 if(lz[rt]) { 42 t[rt<<1] += lz[rt]; 43 t[rt<<1|1] += lz[rt]; 44 lz[rt<<1] += lz[rt]; 45 lz[rt<<1|1] += lz[rt]; 46 lz[rt] = 0; 47 } 48 } 49 50 void update(ll x, int L, int R, int l, int r, int rt) { 51 if(l>R || r<L) return; 52 if(l>=L && r<=R) { 53 t[rt] += x; 54 lz[rt] += x; 55 return; 56 } 57 pushdown(rt); 58 int mid = l + r >> 1; 59 update(x, L, R, l, mid, rt<<1); 60 update(x, L, R, mid+1, r, rt<<1|1); 61 t[rt] = t[rt<<1] + t[rt<<1|1]; 62 } 63 64 ll query(int pos, int l, int r, int rt) { 65 if(pos>r || pos<l) return 0; 66 if(l == r) return t[rt]; 67 pushdown(rt); 68 int mid = l + r >> 1; ll ret = 0; 69 ret += query(pos, l, mid, rt<<1); 70 ret += query(pos, mid+1, r, rt<<1|1); 71 return ret; 72 } 73 74 void gao(int u, int x) { 75 while(u) { 76 update(x, id[top[u]], id[u], 1, n, 1); 77 u = top[u]; 78 ans[fa[u]].fi += 1ll * x * fv[u]; 79 ans[fa[u]].se += x; 80 u = fa[u]; 81 } 82 } 83 84 void solve(int u) { 85 ll res = 0, num = query(id[son[u]], 1, n, 1); 86 res += 1ll * num * fv[son[u]]; 87 res += 1ll * (sum - cnt[u] - num - ans[u].se) * fv[u]; 88 res += ans[u].fi; 89 printf("%lld\n", res); 90 } 91 92 int main() { 93 scanf("%d", &n); 94 for(int i=1, u, v, w; i<n; i++) { 95 scanf("%d%d%d", &u, &v, &w); 96 g[u].push_back({v, w}); 97 g[v].push_back({u, w}); 98 } 99 dfs1(1, 0, 1); 100 dfs2(1, 1); 101 scanf("%d", &q); 102 int op, v, x, rt = 1; 103 while(q--) { 104 scanf("%d", &op); 105 if(op == 1) { 106 scanf("%d%d", &v, &x); 107 sum += x; 108 cnt[v] += x; 109 gao(v, x); 110 } else scanf("%d", &rt); 111 solve(rt); 112 } 113 }