题目
点这里看题目。
分析
看到链的限制很奇葩, " 存在一个重要的选择 " ,于是就不难想到容斥。
首先定义 \((u_1,v_1)\cup (u_2,v_2)\) 表示求两条路径的边的并集。
显然容斥式子长这个样子:
\[\sum_{S\subset Q} (-1)^{|S|}\times 2^{n-1-|\bigcup_{p\in S}p|} \]
24 pts
考虑暴力选出容斥式子中的 \(S\) ,然后用一番树上差分求出 \(|\bigcup_{p\in S}p|\) 。
时间复杂度是 \(O(2^mn)\) 。
32 pts
同样是暴力选出容斥式子中的 \(S\) ,只不过我们这一次用树剖动态维护 \(|\bigcup_{p\in S} p|\) 。
现在的时间是 \(O(2^mm\log_2^2n)\) 。
如果我们按照 Gray Code 的顺序来枚举 \(S\) ,就可以省掉一个 \(m\) ,将时间优化到 \(O(2^m\log_2^2n)\) 。
虽然这并不可以让你获得更高的分数。
40 pts
唉我好像不太会。
64 pts
直接暴力枚举 \(S\) 的方法已经行不通了,我们需要换种方法统计后面这一大堆。
不难想到要用树形 DP 。这里的状态设计比较奇妙,可以先自己思考一会儿。
猜到了吗?最暴力的状态还是:
\(f(u,S)\):以 \(u\) 为根的子树内,选中了 \(S\) 集合中的链的容斥方案。
然后你就用另一种方式获得暴力容斥的分数。
注意到我们选取 \(S\) 集合,实际上是想要知道,从 \(u\) 到祖先上,哪些边会被限制到。
但是有一点是:如果有 \(u\) 是 \(v\) 的祖先,且 \(u\) 的父边已经被限制,那么 \(v\) 的父边一定被限制。
这意味着,我们只需要记录最浅的被限制的边的深度(一条边的深度是较浅的那个端点的深度),于是状态可以改为:
\(f(u,d,k)\):只考虑以 \(u\) 为根的子树内的边,且此时被限制的边的最小深度是 \(d\) ,选择了 \(k\) 条链的方案数。
想到一个老 trick,冷静分析一下,发现完全不需要记录 \(k\) ,直接在转移的时候乘 \(-1\) 就行了。于是状态就变成了两维。
(以下记 \(dep_u\) 为 \(u\) 的深度)
转移如下:
-
尝试合并一棵子树 \(v\) 。此时我们需要考虑 \((u,v)\) 是否被限制。记 \(g\) 为合并了 \((u,v)\) 之后 \(v\) 的贡献。
\[g(d)=f(v,d)\times 2^{[d\ge dep_v]} \]
然后合并到 \(u\) 上来就有:
\[f(u,d)=\sum_{\min\{i,j\}=d}f(u,i)\times g(j) \]
这个看似卷积又不太像卷积的东西( \(\min\) 卷积),我们待会儿再来谈怎么优化它。
-
考虑加入一条链,链顶的深度为 \(d\) ,于是有转移:
\[f(u,i)=f(u,i)-\sum_{\min\{k,d\}=i} f(u,k) \]
冷静分析一下之后发现:
\[f(u,k)=\begin{cases}0&k<d\\-\sum_{i=d+1}^{n}f(u,i)&k=d\\f(u,k)&k>d\end{cases} \]
另外还有一个边界条件,就是最初的时候有 \(f(u,n)=1\) ,因为压根儿没有限制。
于是你就可以拿到 \(O(n^2)\) 的暴力 DP 分数。结合 40 pts 的某些做法就可以获得 64 pts 的高分。
100 pts
这里需要用到一个叫做 " 整体 DP " 的技巧。
说得这么玄乎,其实这就是使用数据结构维护 DP 数组,方便进行结构性操作和区间操作。
在这个问题中,由于转移过程中 DP 数组并不会变长或者变短,因此我们可以考虑使用线段树维护 DP 数组。
那么考虑不同的转移:
-
合并子树。此时求 \(g\) 就相当于直接做区间乘法。
考虑之后合并到 \(u\) 上来。我们首先记后缀和:
\[\begin{aligned}su(d)&=\sum_{i=d}^nf(u,d)\\sg(d)&=\sum_{i=d}^ng(d)\end{aligned} \]
那么转移就是:
\[f(u,d)=f(u,d)\times sg(d+1)+g(d)\times su(d+1)+f(u,d)\times g(d) \]
如果这是正常的对应项求积,我们就可以直接线段树合并,然后在叶节点乘起来。
现在我们还是可以线段树合并,不过我们需要求出 \(sg\) 和 \(su\) ,并且在叶节点加上额外贡献。
当然每次直接查询是不科学的,我们应该利用 cdq 分治的技巧。我们在线段树合并的时候,如果当前区间是 \([l,r]\) ,我们就需要维护 \(su(r+1)\) 和 \(sg(r+1)\) 。
我们可以首先递归左子树,并且将右子树的 \(su\) 和 \(sg\) 贡献给左子树。之后再贡献给右子树。
于是就可以在线段树合并的过程中处理掉这个 \(\min\) 卷积。
-
新增一条链,可以发现加入一条链做了三件事情:
- 区间清空。
- 区间求和。
- 单点赋值。
这三个东西都可以用线段树快速处理。第一个可以直接区间乘法。
一点优化:如果当前点 \(u\) 是多条链的链尾,我们可以只处理链顶深度最大的那条链。
这是由于,如果有两条链,链顶深度分别为 \(d\) 和 \(d'\) ,且 \(d'<d\) ,我们可以说明只有 \(d\) 会影响结果。
如果我们让 \(d\) 在 \(d'\) 之后加入 DP ,那么 \(d'\) 的影响就被清除了。
否则就是 \(d'\) 在 \(d\) 之后加入 DP 。首先对于 \(i<d'\) 和 \(i>d'\) , \(f(u,i)\) 不会受到影响。考虑 \(f(u,d')\) ,简单算一下 \(\sum_{k=d'+1}^nf(u,k)\) 可以发现它也是 0 。所以 \(d'\) 对答案仍然没有影响。
于是这道题就做完了,时间是 \(O(n\log_2n)\) 。实际上代码也不怎么难写,基本上都是常规的算法。
我不是指思维的常规。不过仔细想想这道题的思路和技巧也并不是特别难。
顺便,它有亿点点卡常。
代码
#include <cstdio>
typedef long long LL;
const int mod = 998244353;
const int MAXN = 5e5 + 5, MAXM = 5e5 + 5, MAXLOG = 40;
const int MAXS = MAXN * MAXLOG;
inline void read( int &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
inline void write( int x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
inline int MAX( const int a, const int b )
{
return a > b ? a : b;
}
struct edge
{
int to, nxt;
}Graph[MAXN << 1];
int ttag[MAXS];
int su[MAXS], lch[MAXS], rch[MAXS];
int siz;
int head[MAXN], dep[MAXN], rt[MAXN], mx[MAXN];
int N, M, cnt;
inline int newNode() { int cur = ++ siz; ttag[cur] = 1; return cur; }
inline void upt( const int x ) { su[x] = ( su[lch[x]] + su[rch[x]] ) % mod; }
inline void sub( int &x, const int v ) { x = ( x < v ? x - v + mod : x - v ); }
inline int mul( LL x, const int v ) { x *= v; if( x >= mod ) x %= mod; return x; }
inline void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }
inline void mult( int x, const int p ) { if( ! x ) return ; su[x] = mul( su[x], p ), ttag[x] = mul( ttag[x], p ); }
inline void normalize( const int x ) { if( ! x || ttag[x] == 1 ) return; mult( lch[x], ttag[x] ), mult( rch[x], ttag[x] ), ttag[x] = 1; }
inline void addEdge( const int from, const int to )
{
Graph[++ cnt].to = to, Graph[cnt].nxt = head[from];
head[from] = cnt;
}
void multiply( int &x, const int l, const int r, const int segL, const int segR, const int delt )
{
if( ! x ) return ;
if( segL <= l && r <= segR ) { mult( x, delt ); return ; }
int mid = l + r >> 1; normalize( x );
if( segL <= mid ) multiply( lch[x], l, mid, segL, segR, delt );
if( mid < segR ) multiply( rch[x], mid + 1, r, segL, segR, delt );
upt( x );
}
void setVal( int &x, const int l, const int r, const int p, const int v )
{
x = x ? x : newNode();
if( l == r ) { su[x] = v; return ; }
int mid = l + r >> 1; normalize( x );
if( p <= mid ) setVal( lch[x], l, mid, p, v );
else setVal( rch[x], mid + 1, r, p, v );
upt( x );
}
int query( const int x, const int l, const int r, const int segL, const int segR )
{
if( ! x ) return 0;
if( segL <= l && r <= segR ) return su[x];
int mid = l + r >> 1, ret = 0; normalize( x );
if( segL <= mid ) add( ret, query( lch[x], l, mid, segL, segR ) );
if( mid < segR ) add( ret, query( rch[x], mid + 1, r, segL, segR ) );
return ret;
}
int merg( const int x, const int y, const int l, const int r, const int sufX, const int sufY )
{
if( ! x ) { mult( y, sufX ); return y; }
if( ! y ) { mult( x, sufY ); return x; }
if( l == r )
{
su[x] = ( ( LL ) mul( su[x], sufY ) + mul( su[y], sufX ) + mul( su[x], su[y] ) ) % mod;
return x;
}
int mid = l + r >> 1; normalize( x ), normalize( y );
lch[x] = merg( lch[x], lch[y], l, mid, ( sufX + su[rch[x]] ) % mod, ( sufY + su[rch[y]] ) % mod );
rch[x] = merg( rch[x], rch[y], mid + 1, r, sufX, sufY );
upt( x );
return x;
}
void DFS( const int u, const int fa )
{
dep[u] = dep[fa] + 1;
setVal( rt[u], 1, N, N, 1 );
for( int i = head[u], v ; i ; i = Graph[i].nxt )
if( ( v = Graph[i].to ) ^ fa )
{
DFS( v, u );
multiply( rt[v], 1, N, dep[v], N, 2 );
rt[u] = merg( rt[u], rt[v], 1, N, 0, 0 );
}
if( mx[u] )
{
multiply( rt[u], 1, N, 1, mx[u], 0 );
int tVal = query( rt[u], 1, N, mx[u] + 1, N );
setVal( rt[u], 1, N, mx[u], mod - tVal );
}
}
void init( const int u, const int fa )
{
dep[u] = dep[fa] + 1;
for( int i = head[u], v ; i ; i = Graph[i].nxt )
if( ( v = Graph[i].to ) ^ fa ) init( v, u );
}
int main()
{
read( N ); for( int i = 1, a, b ; i < N ; i ++ ) read( a ), read( b ), addEdge( a, b ), addEdge( b, a );
init( 1, 0 ), read( M ); for( int i = 1, fr, to ; i <= M ; i ++ ) read( fr ), read( to ), mx[to] = MAX( mx[to], dep[fr] );
DFS( 1, 0 ); write( su[rt[1]] ), putchar( '\n' );
return 0;
}