Problem
Solution
先找到序列中 \(n\) 的位置,那么在 \(n\) 之前必须有 \(a-1\) 个前缀最大值,之后有 \(b-1\) 个后缀最大值。
设 \(f[i][j]\) 表示长度为 \(i\) 的排列,有 \(j\) 个前缀最大值的方案数。
那么\(ans=\sum_{i=1}^n f[i-1][a-1]\times f[n-i][b-1]\times \binom n {i-1}\)
枚举最小值的位置,那么当且仅当它在第一个位置上时才会贡献一个前缀最大值,则 \(f[i][j]=f[i-1][j-1]+(i-1)f[i-1][j]\)。不难发现这其实是第一类斯特林数的递推式。
怎么理解呢?不妨记第 \(i\) 个前缀最大值的出现位置为 \(p_i\),把\([p_i,p_{i+1})\)视为一个盒子,盒子内的排列的方案数就是圆排列的方案数,所以这就是第一类斯特林数。
前后其实是一个对称的问题,那把它们放在一起考虑,答案就是 \(f[n-1][a+b-2]\binom {a+b-2} {b-1}\),即把这些盒子排在一起,然后再选 \(b-1\) 个盒子放到后面去。
第一类斯特林数可以用它的生成函数来做,倍增FFT加速即可做到\(O(n\log n)\)。
Code
#include <algorithm>
#include <cstring>
#include <cstdio>
using namespace std;
typedef long long ll;
const int maxn=300010,mod=998244353,G=3;
template <typename Tp> int getmin(Tp &x,Tp y){return y<x?x=y,1:0;}
template <typename Tp> int getmax(Tp &x,Tp y){return y>x?x=y,1:0;}
template <typename Tp> void read(Tp &x)
{
x=0;char ch=getchar();int f=0;
while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
if(ch=='-') ch=getchar(),f=1;
while(ch>='0'&&ch<='9') x=x*10+(ch-'0'),ch=getchar();
if(f) x=-x;
}
int n,a,b,N,l,ans,fac[maxn],inv[maxn],rev[maxn],f[maxn],g[maxn];
int pw[maxn],ta[maxn],tb[maxn];
int pls(int x,int y){return x+y>=mod?x+y-mod:x+y;}
int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
int c(int n,int m){return m>n?0:(ll)fac[n]*inv[m]%mod*inv[n-m]%mod;}
int power(int x,int y)
{
int res=1;
for(;y;y>>=1,x=(ll)x*x%mod)
if(y&1)
res=(ll)res*x%mod;
return res;
}
void NTT(int *a,int f)
{
for(int i=1;i<N;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=1;i<N;i<<=1)
{
int gn=power(G,(mod-1)/(i<<1));
for(int j=0;j<N;j+=(i<<1))
{
int g=1,x,y;
for(int k=0;k<i;++k,g=(ll)g*gn%mod)
{
x=a[j+k];y=(ll)g*a[j+k+i]%mod;
a[j+k]=pls(x,y);a[j+k+i]=dec(x,y);
}
}
}
if(f==-1)
{
int inv=power(N,mod-2);reverse(a+1,a+N);
for(int i=0;i<N;i++) a[i]=(ll)a[i]*inv%mod;
}
}
void solve(int n)//x^n
{
if(n==1){f[1]=1;return ;}
int h=n>>1;solve(n>>1);
for(N=1,l=0;N<=n;N<<=1) ++l;
for(int i=1;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
for(int i=1;i<=h;i++) pw[i]=(ll)pw[i-1]*h%mod;
for(int i=0;i<=h;i++) ta[i]=(ll)fac[i]*f[i]%mod;
for(int i=0;i<=h;i++) tb[i]=(ll)pw[i]*inv[i]%mod;
for(int i=h+1;i<N;i++) ta[i]=tb[i]=0;
reverse(tb,tb+h+1);
NTT(ta,1);NTT(tb,1);NTT(f,1);
for(int i=0;i<N;i++) g[i]=(ll)ta[i]*tb[i]%mod;
NTT(g,-1);
for(int i=0;i<=h;i++) g[i]=(ll)g[i+h]*inv[i]%mod;
for(int i=h+1;i<N;i++) g[i]=0;
NTT(g,1);
for(int i=0;i<N;i++) f[i]=(ll)f[i]*g[i]%mod;
NTT(f,-1);
if(n&1)
{
for(int i=n;i;i--) f[i]=pls(f[i-1],(ll)f[i]*(n-1)%mod);
f[0]=(ll)f[0]*(n-1)%mod;
}
}
int main()
{
read(n);read(a);read(b);fac[0]=pw[0]=1;
if(a<1||b<1){puts("0");return 0;}
if(n==1){printf("%d\n",(a==1&&b==1));return 0;}
for(int i=1;i<=n;i++) fac[i]=(ll)fac[i-1]*i%mod;
inv[n]=power(fac[n],mod-2);
for(int i=n-1;~i;i--) inv[i]=(ll)inv[i+1]*(i+1)%mod;
solve(n-1);
ans=(ll)f[a+b-2]*c(a+b-2,a-1)%mod;
printf("%d\n",ans);
return 0;
}