Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),
如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
第一眼看上去肯定是树链剖分,然后就是想怎么用线段树维护区间色段。
我们用线段树维护一个区间最左边的颜色,最右边的颜色,和颜色段数。如果一个节点的左儿子的右颜色和右儿子的左颜色相同,那么它的色段数是左+右-1,否则是左+右。
但是在查询时一定要注意,跑完每一条重链,和下一条重链中的轻链时,他们在线段树上并不是一起查询的。我们需要单点找出当前重链的顶端和下一个重链的底端的颜色,如果颜色相同,那么ans-1.
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstdlib>
#include <cstring>
#define in(a) a=read()
#define REP(i,k,n) for(int i=k;i<=n;i++)
#define MAXN 100010
using namespace std;
inline int read(){
int x=,f=;
char ch=getchar();
for(;!isdigit(ch);ch=getchar())
if(ch=='-')
f=-;
for(;isdigit(ch);ch=getchar())
x=x*+ch-'';
return x*f;
}
int n,m,a,b,d;
char c;
int input[MAXN];
int total,head[MAXN],nxt[MAXN<<],to[MAXN<<];
int depth[MAXN],size[MAXN],son[MAXN],f[MAXN];
int cnt,dfn[MAXN],top[MAXN],link[MAXN];
struct node{
int l,r,lc,rc,s,lt;
}tree[MAXN<<];
inline void adl(int a,int b){
total++;
to[total]=b;
nxt[total]=head[a];
head[a]=total;
return ;
}
inline void getson(int u,int fa){//得到重儿子
size[u]=;
for(int e=head[u];e;e=nxt[e])
if(to[e]!=fa){
depth[to[e]]=depth[u]+;
f[to[e]]=u;
getson(to[e],u);
size[u]+=size[to[e]];
if(!son[u] || size[to[e]]>size[son[u]]) son[u]=to[e];
}
return ;
}
inline void getdfn(int u,int t){//得到重边
top[u]=t;
dfn[u]=++cnt;
link[cnt]=u;
if(!son[u]) return ;
getdfn(son[u],t);
for(int e=head[u];e;e=nxt[e])
if(to[e]!=f[u] && to[e]!=son[u])
getdfn(to[e],to[e]);
return ;
}
inline void build(int i,int l,int r){//建树
tree[i].l=l;
tree[i].r=r;
if(l==r){
tree[i].s=,tree[i].lc=tree[i].rc=input[link[l]];
return ;
}
int mid=(l+r)>>;
build(i<<,l,mid);
build(i<<|,mid+,r);
if(tree[i<<].rc==tree[i<<|].lc) tree[i].s=tree[i<<].s+tree[i<<|].s-;
else tree[i].s=tree[i<<].s+tree[i<<|].s;
tree[i].lc=tree[i<<].lc;
tree[i].rc=tree[i<<|].rc;
}
inline void pushdown(int i){//下传懒标记
if(!tree[i].lt) return ;
int k=tree[i].lt;
tree[i<<].s=tree[i<<|].s=;
tree[i<<].lc=tree[i<<].rc=tree[i<<|].lc=tree[i<<|].rc=k;
tree[i<<].lt=tree[i<<|].lt=k;
tree[i].lt=;
return ;
}
inline void add(int i,int l,int r,int k){//修改颜色
if(tree[i].l>=l && tree[i].r<=r){
tree[i].s=;
tree[i].lt=tree[i].lc=tree[i].rc=k;
return ;
}
pushdown(i);
if(tree[i<<].r>=l) add(i<<,l,r,k);
if(tree[i<<|].l<=r) add(i<<|,l,r,k);
if(tree[i<<].rc==tree[i<<|].lc) tree[i].s=tree[i<<].s+tree[i<<|].s-;
else tree[i].s=tree[i<<].s+tree[i<<|].s;
tree[i].lc=tree[i<<].lc;
tree[i].rc=tree[i<<|].rc;
return ;
}
inline void updates(int x,int y,int z){//枚举两点间每一条重边
int tx=top[x],ty=top[y];
while(tx!=ty){
if(depth[tx]<depth[ty]) swap(tx,ty),swap(x,y);
add(,dfn[tx],dfn[x],z);
x=f[tx];
tx=top[x],ty=top[y];
}
if(depth[x]<depth[y]) swap(x,y);
add(,dfn[y],dfn[x],z);
}
inline int query(int i,int l,int r){//区间查询
int sum=;
if(tree[i].l>=l && tree[i].r<=r) return tree[i].s;
pushdown(i);
if(tree[i<<].r>=l) sum+=query(i<<,l,r);
if(tree[i<<|].l<=r) sum+=query(i<<|,l,r);
if(tree[i<<].r>=l && tree[i<<|].l<=r && tree[i<<].rc==tree[i<<|].lc) sum--;
return sum;
}
inline int getcolor(int i,int dis){//查询单点颜色
if(tree[i].l==tree[i].r) return tree[i].lc;
pushdown(i);
int mid=(tree[i].l+tree[i].r)>>;
if(dis<=mid) return getcolor(i<<,dis);
else return getcolor(i<<|,dis);
}
inline int getsum(int x,int y){//枚举查询时两点间的重边
int tx=top[x],ty=top[y],ans=;
while(tx!=ty){
if(depth[tx]<depth[ty]) swap(tx,ty),swap(x,y);
ans+=query(,dfn[tx],dfn[x]);
if(getcolor(,dfn[tx])==getcolor(,dfn[f[tx]])) ans--;//看轻边两点的颜色是否相同
x=f[tx];
tx=top[x],ty=top[y];
}
if(depth[x]<depth[y]) swap(x,y);
ans+=query(,dfn[y],dfn[x]);
return ans;
}
int main(){
in(n),in(m);
REP(i,,n) in(input[i]);
REP(i,,n-) in(a),in(b),adl(a,b),adl(b,a);
depth[]=;
getson(,);
getdfn(,);
build(,,n);
REP(i,,m){
cin>>c;
if(c=='C') in(a),in(b),in(d),updates(a,b,d);
if(c=='Q') in(a),in(b),printf("%d\n",getsum(a,b));
}
}