题意
待修莫队与树上莫队合并起来的练手题。
code:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+10;
const int maxm=1e5+10;
const int maxQ=1e5+10;
int n,m,Q,cnt_edge,tim,cnt1,cnt2;
int head[maxn],val[maxm],w[maxn],a[maxn],b[maxn],dep[maxn];
int f[maxn][20];
inline int read()
{
char c=getchar();int res=0,f=1;
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9')res=res*10+c-'0',c=getchar();
return res*f;
}
struct edge{int to,nxt;}e[maxn<<1];
inline void add(int u,int v)
{
e[++cnt_edge].nxt=head[u];
head[u]=cnt_edge;
e[cnt_edge].to=v;
}
int ouler[maxn<<1],st[maxn],ed[maxn];
void dfs(int x,int fa)
{
for(int i=1;i<=18;i++)f[x][i]=f[f[x][i-1]][i-1];
dep[x]=dep[fa]+1;
ouler[++tim]=x;st[x]=tim;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(y==fa)continue;
f[y][0]=x;dfs(y,x);
}
ouler[++tim]=x;ed[x]=tim;
}
inline int lca(int x,int y)
{
if(dep[x]>dep[y])swap(x,y);
for(int i=18;~i;i--)if(dep[f[y][i]]>=dep[x])y=f[y][i];
if(x==y)return x;
for(int i=18;~i;i--)if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
return f[x][0];
}
int nowl=1,nowr,nowtim;
int pos[maxn<<1],cnt[maxm];
ll nowans;
ll ans[maxQ];
bool vis[maxn];
struct Query{int tim,l,r,lca,id;}qr[maxQ];
struct Times{int pos,last,now;}times[maxQ];
inline bool cmp(Query x,Query y){return pos[x.l]==pos[y.l]?(pos[x.r]==pos[y.r]?x.tim<y.tim:x.r<y.r):x.l<y.l;}
inline void del(int c){nowans-=1ll*val[c]*w[cnt[c]--];}
inline void add(int c){nowans+=1ll*val[c]*w[++cnt[c]];}
inline void change(int pos,int k)
{
if(vis[pos])del(a[pos]),add(k);
a[pos]=k;
}
inline void work(int pos)
{
if(vis[ouler[pos]])del(a[ouler[pos]]);
else add(a[ouler[pos]]);
vis[ouler[pos]]^=1;
}
int main()
{
//freopen("test.in","r",stdin);
//freopen("test.out","w",stdout);
n=read(),m=read(),Q=read();
for(int i=1;i<=m;i++)val[i]=read();
for(int i=1;i<=n;i++)w[i]=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read();
add(u,v),add(v,u);
}
dfs(1,0);
for(int i=1;i<=n;i++)a[i]=b[i]=read();
for(int i=1;i<=Q;i++)
{
int op=read(),x=read(),y=read();
if(op)qr[++cnt1]=(Query){cnt2,x,y,0,cnt1};
else times[++cnt2]=(Times){x,b[x],y},b[x]=y;
}
for(int i=1;i<=Q;i++)
{
if(st[qr[i].l]>st[qr[i].r])swap(qr[i].l,qr[i].r);
int z=lca(qr[i].l,qr[i].r);
if(z==qr[i].l)qr[i].l=st[qr[i].l],qr[i].r=st[qr[i].r];
else qr[i].l=ed[qr[i].l],qr[i].r=st[qr[i].r],qr[i].lca=z;
}
int t=pow(2*n,0.6666666666);
for(int i=1;i<=2*n;i++)pos[i]=(i-1)/t+1;
sort(qr+1,qr+cnt1+1,cmp);
for(int i=1;i<=cnt1;i++)
{
while(nowtim<qr[i].tim)change(times[nowtim+1].pos,times[nowtim+1].now),nowtim++;
while(nowtim>qr[i].tim)change(times[nowtim].pos,times[nowtim].last),nowtim--;
while(nowl<qr[i].l)work(nowl++);
while(nowl>qr[i].l)work(--nowl);
while(nowr<qr[i].r)work(++nowr);
while(nowr>qr[i].r)work(nowr--);
if(qr[i].lca)work(st[qr[i].lca]);
ans[qr[i].id]=nowans;
if(qr[i].lca)work(st[qr[i].lca]);
}
for(int i=1;i<=cnt1;i++)printf("%lld\n",ans[i]);
return 0;
}