【UNR 3】配对树(线段树合并)

传送门

题意

给定一棵 \(n\) 个结点的树和一个长度为 \(m\) 的结点序列。对于一个大小为偶数的点集 \(S\)(集合元素可重复),定义 \(w(S)\) 为:把 \(S\) 中的点两两匹配,每对匹配的树上距离之和的最小值。现在要对序列中所有长度为偶数的区间 \([l,r]\),求出 \(w(\{a_l,a_{l+1},\cdots,a_r\})\) 的和。

\(n,m\le 10^5\),答案对 \(998244353\) 取模。

分析

首先观察到 \(w(S)\) 可以拆到每条边上算。对于一条边:

  • 它两边的子树中都有奇数个 \(S\) 中的点时会被算 \(1\) 次。
  • 它两边的子树中都有偶数个 \(S\) 中的点时会被算 \(0\) 次。

因此我们可以钦定一个根,然后对于每个子树,计算有多少个长度为偶数的区间 \([l,r]\) 满足 \(a_l,a_{l+1},\cdots,a_r\) 中有奇数个点在这棵子树中。

我们只讨论 \(l\) 为奇数、\(r\) 为偶数的情况,另一种情况是一样的。设 \(c_{u,i}\) 表示 \(a_1,a_2,\cdots,a_i\) 在 \(u\) 子树内出现的次数。观察到 \([l,r]\) 合法当且仅当 \(c_{u,l-1}\) 与 \(c_{u,r}\) 的奇偶性不同。因此我们只要算出使 \(c_{u,2i}\) 为奇数的 \(i\) 的个数。设它为 \(x\),则 \(u\) 到父亲这条边就要算 \(x\times(\lfloor{n\over2}\rfloor+1-x)\) 次。

我们可以用线段树合并来维护 \(b_{u_i}=c_{u,2i} \bmod 2\)。\(b_u\) 实际上是一个 01 数组 \(b'_u\) 的前缀异或和。因此对于线段树上的一个结点,我们只要记录对应区间上 \(b_u\) 的和以及 \(b'_u\) 中 1 的个数的奇偶性即可。时间复杂度 \(O(n\log n)\)。

实现

#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,a,b) for(int i=a;i>=b;i--)
#define pb push_back
#define md ((l+r)>>1)
using namespace std;
typedef long long ll;

const int mod=998244353,maxn=1e5+5,maxm=maxn*17;
int n,m,L,N,ans,a[maxn],lc[maxm],rc[maxm],cn[maxm],su[maxm],rt[maxn];
struct edge{int v,w;};
vector<edge>G[maxn];
vector<int>V[maxn];

inline int inc(int x,int y){return x+=y-mod,x+=x>>31&mod;}
inline int mul(int x,int y){return ll(x)*y%mod;}
inline int nnd(){++N,lc[N]=rc[N]=cn[N]=su[N]=0;return N;}

void up(int x,int l,int r){
	cn[x]=cn[lc[x]]^cn[rc[x]];
	su[x]=su[lc[x]]+(cn[lc[x]]?r-md-su[rc[x]]:su[rc[x]]);
}

int merge(int x,int y,int l,int r){
	if(!x||!y)return x+y;
	if(l==r)return su[x]=cn[x]^=cn[y],x;
	lc[x]=merge(lc[x],lc[y],l,md);
	rc[x]=merge(rc[x],rc[y],md+1,r);
	return up(x,l,r),x;
}

void modify(int&x,int l,int r,int i){
	if(!x)x=nnd();
	if(l==r){su[x]=cn[x]^=1;return;}
	i<=md?modify(lc[x],l,md,i):modify(rc[x],md+1,r,i);up(x,l,r);
}

void dfs(int u,int _v=0,int w=0){
	rt[u]=0;
	for(int i:V[u])modify(rt[u],1,L,i);
	for(edge e:G[u])if(e.v!=_v)dfs(e.v,u,e.w),rt[u]=merge(rt[u],rt[e.v],1,L);
	ans=inc(ans,mul(w,mul(su[rt[u]],L-su[rt[u]])));
}

int main(){
	scanf("%d%d",&n,&m);
	rep(i,1,n-1){
		int u,v,w;scanf("%d%d%d",&u,&v,&w);
		G[u].pb({v,w}),G[v].pb({u,w});	
	}
	rep(i,1,m)scanf("%d",&a[i]);
	rep(k,1,2){
		L=(m+k)>>1,N=0;
		rep(i,1,n)V[i].clear();
		rep(i,1,m)if((i+k+1)>>1<=L)V[a[i]].pb((i+k+1)>>1);
		dfs(1);
	}
	printf("%d\n",ans);
	return 0;
}
上一篇:MySQL 默认隔离级别是RR,为什么阿里等大厂会改成RC?


下一篇:linux ATA驱动中的宏AHCI_MAX_CLKS