[NOI2020]命运

题目

点这里看题目。

分析

看到链的限制很奇葩, " 存在一个重要的选择 " ,于是就不难想到容斥。

首先定义 \((u_1,v_1)\cup (u_2,v_2)\) 表示求两条路径的边的并集。

显然容斥式子长这个样子:

\[\sum_{S\subset Q} (-1)^{|S|}\times 2^{n-1-|\bigcup_{p\in S}p|} \]

24 pts

考虑暴力选出容斥式子中的 \(S\) ,然后用一番树上差分求出 \(|\bigcup_{p\in S}p|\) 。

时间复杂度是 \(O(2^mn)\) 。

32 pts

同样是暴力选出容斥式子中的 \(S\) ,只不过我们这一次用树剖动态维护 \(|\bigcup_{p\in S} p|\) 。

现在的时间是 \(O(2^mm\log_2^2n)\) 。

如果我们按照 Gray Code 的顺序来枚举 \(S\) ,就可以省掉一个 \(m\) ,将时间优化到 \(O(2^m\log_2^2n)\) 。

虽然这并不可以让你获得更高的分数。

40 pts

唉我好像不太会。

64 pts

直接暴力枚举 \(S\) 的方法已经行不通了,我们需要换种方法统计后面这一大堆。

不难想到要用树形 DP 。这里的状态设计比较奇妙,可以先自己思考一会儿。

猜到了吗?最暴力的状态还是:

\(f(u,S)\):以 \(u\) 为根的子树内,选中了 \(S\) 集合中的链的容斥方案。

然后你就用另一种方式获得暴力容斥的分数。

注意到我们选取 \(S\) 集合,实际上是想要知道,从 \(u\) 到祖先上,哪些边会被限制到。

但是有一点是:如果有 \(u\) 是 \(v\) 的祖先,且 \(u\) 的父边已经被限制,那么 \(v\) 的父边一定被限制

这意味着,我们只需要记录最浅的被限制的边的深度(一条边的深度是较浅的那个端点的深度),于是状态可以改为:

\(f(u,d,k)\):只考虑以 \(u\) 为根的子树内的边,且此时被限制的边的最小深度是 \(d\) ,选择了 \(k\) 条链的方案数。

想到一个老 trick,冷静分析一下,发现完全不需要记录 \(k\) ,直接在转移的时候乘 \(-1\) 就行了。于是状态就变成了两维。

(以下记 \(dep_u\) 为 \(u\) 的深度)

转移如下:

  1. 尝试合并一棵子树 \(v\) 。此时我们需要考虑 \((u,v)\) 是否被限制。记 \(g\) 为合并了 \((u,v)\) 之后 \(v\) 的贡献。

    \[g(d)=f(v,d)\times 2^{[d\ge dep_v]} \]

    然后合并到 \(u\) 上来就有:

    \[f(u,d)=\sum_{\min\{i,j\}=d}f(u,i)\times g(j) \]

    这个看似卷积又不太像卷积的东西( \(\min\) 卷积),我们待会儿再来谈怎么优化它。

  2. 考虑加入一条链,链顶的深度为 \(d\) ,于是有转移:

    \[f(u,i)=f(u,i)-\sum_{\min\{k,d\}=i} f(u,k) \]

    冷静分析一下之后发现:

    \[f(u,k)=\begin{cases}0&k<d\\-\sum_{i=d+1}^{n}f(u,i)&k=d\\f(u,k)&k>d\end{cases} \]

另外还有一个边界条件,就是最初的时候有 \(f(u,n)=1\) ,因为压根儿没有限制。

于是你就可以拿到 \(O(n^2)\) 的暴力 DP 分数。结合 40 pts 的某些做法就可以获得 64 pts 的高分。

100 pts

这里需要用到一个叫做 " 整体 DP " 的技巧。

说得这么玄乎,其实这就是使用数据结构维护 DP 数组,方便进行结构性操作和区间操作。

在这个问题中,由于转移过程中 DP 数组并不会变长或者变短,因此我们可以考虑使用线段树维护 DP 数组。

那么考虑不同的转移:

  1. 合并子树。此时求 \(g\) 就相当于直接做区间乘法。

    考虑之后合并到 \(u\) 上来。我们首先记后缀和:

    \[\begin{aligned}su(d)&=\sum_{i=d}^nf(u,d)\\sg(d)&=\sum_{i=d}^ng(d)\end{aligned} \]

    那么转移就是:

    \[f(u,d)=f(u,d)\times sg(d+1)+g(d)\times su(d+1)+f(u,d)\times g(d) \]

    如果这是正常的对应项求积,我们就可以直接线段树合并,然后在叶节点乘起来。

    现在我们还是可以线段树合并,不过我们需要求出 \(sg\) 和 \(su\) ,并且在叶节点加上额外贡献。

    当然每次直接查询是不科学的,我们应该利用 cdq 分治的技巧。我们在线段树合并的时候,如果当前区间是 \([l,r]\) ,我们就需要维护 \(su(r+1)\) 和 \(sg(r+1)\) 。

    我们可以首先递归左子树,并且将右子树的 \(su\) 和 \(sg\) 贡献给左子树。之后再贡献给右子树。

    于是就可以在线段树合并的过程中处理掉这个 \(\min\) 卷积。

  2. 新增一条链,可以发现加入一条链做了三件事情:

    • 区间清空。
    • 区间求和。
    • 单点赋值。

    这三个东西都可以用线段树快速处理。第一个可以直接区间乘法。

    一点优化:如果当前点 \(u\) 是多条链的链尾,我们可以只处理链顶深度最大的那条链

    这是由于,如果有两条链,链顶深度分别为 \(d\) 和 \(d'\) ,且 \(d'<d\) ,我们可以说明只有 \(d\) 会影响结果。

    如果我们让 \(d\) 在 \(d'\) 之后加入 DP ,那么 \(d'\) 的影响就被清除了。

    否则就是 \(d'\) 在 \(d\) 之后加入 DP 。首先对于 \(i<d'\) 和 \(i>d'\) , \(f(u,i)\) 不会受到影响。考虑 \(f(u,d')\) ,简单算一下 \(\sum_{k=d'+1}^nf(u,k)\) 可以发现它也是 0 。所以 \(d'\) 对答案仍然没有影响。

于是这道题就做完了,时间是 \(O(n\log_2n)\) 。实际上代码也不怎么难写,基本上都是常规的算法。

我不是指思维的常规。不过仔细想想这道题的思路和技巧也并不是特别难。

顺便,它有亿点点卡常。

代码

#include <cstdio>

typedef long long LL;

const int mod = 998244353;

const int MAXN = 5e5 + 5, MAXM = 5e5 + 5, MAXLOG = 40;
const int MAXS = MAXN * MAXLOG;

inline void read( int &x )
{
	x = 0;char s = getchar();int f = 1;
	while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
	while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
	x *= f;
}

inline void write( int x )
{
	if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
	if( 9 < x ){ write( x / 10 ); }
	putchar( x % 10 + '0' );
}

inline int MAX( const int a, const int b )
{
	return a > b ? a : b;
}

struct edge
{
	int to, nxt;
}Graph[MAXN << 1];

int ttag[MAXS];
int su[MAXS], lch[MAXS], rch[MAXS]; 
int siz;

int head[MAXN], dep[MAXN], rt[MAXN], mx[MAXN];
int N, M, cnt;

inline int newNode() { int cur = ++ siz; ttag[cur] = 1; return cur; }
inline void upt( const int x ) { su[x] = ( su[lch[x]] + su[rch[x]] ) % mod; }
inline void sub( int &x, const int v ) { x = ( x < v ? x - v + mod : x - v ); }
inline int mul( LL x, const int v ) { x *= v; if( x >= mod ) x %= mod; return x; }
inline void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }
inline void mult( int x, const int p ) { if( ! x ) return ; su[x] = mul( su[x], p ), ttag[x] = mul( ttag[x], p ); }
inline void normalize( const int x ) { if( ! x || ttag[x] == 1 ) return; mult( lch[x], ttag[x] ), mult( rch[x], ttag[x] ), ttag[x] = 1; }

inline void addEdge( const int from, const int to )
{
	Graph[++ cnt].to = to, Graph[cnt].nxt = head[from];
	head[from] = cnt;
}

void multiply( int &x, const int l, const int r, const int segL, const int segR, const int delt )
{
	if( ! x ) return ;
	if( segL <= l && r <= segR ) { mult( x, delt ); return ; }
	int mid = l + r >> 1; normalize( x );
	if( segL <= mid ) multiply( lch[x], l, mid, segL, segR, delt );
	if( mid < segR ) multiply( rch[x], mid + 1, r, segL, segR, delt );
	upt( x );
}

void setVal( int &x, const int l, const int r, const int p, const int v )
{
	x = x ? x : newNode();
	if( l == r ) { su[x] = v; return ; }
	int mid = l + r >> 1; normalize( x );
	if( p <= mid ) setVal( lch[x], l, mid, p, v );
	else setVal( rch[x], mid + 1, r, p, v );
	upt( x );
}

int query( const int x, const int l, const int r, const int segL, const int segR )
{
	if( ! x ) return 0;
	if( segL <= l && r <= segR ) return su[x];
	int mid = l + r >> 1, ret = 0; normalize( x );
	if( segL <= mid ) add( ret, query( lch[x], l, mid, segL, segR ) );
	if( mid < segR ) add( ret, query( rch[x], mid + 1, r, segL, segR ) );
	return ret;
}

int merg( const int x, const int y, const int l, const int r, const int sufX, const int sufY )
{
	if( ! x ) { mult( y, sufX ); return y; }
	if( ! y ) { mult( x, sufY ); return x; }
	if( l == r )
	{
		su[x] = ( ( LL ) mul( su[x], sufY ) + mul( su[y], sufX ) + mul( su[x], su[y] ) ) % mod;
		return x;
	}
	int mid = l + r >> 1; normalize( x ), normalize( y );
	lch[x] = merg( lch[x], lch[y], l, mid, ( sufX + su[rch[x]] ) % mod, ( sufY + su[rch[y]] ) % mod );
	rch[x] = merg( rch[x], rch[y], mid + 1, r, sufX, sufY );
	upt( x );
	return x;
}

void DFS( const int u, const int fa )
{
	dep[u] = dep[fa] + 1;
	setVal( rt[u], 1, N, N, 1 );
	for( int i = head[u], v ; i ; i = Graph[i].nxt )
		if( ( v = Graph[i].to ) ^ fa )
		{
			DFS( v, u );
			multiply( rt[v], 1, N, dep[v], N, 2 );
			rt[u] = merg( rt[u], rt[v], 1, N, 0, 0 );
		}
	if( mx[u] ) 
	{
		multiply( rt[u], 1, N, 1, mx[u], 0 );
		int tVal = query( rt[u], 1, N, mx[u] + 1, N );
		setVal( rt[u], 1, N, mx[u], mod - tVal );
	}
}

void init( const int u, const int fa )
{
	dep[u] = dep[fa] + 1;
	for( int i = head[u], v ; i ; i = Graph[i].nxt )
		if( ( v = Graph[i].to ) ^ fa ) init( v, u );
}

int main()
{
	read( N ); for( int i = 1, a, b ; i < N ; i ++ ) read( a ), read( b ), addEdge( a, b ), addEdge( b, a );
	init( 1, 0 ), read( M ); for( int i = 1, fr, to ; i <= M ; i ++ ) read( fr ), read( to ), mx[to] = MAX( mx[to], dep[fr] );
	DFS( 1, 0 ); write( su[rt[1]] ), putchar( '\n' );
	return 0;
}
上一篇:素数筛(筛选出一定范围内的所有素数)


下一篇:Linux 用户名显示为sh-