题解 Count on a tree II/【模板】树分块

link

Description

给出一个大小为 \(n\) 的树,每个点有点权,有 \(m\) 次查询,每次查询 \(u\to v\) 的不同点权个数。强制在线。

\(n\le 4\times 10^4,m\le 10^5\)

Solution

不知道这是不是正宗的树分块。

我们考虑假如我们能取出约 \(\Theta(n/B)\) 个点,使得任意一个点到其最近的一个点距离都 \(\le B\),那么我们就可以提前处理任意两两这些点预处理信息,再把剩下的距离手动加进去。这样的话我们复杂度就可以做到 \(\Theta(n^2/B^2\times t+qB)\) 。其中 \(t\) 是一次预处理的复杂度。

我们发现其实 \(B=\sqrt n\) 的时候最优。那我们怎么取呢?我们发现直接在 \(\text{dep}\equiv 0\pmod{B}\) 的时候就好了,因为显然。

然后这个题目的话我们可以用 bitset 来维护一下就好了。稍微有点卡空间和卡时间,\(B=800\) 的时候比较优秀。

Code

#include <bits/stdc++.h>
using namespace std;
 
#define Int register int
#define MAXN 40005
#define MAXM 205

template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> inline void chkmax (T &a,T b){a = max (a,b);}
template <typename T> inline void chkmin (T &a,T b){a = min (a,b);}

vector <int> h[MAXN];
int n,m,uni,cnt,ind,B,dfn[MAXN],siz[MAXN],val[MAXN],tmp[MAXN],pat[MAXN],dep[MAXN],tur[MAXN],par[MAXN][21];

void dfs1 (int u,int fa){
	dep[u] = dep[fa] + 1,par[u][0] = fa,dfn[u] = ++ ind,siz[u] = 1;
	if (dep[u] % B == 0) tur[u] = ++ cnt,tmp[cnt] = u;
	for (Int i = 1;i <= 20;++ i) par[u][i] = par[par[u][i - 1]][i - 1];
	for (Int v : h[u]) if (v ^ fa) dfs1 (v,u),siz[u] += siz[v];
}

int hav[MAXM][MAXM],have;
bitset <MAXN> Now,con[MAXM][MAXM];
void dfs2 (int st,int u,int fa){
	int lst = Now[val[u]];have += (!Now[val[u]]),Now[val[u]] = 1;
	if (tur[u]) con[st][tur[u]] = Now,hav[st][tur[u]] = have;
	for (Int v : h[u]) if (v ^ fa) dfs2 (st,v,u);
	Now[val[u]] = lst,have -= (!lst);
}

int getlca (int u,int v){
	if (dep[u] < dep[v]) swap (u,v);
	for (Int i = 20,dis = dep[u] - dep[v];~i;-- i) if (dis >> i & 1) u = par[u][i];
	if (u == v) return u;
	else{
		for (Int i = 20;~i;-- i) if (par[u][i] ^ par[v][i]) u = par[u][i],v = par[v][i];
		return par[u][0];
	}
}

bool checkin (int u,int v){return dfn[u] <= dfn[v] && dfn[v] <= dfn[u] + siz[u] - 1;}

int query (int u,int v){
	int lca = getlca (u,v),tmp1 = 0,tmp2 = 0,tmpu = u,tmpv = v;
	while(dep[tmpu] >= dep[lca]){
		if (tur[tmpu]){
			tmp1 = tur[tmpu];
			break;
		}
		tmpu = par[tmpu][0];
	}
	while (dep[tmpv] >= dep[lca]){
		if (tur[tmpv]){
			tmp2 = tur[tmpv];
			break;
		}
		tmpv = par[tmpv][0];
	}
	if (tmp1 && tmp2){
		int ans = hav[tmp1][tmp2];Now = con[tmp1][tmp2];
		for (Int f1 = u;f1 != tmpu;f1 = par[f1][0]) ans += (!Now[val[f1]]),Now[val[f1]] = 1;
		for (Int f2 = v;f2 != tmpv;f2 = par[f2][0]) ans += (!Now[val[f2]]),Now[val[f2]] = 1;
		return ans;
	}
	else if (!tmp1 && !tmp2){
		Now.reset ();int ans = 0;
		for (Int f1 = u;dep[f1] >= dep[lca];f1 = par[f1][0]) ans += (!Now[val[f1]]),Now[val[f1]] = 1;
		for (Int f2 = v;dep[f2] >= dep[lca];f2 = par[f2][0]) ans += (!Now[val[f2]]),Now[val[f2]] = 1;
		return ans;
	}
	else{
		if (tmp2) swap (tmp1,tmp2),swap (u,v),swap (tmpu,tmpv);tmp2 = 0,tmpv = 0; 
		for (Int i = 1;i <= cnt;++ i) if (checkin (tmp[i],tmpu) && checkin (lca,tmp[i])){
			if (tmp2 == 0 || dep[tmp[tmp2]] > dep[tmp[i]]) tmpv = tmp[i],tmp2 = i;
		}
		int ans = hav[tmp1][tmp2];Now = con[tmp1][tmp2];
		for (Int f1 = tmpv;dep[f1] >= dep[lca];f1 = par[f1][0]) ans += (!Now[val[f1]]),Now[val[f1]] = 1;
		for (Int f2 = v;dep[f2] >= dep[lca];f2 = par[f2][0]) ans += (!Now[val[f2]]),Now[val[f2]] = 1;
		for (Int f3 = u;f3 != tmpu;f3 = par[f3][0]) ans += (!Now[val[f3]]),Now[val[f3]] = 1;
		return ans;
	}
}

signed main(){
	read (n,m),B = 800;
	for (Int i = 1;i <= n;++ i) read (val[i]),tmp[i] = val[i];
	sort (tmp + 1,tmp + n + 1),uni = unique (tmp + 1,tmp + n + 1) - tmp - 1;
	for (Int i = 1;i <= n;++ i) val[i] = lower_bound (tmp + 1,tmp + n + 1,val[i]) - tmp;
	for (Int i = 2,u,v;i <= n;++ i) read (u,v),h[u].push_back (v),h[v].push_back (u);
	cnt = 0,dfs1 (1,0);
	for (Int i = 1;i <= cnt;++ i) dfs2 (i,tmp[i],0);
	int lstans = 0;
	while (m --> 0){
		int u,v;read (u,v),u ^= lstans;
		write (lstans = query (u,v)),putchar ('\n');
	}
	return 0;
}
上一篇:[NOIP2004 提高组] 合唱队形


下一篇:11.第三章 Linux文件管理和IO重定向 -- 文件操作命令和文件元数据和节点表结构(四)