【题解】[ZJOI2019]语言

Problem

\(\text{Solution:}\)

菜鸡太菜了想了好久没有思路……只知道要求树上链的并集但不知道咋整……虽然题目算法线段树合并和树上差分看到题就能想到……但是怎么做还是有思维难度……(对笔者来说)

看了好久的题解都没有看懂 这次写的详细一点。

  • 所谓矩阵面积并

某些题解中说了一句 “可以用树剖套扫描线” “就是一个矩阵面积并” 的做法,复杂度 \(O(n\log^3 n).\)

实际上,我们对树进行树剖后,每一条路径 \((s,t)\) 都可以被划分为 \(\log n\) 个线段。我们可以把每一段被划分的重链区间看做矩形,那我们对每个点求的链并集,实际上就是一个抽象的 “矩形面积并” 。

复杂度是树剖的 \(\log\) 和扫面线本身的两个 \(\log\). 我没有实现这种做法。

  • 线段树合并、树上差分与树剖

常规的 \(\log^2 n\) 做法是这个。

由于有序对不好算,我们可以算无序对的个数,除以二即可。

记 \(S_u\) 表示 \(u\) 可以到达的点的集合(不包括它自己)。那么 \(ans=\frac{\sum S_i}{2}.\)

如何对每一个点计算出它的 \(S\) 呢?

首先明确一点:我们对每一个点建立的线段树是一棵 以 dfs序 为基础的线段树,它维护了整棵树的信息。

那么,对每一个点都有这样一棵线段树,对于每一个语言传递过程,我们在路上的每一个点上都进行一次区间覆盖操作,最终每一个点上线段树上被覆盖的长度就是它的 \(S\) .

显然这玩意复杂度不对。

首先,对树上路径修改,我们需要用到树剖。

那么对于一条路径上的点,我们可以自然想到用树上差分的思路来优化。

并不知道一棵支持区间修改的树咋合并) 上文所述,我们需要维护一棵线段树上被覆盖的长度,并且是 全局询问

那这个东西就长得很像扫描线了。 维护区间被覆盖的次数来更新即可。

那么合并呢?

合并要注意一下细节:每一次合并到的点都要记录信息,因为这种写法的区间覆盖次数没有什么下传操作,不要只在叶子上维护 pushup 操作,这样维护的信息是错误而不全面的。

其他写法都一样,复杂度是一个 \(\log\).

那么总结就是:考虑暴力,每一个点都记录一下链上信息,又因为需要树上修改路径需要树剖操作,而对于一条路径上的点我们可以考虑树上差分优化,进而利用线段树合并来解决这题。

时间复杂度:\(O(n\log^2 n).\)

空间分析

之前从神鱼那里见到,空间复杂度是 \(2n\log n\) 的。

计算一下:\(Num=2\cdot 10^5 \cdot \log 10^5=3.32\cdot 10^6\) 级别。然而代码中,进行修改的操作达到了 \(O(m\log n)\) 级别的个数,也就是多了个 \(\log\) ,约为\(5.5\cdot 10^7\)级别。空间没有卡满,代码开到 \(2\cdot 10^7\) 可以过去。

#include<bits/stdc++.h>
using namespace std;
const int MAXN=2e7+10;
int ls[MAXN],rs[MAXN],topp,rub[MAXN],node;
int cnt[MAXN],len[MAXN],id[MAXN],rk[MAXN];
int rt[MAXN],pa[MAXN],siz[MAXN],son[MAXN];
int dep[MAXN],head[MAXN],tot,n,m,top[MAXN];
long long ans;
struct E {
	int nxt,to;
} e[MAXN];

inline int Max(int x,int y){return x>y?x:y;}

inline void add(int x,int y) {
	e[++tot] = ( E ) {
		head [ x ] , y
	} ;
	head[x]=tot;
}

void dfs1(int x,int fa) {
	pa [ x ] = fa ;
	siz [ x ] = 1 ;
	dep [ x ] = dep [ fa ] + 1 ;
	for ( int i=head[x]; i; i=e[i].nxt) {
		int j=e[i].to;
		if(j==fa)continue;
		dfs1(j,x);
		siz[x]+=siz[j];
		if(siz[j]>siz[son[x]])son[x]=j;
	}
}

int dfstime;
void dfs2(int u,int t) {
	top [ u ] = t ;
	id [ u ] = ++ dfstime ;
	if ( ! son [ u ] )
		return ;
	dfs2 ( son [ u ] , t ) ;
	for ( int i = head [ u ] ; i ; i = e [ i ] .nxt ) {
		int j = e [ i ] .to ;
		if ( j == son [ u ] || j == pa [ u ] )
			continue ;
		dfs2 ( j , j ) ;
	}
}

inline void del(int x) {
	rub[++topp]=x;
	len[x]=cnt[x]=ls[x]=rs[x]=0;
}

inline int New() {
	if(topp)return rub[topp--];
	return ++node;
}

inline void pushup(int x,int l,int r) {
	if(cnt[x]>0)
		len[x]=r-l+1;
	else
		len[x]=len[ls[x]]+len[rs[x]];
}

void change(int &x,int L,int R,int l,int r,int v) {
	if(!x)x=New();
	if ( l <= L && R <= r ) {
		cnt[x]+=v;
		pushup(x,L,R);
		return;
	}
	int mid=(L+R)>>1;
	if(l<=mid)change(ls[x],L,mid,l,r,v);
	if(mid<r)change(rs[x],mid+1,R,l,r,v);
	pushup(x,L,R);
}

int merge(int x,int y,int l,int r) {
	if(!x||!y)return x+y;
	cnt[x]+=cnt[y];
	if(l==r) {
		pushup(x,l,r);
		del(y);
		return x;
	}
	int mid=(l+r)>>1;
	ls[x]=merge(ls[x],ls[y],l,mid);
	rs[x]=merge(rs[x],rs[y],mid+1,r);
	pushup(x,l,r);
	del(y);
	return x;
}

int query(int x,int L,int R,int l,int r) {
	if(L>=l&&R<=r)return len[x];
	int mid=(L+R)>>1,val=0;
	if(l<=mid)val+=query(ls[x],L,mid,l,r);
	if(mid<r)val+=query(rs[x],mid+1,R,l,r);
	return val;
}

void changes(int root,int x,int y,int v) {
	while(top[x]!=top[y]) {
		if(dep[top[x]]>=dep[top[y]]) {
			change(rt[root],1,n,id[top[x]],id[x],v);
			x=pa[top[x]];
		} else {
			change(rt[root],1,n,id[top[y]],id[y],v);
			y=pa[top[y]];
		}
	}
	if(id[x]<=id[y])change(rt[root],1,n,id[x],id[y],v);
	else change(rt[root],1,n,id[y],id[x],v);
}

inline int LCA(int x,int y) {
	while ( top [ x ] != top [ y ] ) {
		if(dep[top[x]]>=dep[top[y]])x=pa[top[x]];
		else y=pa[top[y]];
	}
	return dep [ x ] < dep [ y ] ? x : y ;
}

void dfs3 ( int x ) {
	for ( int i = head [ x ] ; i ; i = e [ i ] .nxt ) {
		int j = e [ i ] .to ;
		if ( j == pa [ x ] )
			continue ;
		dfs3 ( j ) ;
		rt [ x ] = merge ( rt [ x ] , rt [ j ] , 1 , n ) ;
	}
	ans += Max(query ( rt [ x ] , 1 , n , 1 , n ) -1,0) ;
}

signed main() {
	freopen("1.in","r",stdin);
	freopen("my.out","w",stdout);
	scanf("%d%d",&n,&m);
	for(int i=1; i<n; ++i) {
		int x,y;
		scanf("%d%d",&x,&y);
		add(x,y);
		add(y,x);
	}
	dfs1(1,1);
	dfs2(1,1);
	for(; m; m--) {
		int s,t;
		scanf("%d%d",&s,&t);
		int lca=LCA(s,t);
		changes(s,s,t,1);
		changes(t,s,t,1);
		changes(lca,s,t,-1);
		if(lca==1)continue;
		changes(pa[lca],s,t,-1);
	}
	dfs3(1);
	ans>>=1;
	printf("%lld\n",ans);
	return 0;
}
上一篇:Codeforces Round #714 (Div. 2) 题解(A-D)


下一篇:PAT-B1016. 部分A+B 题解