P6177 Count on a tree II

给你一棵 \(n\) 个节点且带点权的树,\(m\) 个询问,每次查询链颜色数。

强制在线。

\(1\leq n \leq 4\times 10^4,1\leq m\leq 10^5\)。

sol

首先如果不强制在线,用树上莫队即可。

但多了个强制在线,容易想到是预处理题。

查询链颜色数,比较好的一种方法是使用 bitset,对值域建 bitset,答案就是 bitset 中 \(1\) 的数量。

那么现在的问题就是怎么把一条路径上的 bitset 并起来。

记得离散化。

法一

考虑树分块。

考虑用一种简单的树分块技巧——树上撒点。

简单来说就是先设一个阈值 \(S\),在树上选择不超过 \(\frac{n}{S}\) 个点作为关键点,满足每个关键点到离它最近的祖先关键点的距离不超过 \(S\)。

具体地,每次选择当前深度最大的一个非关键点,若它的 \(1 \sim S\) 级祖先都不是关键点,则把它的 \(S\) 级祖先标记为关键点。

由于上述方法中每标记一个关键点,至少有 \(S\) 个点不会被标记,所以关键点的数量是正确的。

仔细思考,容易发现每个关键点到离它最近的祖先关键点的距离不超过 \(S\) 这个条件也满足。

撒完关键点,再记录两关键点间的 bitset,先用 \(\mathcal O(S)\) 的时间求出相邻两关键点的 bitset,再处理出两两之间的即可,预处理总时间复杂度为 \(\mathcal O(\frac{n^2}{S}+\frac{n^3}{wS^2})\)。

然后考虑询问,此时询问的路径就被拆成了两个散块和一个整块,散块暴力,整块 bitset 取交集即可,总时间复杂度为 \(\mathcal O(mS+\frac{nm}{w})\)。

取 \(S=\sqrt n\),则总时间复杂度为 \(\mathcal O((n+m)\sqrt n+\frac{n^2+nm}{w})\),可过。

\(\text{1.88s / 20.14MB / 4.15KB C++14 (GCC 9) O2}\)。

#include <cstdio>
#include <bitset>
#include <algorithm>

#define re register

namespace Fread
{
    const int SIZE = 1 << 23;
    char buf[SIZE], *S, *T;
    inline char getchar()
    {
        if (S == T)
        {
            T = (S = buf) + fread(buf, 1, SIZE, stdin);
            if (S == T)
                return '\n';
        }
        return *S++;
    }
}
namespace Fwrite
{
    const int SIZE = 1 << 23;
    char buf[SIZE], *S = buf, *T = buf + SIZE;
    inline void flush()
    {
        fwrite(buf, 1, S - buf, stdout);
        S = buf;
    }
    inline void putchar(char c)
    {
        *S++ = c;
        if (S == T)
            flush();
    }
    struct NTR
    {
        ~NTR()
        {
            flush();
        }
    } ztr;
}

#ifdef ONLINE_JUDGE
#define getchar Fread::getchar
#define putchar Fwrite::putchar
#endif

inline int read()
{
	re int x = 0, f = 1;
	re char c = getchar();
	while (c < '0' || c > '9')
	{
		if (c == '-')
			f = -1;
		c = getchar();
	}
	while (c >= '0' && c <= '9')
	{
		x = x * 10 + c - '0';
		c = getchar();
	}
	return x * f;
}

inline void write(re int x)
{
	if (x < 0)
	{
		putchar('-');
		x = -x;
	}
	if (x > 9)
		write(x / 10);
	putchar(x % 10 + '0');
}

inline void swap(re int &x, re int &y)
{
	x ^= y ^= x ^= y;
}

const int _ = 4e4 + 7;

std::bitset<_> bt[42][42], nw;

int n, m, B, arr[_], a[_], fa[_], dep[_], mxd[_], FF[_], siz[_], tp[_], hson[_];

int id[_], cnt, head[_], tot, ans, sta[_], top, gg[_];

struct edge
{
	int to, nxt;
} e[_ << 1];

void dfs1(re int now, re int D)
{
	siz[now] = 1;
	dep[now] = D;
	mxd[now] = dep[now];
	for (re int i = head[now]; i; i = e[i].nxt)
	{
		re int v = e[i].to;
		if (!dep[v])
		{
			fa[v] = now;
			dfs1(v, D + 1);
			siz[now] += siz[v];
			if (mxd[v] > mxd[now])
				mxd[now] = mxd[v];
			if (siz[hson[now]] < siz[v])
				hson[now] = v;
		}
	}
	if (mxd[now] - dep[now] >= 1000)
		id[now] = ++tot, mxd[now] = dep[now];
}

void dfs2(re int now)
{
	for (re int i = head[now]; i; i = e[i].nxt)
	{
		re int v = e[i].to;
		if (dep[v] > dep[now])
		{
			if (id[v])
			{
				int ip = id[sta[top]], in = id[v];
				for (int x = v; x != sta[top]; x = fa[x])
					bt[ip][in].set(a[x]);
				nw = bt[ip][in];
				for (int i = 1; i < top; ++i)
				{
					std::bitset<_> &bs = bt[id[sta[i]]][in];
					bs = bt[id[sta[i]]][ip];
					bs |= nw;
				}
				FF[v] = sta[top], gg[v] = gg[sta[top]] + 1;
				sta[++top] = v;
			}
			dfs2(v);
			if (id[v])
				--top;
		}
	}
}

void dfs3(re int now, re int tf)
{
	tp[now] = tf;
	if (hson[now])
		dfs3(hson[now], tf);
	for (re int i = head[now]; i; i = e[i].nxt)
	{
		re int v = e[i].to;
		if (!tp[v])
			dfs3(v, v);
	}
}

inline int LCA(re int x, re int y)
{
	while (tp[x] != tp[y])
	{
		if (dep[tp[x]] < dep[tp[y]])
			swap(x, y);
		x = fa[tp[x]];
	}
	return dep[x] < dep[y] ? x : y;
}

signed main()
{
	n = read(), m = read();
	for (re int i = 1; i <= n; ++i)
		arr[i] = a[i] = read();
	std::sort(arr + 1, arr + 1 + n);
	arr[0] = std::unique(arr + 1, arr + 1 + n) - arr - 1;
	for (re int i = 1; i <= n; ++i)
		a[i] = std::lower_bound(arr + 1, arr + 1 + arr[0], a[i]) - arr;
	for (re int i = 1, u, v; i < n; ++i)
	{
		u = read(), v = read();
		e[++cnt] = (edge){v, head[u]}, head[u] = cnt;
		e[++cnt] = (edge){u, head[v]}, head[v] = cnt;
	}
	dfs1(1, 1);
	if (!id[1])
		id[1] = ++tot;
	sta[top = 1] = gg[1] = 1;
	dfs2(1);
	dfs3(1, 1);
	while (m--)
	{
		int u = read() ^ ans, v = read();
		nw.reset();
		int l = LCA(u, v);
		while (u != l && !id[u])
			nw.set(a[u]), u = fa[u];
		while (v != l && !id[v])
			nw.set(a[v]), v = fa[v];
		if (u != l)
		{
			int pre = u;
			while (dep[FF[pre]] >= dep[l])
				pre = FF[pre];
			if (pre != u)
				nw |= bt[id[pre]][id[u]];
			while (pre != l)
				nw.set(a[pre]), pre = fa[pre];
		}
		if (v != l)
		{
			int pre = v;
			while (dep[FF[pre]] >= dep[l])
				pre = FF[pre];
			if (pre != v)
				nw |= bt[id[pre]][id[v]];
			while (pre != l)
				nw.set(a[pre]), pre = fa[pre];
		}
		nw.set(a[l]);
		write(ans = nw.count()), putchar('\n');
	}
	return 0;
}

法二

考虑轻重链剖分,询问时将路径上的若干条重链的 bitset 并起来即可。

由于重链上的点的 dfn 序是连续的,序列分块即可。

每次询问,跳重链分块计算这条重链的贡献即可。

时间复杂度为 \(\mathcal O(\frac{n^2}{w}+m\log n(\sqrt{n}+\frac{n}{w}))\)。

\(\text{1.05s / 18.14MB / 3.47KB C++14 (GCC 9) O2}\)。

#include <cstdio>
#include <bitset>
#include <algorithm>

#define re register

namespace Fread
{
	const int SIZE = 1 << 23;
	char buf[SIZE], *S, *T;
	inline char getchar()
	{
		if (S == T)
		{
			T = (S = buf) + fread(buf, 1, SIZE, stdin);
			if (S == T)
				return '\n';
		}
		return *S++;
	}
}
namespace Fwrite
{
	const int SIZE = 1 << 23;
	char buf[SIZE], *S = buf, *T = buf + SIZE;
	inline void flush()
	{
		fwrite(buf, 1, S - buf, stdout);
		S = buf;
	}
	inline void putchar(char c)
	{
		*S++ = c;
		if (S == T)
			flush();
	}
	struct NTR
	{
		~NTR()
		{
			flush();
		}
	} ztr;
}

#ifdef ONLINE_JUDGE
#define getchar Fread::getchar
#define putchar Fwrite::putchar
#endif

inline int read()
{
	re int x = 0, f = 1;
	re char c = getchar();
	while (c < '0' || c > '9')
	{
		if (c == '-')
			f = -1;
		c = getchar();
	}
	while (c >= '0' && c <= '9')
	{
		x = x * 10 + c - '0';
		c = getchar();
	}
	return x * f;
}

inline void write(int x)
{
	if (x < 0)
	{
		putchar('-');
		x = -x;
	}
	if (x > 9)
		write(x / 10);
	putchar(x % 10 + '0');
}

const int _ = 4e4 + 7, W = 4e4 + 7, B = 1e3;

std::bitset<W> f[42][42], nw;

int n, m, cnt_node, ans, tot, a[_], arr[_], fa[_], dep[_], siz[_], hson[_], top[_], dfn[_], b[_], bel[_], L[_], R[_], head[_], to[_ << 1], nxt[_ << 1];

inline void swap(re int &x, re int &y)
{
	x ^= y ^= x ^= y;
}

inline void add(re int u, re int v)
{
	to[++tot] = v;
	nxt[tot] = head[u];
	head[u] = tot;
}

void dfs1(re int u, re int D)
{
	dep[u] = D, siz[u] = 1;
	for (re int i = head[u]; i; i = nxt[i])
	{
		re int v = to[i];
		if (siz[v])
			continue;
		fa[v] = u;
		dfs1(v, D + 1);
		siz[u] += siz[v];
		if (siz[hson[u]] < siz[v])
			hson[u] = v;
	}
}

void dfs2(re int u, re int tf)
{
	top[u] = tf, dfn[u] = ++cnt_node, a[cnt_node] = b[u];
	if (!hson[u])
		return;
	dfs2(hson[u], tf);
	for (re int i = head[u]; i; i = nxt[i])
	{
		re int v = to[i];
		if (top[v])
			continue;
		dfs2(v, v);
	}
}

inline void pre()
{
	for (re int i = 1; i <= n; ++i)
	{
		bel[i] = (i - 1) / B + 1;
		f[bel[i]][bel[i]].set(a[i]);
	}
	for (re int i = 1; i <= bel[n]; ++i)
		L[i] = R[i - 1] + 1, R[i] = i * B;
	R[bel[n]] = n;
	for (re int i = 1; i < bel[n]; ++i)
		for (re int j = i + 1; j <= bel[n]; ++j)
			f[i][j] = f[i][j - 1] | f[j][j];
}

inline void Query_on_block(re int l, re int r)
{
	if (bel[l] == bel[r])
	{
		for (re int i = l; i <= r; ++i)
			nw.set(a[i]);
		return;
	}
	nw |= f[bel[l] + 1][bel[r] - 1];
	for (re int i = l; i <= R[bel[l]]; ++i)
		nw.set(a[i]);
	for (re int i = L[bel[r]]; i <= r; ++i)
		nw.set(a[i]);
}

inline void Query_on_tree(re int u, re int v)
{
	while (top[u] != top[v])
	{
		if (dep[top[u]] < dep[top[v]])
			swap(u, v);
		Query_on_block(dfn[top[u]], dfn[u]);
		u = fa[top[u]];
	}
	if (dep[u] > dep[v])
		swap(u, v);
	Query_on_block(dfn[u], dfn[v]);
}

signed main()
{
	n = read(), m = read();
	for (re int i = 1; i <= n; ++i)
		arr[i] = b[i] = read();
	std::sort(arr + 1, arr + 1 + n);
	arr[0] = std::unique(arr + 1, arr + 1 + n) - arr - 1;
	for (re int i = 1; i <= n; ++i)
		b[i] = std::lower_bound(arr + 1, arr + 1 + arr[0], b[i]) - arr;
	re int u, v;
	for (re int i = 1; i < n; ++i)
	{
		u = read(), v = read();
		add(u, v), add(v, u);
	}
	dfs1(1, 1), dfs2(1, 1);
	pre();
	while (m--)
	{
		nw.reset();
		u = read() ^ ans, v = read();
		Query_on_tree(u, v);
		write(ans = nw.count()), putchar('\n');
	}
}
上一篇:一个完整的、全面 k8s 化的集群稳定架构(值得借鉴)


下一篇:【YBT2022寒假Day1 C】相似子串(SA)(RMQ)(LCP)