裸的树链剖分加线段树区间修改
区间合并时需要多注意一点
当时写的很慢 理解不深刻
#include<bits/stdc++.h>
using namespace std;
const int INF = 0x3f3f3f3f;
const int MAXN = 40005;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
int N,M;
struct Pode{
int to,next;
}edge[MAXN * 2];
int tot;
int head[MAXN];
int E[MAXN][3];
void addedge(int x,int y){
edge[tot].to = y; edge[tot].next = head[x]; head[x] = tot ++;
}
int top[MAXN],fa[MAXN],son[MAXN],deep[MAXN],num[MAXN],p[MAXN],fp[MAXN],pos;
void dfs1(int x,int pre,int dep){
deep[x] = dep;
fa[x] = pre;
num[x] = 1;
for(int i = head[x]; i != -1; i = edge[i].next){
int y = edge[i].to;
if(y == pre) continue;
dfs1(y,x,dep + 1);
num[x] += num[y];
if(son[x] == -1 || num[y] > num[son[x]])
son[x] = y;
}
}
void dfs2(int x,int tp){
top[x] = tp;
p[x] = pos++;
fp[p[x]] = x;
if(son[x] == -1) return;
dfs2(son[x],tp);
for(int i = head[x] ; i != -1; i = edge[i].next){
int y = edge[i].to;
if(y != son[x] && y != fa[x])
dfs2(y,y);
}
}
int ininum[MAXN];
struct Node{
int l,r, num;
Node(int a=-1, int b=0, int c=0):l(a), r(b), num(c){}
Node operator + (const Node &T)const {
if(l == -1) return T; if(T.l == -1) return Node(l,r,num);
Node tt;
tt.l = l; tt.r = T.r;
tt.num = num+T.num;
if(r == T.l) tt.num--;
return tt;
}
Node rev(){
return Node(r,l,num);
}
};
struct Segtree{
Node tree[MAXN<<2];
int lazy[MAXN<<2];
void Pushup(int rt){
tree[rt] = tree[rt<<1]+tree[rt<<1|1];
}
void Pushdown(int rt) {
if(lazy[rt] != -1) {
lazy[rt<<1] = lazy[rt<<1|1] = lazy[rt];
tree[rt<<1|1] = tree[rt<<1] = Node(lazy[rt], lazy[rt], 1);
lazy[rt] = -1;
}
}
void Build(int l,int r,int rt) {
lazy[rt] = -1;
if(l == r) {
tree[rt] = Node(ininum[l], ininum[l], 1);
return;
}
int m = (l+r) >>1;
Build(lson); Build(rson);
Pushup(rt);
}
void Change(int L,int R,int num,int l,int r,int rt) {
if(L <= l && r <= R) {
tree[rt] = Node(num,num,1);
lazy[rt] = num;
return;
}
int m = (l+r) >>1;
Pushdown(rt);
if(L <= m) Change(L,R,num,lson);
if(R > m) Change(L,R,num,rson);
Pushup(rt);
}
Node Sum(int L,int R,int l,int r,int rt){
if(L <= l && r <= R) {
return tree[rt];
}
int m = (l + r) >> 1;
Node ans;
Pushdown(rt);
if(L <= m) ans = ans+Sum(L,R,lson);
if(R > m) ans = ans+Sum(L,R,rson);
return ans;
}
void Find(int x,int y,int d){
int t1 = top[x]; int t2 = top[y];
while(t1 != t2) {
if(deep[t1] < deep[t2]) {
swap(t1,t2); swap(x,y);
}
Change(p[t1], p[x], d, 1,N,1);
x = fa[t1];
t1 = top[x];
}
if(x == y) return;
if(deep[x] > deep[y]) {
swap(x,y);
}
Change(p[son[x]], p[y], d,1, N, 1);
}
Node Query(int x,int y){
int t1 = top[x]; int t2 = top[y];
Node X, Y;
while(t1 != t2) {
if(deep[t1] < deep[t2]){
Y = Sum(p[t2], p[y], 1,N,1)+Y;
y = fa[t2]; t2 = top[y];
}else {
X = Sum(p[t1], p[x], 1,N,1)+X;
x = fa[t1]; t1 = top[x];
}
}
if(x == y) return X+Y;
if(deep[x] > deep[y]){
return Y.rev() + Sum(p[son[y]], p[x], 1,N,1) + X;
}else {
return X.rev() + Sum(p[son[x]], p[y], 1,N,1) + Y;
}
}
}solve;
int main(){
while(~scanf("%d %d",&N,&M)){
memset(head,-1,sizeof(head));
memset(son, -1, sizeof(son));
tot = 0;
pos = 1;
for(int i = 1; i < N; ++i){
scanf("%d%d%d",&E[i][0],&E[i][1],&E[i][2]);
addedge(E[i][0], E[i][1]); addedge(E[i][1], E[i][0]);
}
dfs1(1,0,1);
dfs2(1,1);
ininum[0] = 0;
for(int i = 1; i < N; ++i){
if(deep[E[i][0]] < deep[E[i][1]]) swap(E[i][0], E[i][1]);
ininum[ p[E[i][0]] ] = E[i][2];
}
solve.Build(1,N,1);
for(int i = 1; i <= M; ++i) {
char a[10]; int b,c,d;
scanf("%s",a);
if(a[0] == 'C') {
scanf("%d %d %d",&b,&c,&d);
solve.Find(b,c,d);
}else {
scanf("%d %d",&b,&c);
printf("%d\n", solve.Query(b,c).num );
}
}
}
return 0;
}