description
solution
有一道非常相似的题目
一棵树,每条边限制两个端点的大小关系(限制 a [ u ] > a [ v ] a[u]>a[v] a[u]>a[v] 或 a [ u ] < a [ v ] a[u]<a[v] a[u]<a[v])
求有多少种符合要求的排列 a a a满足整棵树的限制。 n < = 5000 n<=5000 n<=5000
考虑如果所有边都是朝一个方向的话很好做
答案就是 n ! n! n!除以每个子树的大小
如果存在反向边的话,暴力枚举断开若干个反向边,剩下的边改为正向,然后计算答案
容斥即可。这样暴力做的复杂度是 O ( 2 n ∗ n ) O(2^n*n) O(2n∗n) 的
考虑 d p dp dp, f ( i , j , k ) f(i,j,k) f(i,j,k) 表示以 i i i 为根的子树,当前 i i i 所在连通块内有 j j j 个点,总共反向 k k k 条边的方案数
合并两棵子树时,如果边是正向的,那么直接合并;
否则要么断开,要么让 k + 1 k+1 k+1 并且按照正向合并
复杂度 n n n 的若干次方
考虑最后的容斥只需要关注 k k k 的奇偶性,因此第三维完全可以省掉
即合并两棵子树时,如果边是正向则直接合并,否则值就是断开的方案减掉把边正向的方案
因此就是一个简单的树背包,复杂度 O ( n 2 ) O(n^2) O(n2)
此题只是需要将二维
d
p
dp
dp再次优化即可
设
d
p
[
i
]
dp[i]
dp[i]表示前缀
i
i
i的合法方案数,
c
n
t
[
i
]
cnt[i]
cnt[i]表示前缀
i
i
i中
>
>
>的个数
d
p
[
i
]
i
!
=
∑
j
=
0
i
−
1
[
s
[
j
]
=
′
>
′
]
(
i
−
j
)
!
(
−
1
)
c
n
t
[
i
−
1
]
−
c
n
t
[
j
]
×
d
p
[
j
]
j
!
\frac{dp[i]}{i!}=\sum_{j=0}^{i-1}\frac{[s[j]='>']}{(i-j)!}(-1)^{cnt[i-1]-cnt[j]}\times \frac{dp[j]}{j!}
i!dp[i]=j=0∑i−1(i−j)![s[j]=′>′](−1)cnt[i−1]−cnt[j]×j!dp[j]
将
(
−
1
)
c
n
t
[
i
−
1
]
(-1)^{cnt[i-1]}
(−1)cnt[i−1]提出来,剩余部分用
N
T
T
NTT
NTT分治完成有难度
code
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define mod 998244353
#define int long long
#define maxn 400005
int len, inv;
char s[maxn];
int cnt[maxn];
int fac[maxn], ifac[maxn], r[maxn];
int f[maxn], g[maxn], dp[maxn];
int qkpow( int x, int y ) {
int ans = 1;
while( y ) {
if( y & 1 ) ans = ans * x % mod;
x = x * x % mod;
y >>= 1;
}
return ans;
}
void NTT( int *c, int opt ) {
for( int i = 0;i < len;i ++ )
if( i < r[i] ) swap( c[i], c[r[i]] );
for( int i = 1;i < len;i <<= 1 ) {
int omega = qkpow( opt == 1 ? 3 : mod / 3 + 1, ( mod - 1 ) / ( i << 1 ) );
for( int j = 0;j < len;j += ( i << 1 ) ) {
int w = 1;
for( int k = 0;k < i;k ++, w = w * omega % mod ) {
int x = c[j + k], y = w * c[j + k + i] % mod;
c[j + k] = ( x + y ) % mod;
c[j + k + i] = ( x - y + mod ) % mod;
}
}
}
if( opt == -1 ) {
int inv = qkpow( len, mod - 2 );
for( int i = 0;i < len;i ++ )
c[i] = c[i] * inv % mod;
}
}
void solve( int L, int R ) {
if( L == R ) {
if( ! L ) dp[L] = 1;
else dp[L] = cnt[L] & 1 ? mod - dp[L] : dp[L];//单独提出来
return;
}
int mid = ( L + R ) >> 1;
solve( L, mid );
len = 1; int l = 0;
while( len <= R - L + 1 + mid - L ) len <<= 1, l ++;
for( int i = 0;i < len;i ++ )
r[i] = ( r[i >> 1] >> 1 ) | ( ( i & 1 ) << ( l - 1 ) );
for( int i = 0;i <= mid - L;i ++ )
if( s[i + L] == '<' && i + L != 0 ) f[i] = 0;
else f[i] = cnt[i + L] & 1 ? dp[i + L] : mod - dp[i + L];//注意奇偶转换
for( int i = mid - L + 1;i < len;i ++ ) f[i] = 0;
for( int i = 0;i <= R - L + 1;i ++ ) g[i] = ifac[i];
for( int i = R - L + 2;i < len;i ++ ) g[i] = 0;
NTT( f, 1 );
NTT( g, 1 );
for( int i = 0;i < len;i ++ ) f[i] = f[i] * g[i] % mod;
NTT( f, -1 );
for( int i = mid + 1;i <= R;i ++ ) dp[i] = ( dp[i] + f[i - L] ) % mod;
solve( mid + 1, R );
}
signed main() {
scanf( "%s", s + 1 );
int n = strlen( s + 1 );
s[++ n] = '>';
fac[0] = 1;
for( int i = 1;i <= n;i ++ )
fac[i] = fac[i - 1] * i % mod;
ifac[n] = qkpow( fac[n], mod - 2 );
for( int i = n - 1;~ i;i -- )
ifac[i] = ifac[i + 1] * ( i + 1 ) % mod;
for( int i = 1;i <= n;i ++ )
cnt[i] = cnt[i - 1] + ( s[i] == '>' );
solve( 0, n );
printf( "%lld\n", dp[n] * fac[n] % mod );
return 0;
}