Description
Solution
本题要不难想到要用线段树或树状数组之类的数据结构,但是题目要求在树上操作,我们该如何处理呢?
首先我们要用到一个叫dfs序的概念。其实并不难,刚接触的同学不要被它吓到,它本质上就是一棵树的先序遍历,所谓先序遍历就是先遍历根,然后遍历左子节点,最后遍历右子节点。我们需要把dfs序存在pos数组中,并把每个节点第一次遍历到的时间点和第二次遍历到的时间点存到in和out数组中,这样就成功地把一棵树转换为了线性结构。对一棵子树进行操作时,只需对这棵子树的根节点两次遍历到的时间戳中间的部分进行操作即可。
求dfs序的代码:
inline void dfs(int x, int fa){
in[x] = ++tim, pos[tim] = x;
for(int i = head[x]; i; i = edge[i].nxt){
int y = edge[i].v;
if(y != fa) dfs(y, x);
}
out[x] = tim;
}
然后我们就可以用dfs序,也就是pos数组对线段树进行操作了,不过需要用到状态压缩,要把颜色压缩成二进制数到线段树中,所以要开long long。接下来基本上都是线段树区间修改,区间查询的模板了。需要注意的是,查询出来的值是一个经过状压后的数,我们需要把它分解。这里可以借鉴树状数组的思想,即每次减去一个lowbit(一棵树上有数值的节点的最低位,不会的话可以先去学习一下树状数组,这里不再过多赘述)再让ans++,因为状压后只有0和1,有值的话一定是1。ans就是最后的答案。
Code
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#define ls rt << 1
#define rs rt << 1 | 1
#define ll long long
using namespace std;
inline int read(){
int x = 0;
char ch = getchar();
while(ch < '0' || ch > '9') ch = getchar();
while(ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x;
}
const int N = 4e5 + 10;
int n, m;
int a[N];
struct node{
int v, nxt;
}edge[N << 1];
int head[N], tot;
int in[N], out[N], pos[N], tim;
inline void add(int x, int y){
edge[++tot] = (node){y, head[x]};
head[x] = tot;
}
inline void dfs(int x, int fa){
in[x] = ++tim, pos[tim] = x;
for(int i = head[x]; i; i = edge[i].nxt){
int y = edge[i].v;
if(y != fa) dfs(y, x);
}
out[x] = tim;
}
struct Seg_tree{
ll sum, cov;
}t[N << 2];
inline void pushup(int rt){
t[rt].sum = t[ls].sum | t[rs].sum;
}
inline void pushdown(int rt){
if(t[rt].cov){
t[ls].sum = t[rs].sum = t[rt].cov;
t[ls].cov = t[rs].cov = t[rt].cov;
t[rt].cov = 0;
}
}
inline void build(int l, int r, int rt){
if(l == r){
t[rt].sum = (1ll << a[pos[l]]);
return;
}
int mid = (l + r) >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
pushup(rt);
}
inline void update(int L, int R, int k, int l, int r, int rt){
if(L <= l && r <= R){
t[rt].sum = (1ll << k);
t[rt].cov = (1ll << k);
return;
}
pushdown(rt);
int mid = (l + r) >> 1;
if(L <= mid) update(L, R, k, l, mid, ls);
if(R > mid) update(L, R, k, mid + 1, r, rs);
pushup(rt);
}
inline ll query(int L, int R, int l, int r, int rt){
if(L <= l && r <= R)
return t[rt].sum;
pushdown(rt);
int mid = (l + r) >> 1;
ll res = 0;
if(L <= mid) res |= query(L, R, l, mid, ls);
if(R > mid) res |= query(L, R, mid + 1, r, rs);
return res;
}
inline int calc(ll x){
int res = 0;
for(; x; x -= x & (-x)) res++;
return res;
}
int main(){
n = read(), m = read();
for(int i = 1; i <= n; ++i)
a[i] = read();
for(int i = 1; i < n; ++i){
int u = read(), v = read();
add(u, v), add(v, u);;
}
dfs(1, 0);
build(1, n, 1);
for(int i = 1; i <= m; ++i){
int op = read(), x = read();
if(op == 1){
int c = read();
update(in[x], out[x], c, 1, n, 1);
}else printf("%d\n", calc(query(in[x], out[x], 1, n, 1)));
}
return 0;
}