题意
给定 \(m\) 个不同的数 \(a_i\),每次等概率取一个,求取出大小连续的 \(k\) 个数的期望。
题解
令集合 \(S\) 为所有长度为 \(k\) 的可行段的集合,那么题目要求的是所有可行段出现时间的最小值(不会描述)。
考虑 \(\min-max\) 容斥 \(E(\min(S))=\sum \limits _ {T\in S} (-1)^ {|T|+1} E(\max(T))\),即枚举 \(T\in S\),求 \(T\) 中所有元素出现的期望时间。
出现 \(i\) 个数的期望时间:\(\sum \limits _ {j=1} ^ i \frac mi\)。考虑选出 \(i\) 个数中的一个所需要的期望时间即可得出。
这样我们就有了一个 \(dp\),令 \(f[i][j]\) 表示做到第 \(i\) 个数,选了 \(j\) 个,带容斥系数的方案数。转移只需要枚举上一个结尾处的位置即可。可以得到 \(50pts\)。优化一下可以得到 \(70pts\)。
考虑生成函数。我们把每个连续段抠出来,求它的生成函数。最后只要把各段的生成函数乘起来即可。
注意到如果我有一段长度结尾处为 \(k\),那么我的下个结尾端点必定在 \(k+1\) 或者大于 \(2k\)。
因为如果有重合部分,那么我可以选或不选后面一段没有重叠的部分,这样对选的数的个数不会造成影响,但是会影响选出的集合个数。
事实上还有一种情况是可行的:选出 \([1,k],[2,k+1][k+2,2k+1]\)。因为这种情况虽然后面两个长度为 \(k\) 的段连在了一起,但是我如果要选出 \([1,2k+1]\)。必定至少要选出 \(3\) 个段,这样就会有一个段抵消不掉(建议自行画图理解)。
于是我们可以选出若干长度为 \(k+1\) 的不交的段,其中每段可以选前 \(k\) 个,也可以选全部的 \(k+1\) 个。可以枚举选出的段数 \(i\),对于每段,之后有 \(k\) 个位置不能被选,即总共有 \(ik\) 个位置不能选。
但是这样不包含选出末位 \(k\) 个的情况但倒数 \(k+1\) 个不选的情况。对于这种情况,我们只需要钦定最后 \(k\) 个一定选,再套用之前的方法即可,这样就可以保证前面的段和最后 \(k\) 个不相交。所以式子是长这样的:
\[g_n(x)=\sum _ {i=0} ^ {\lfloor \frac {n} {k+1} \rfloor} \binom {n-ik} i x^{ik}(x-1)^i\\ f_n(x)=g_n(x)-x^kg_{n-k}(x) \]
这样就可以在 \(O(\frac {m^2}{k^2})\) 的时间内算出生成函数,有 \(90pts\),常数好可以通过。
现在的时间复杂度瓶颈在于计算 \(g_n(x)\)。考虑分治,每次递归计算左右两部分的 \(i\),并且右边需要额外乘上 \(x^{lk}(x-1)^l\)(\(l\) 为左边区间长度),使用多项式乘法。这样复杂度大概就是 \(O(m\log^2m)\)。目前 loj rank1。
代码:
#include <bits/stdc++.h>
using namespace std;
#define Re register
#define Mod 998244353
#define mkp(a,b) make_pair(a,b)
typedef pair<int,int> pii;
inline int read() {
int x=0;
char ch=getchar();
while (!isdigit(ch)) ch=getchar();
while (isdigit(ch)) x=x*10+ch-'0',ch=getchar();
return x;
}
#define N 600010
int f[N],g[N],fac[N],inv[N],Inv[N];
inline void init(int n) {
fac[0]=fac[1]=inv[0]=inv[1]=Inv[0]=Inv[1]=1;
for (int i=2;i<=n;i++) fac[i]=1LL*fac[i-1]*i%Mod;
for (int i=2;i<=n;i++) inv[i]=1LL*(Mod-Mod/i)*inv[Mod%i]%Mod;
for (int i=2;i<=n;i++) Inv[i]=1LL*Inv[i-1]*inv[i]%Mod;
}
inline int Pow(int a,int b,int p=Mod) {
int res=1;
for (;b;b>>=1,a=1LL*a*a%p)
if (b&1) res=1LL*res*a%p;
return res;
}
const int G=3,invG=332748118;
int A[N],B[N],rev[N];
inline int Dec(int x,int y) {return x-y>=0?x-y:x-y+Mod;}
inline int Pls(int x,int y) {return x+y>=Mod?x+y-Mod:x+y;}
int GPow[2][20][N];
inline void InitG() {
for (int p=1;p<=19;p++) {
int buf1=Pow(G,(Mod-1)/(1<<p));
int buf0=Pow(invG,(Mod-1)/(1<<p));
GPow[1][p][0]=GPow[0][p][0]=1;
for (int i=1;i<(1<<p);i++)
GPow[1][p][i]=1LL*GPow[1][p][i-1]*buf1%Mod,
GPow[0][p][i]=1LL*GPow[0][p][i-1]*buf0%Mod;
}
}
inline int C(int n,int m) {
if (n<m) return 0;
return 1LL*fac[n]*Inv[m]%Mod*Inv[n-m]%Mod;
}
inline void Add(int &x,int y) {x+y<Mod?x+=y:x+=y-Mod;}
inline void NTT(int *a,int len,int f) {
int k=0;
while ((1<<k)<len) ++k;
for (Re int i=0;i<len;++i) {
rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
if (i<rev[i]) swap(a[i],a[rev[i]]);
}
for (Re int l=2,cnt=1;l<=len;l<<=1,cnt++) {
int m=l>>1;
for (Re int i=0;i<len;i+=l) {
int *buf=GPow[f^1][cnt];
for (Re int j=0;j<m;j++,buf++) {
int p=a[i+j],q=1LL*(*buf)*a[i+j+m]%Mod;
a[i+j]=Pls(p,q),a[i+j+m]=Dec(p,q);
}
}
}
}
inline void DFT(int *a,int len) {
NTT(a,len,0);
}
inline void IDFT(int *a,int len) {
NTT(a,len,1); int inv=Pow(len,Mod-2);
for (int i=0;i<len;i++) a[i]=1LL*a[i]*inv%Mod;
}
inline vector<int> Mul(vector<int> a,vector<int> b) {
int n=a.size()+1,m=b.size()+1,len=1; while (len<n+m) len<<=1;
for (int i=0;i<n-1;i++) A[i]=a[i]; for (int i=n-1;i<len;i++) A[i]=0;
for (int i=0;i<m-1;i++) B[i]=b[i]; for (int i=m-1;i<len;i++) B[i]=0;
DFT(A,len),DFT(B,len); for (int i=0;i<len;i++) A[i]=1LL*A[i]*B[i]%Mod; IDFT(A,len);
vector<int> ans(len); for (int i=0;i<len;i++) ans[i]=A[i];
while (!ans.empty() && !ans.back()) ans.pop_back(); return ans;
}
inline vector<int> Pow_(int x) {
vector<int> vec(x+1);
for (int i=0;i<=x;i++) vec[i]=((x-i)&1?Mod-C(x,i):C(x,i));
return vec;
}
inline vector<int> solve(int l,int r,int n,int k) {
if (l==r) {
vector<int> v(1);
v[0]=C(n-l*k,l);
return v;
}
vector<int> ansl=solve(l,(l+r)>>1,n,k);
vector<int> ansr=solve(((l+r)>>1)+1,r,n,k);
ansr=Mul(ansr,Pow_(((l+r)>>1)-l+1));
vector<int> ans(ansr.size()+(((l+r)>>1)-l+1)*k);
for (int i=0;i<ansl.size();i++) ans[i]=ansl[i];
for (int i=0;i<ansr.size();i++) Add(ans[i+(((l+r)>>1)-l+1)*k],ansr[i]);
return ans;
}
inline vector<int> GetP(int n,int k) {
vector<int> res=solve(0,n/(k+1),n,k);
return res;
}
inline vector<int> GetPoly(int n,int k) {
vector<int> v1=GetP(n,k); int siz1=v1.size();
vector<int> v2=GetP(n-k,k); int siz2=v2.size();
v1.resize(max(siz2+k,siz1));
for (int i=0;i<siz2;i++) Add(v1[i+k],Mod-v2[i]);
while (!v1.empty() && !v1.back()) v1.pop_back();
return v1;
}
vector<vector<int>> vec;
inline vector<int> Merge() {
priority_queue<pii,vector<pii>,greater<pii>> heap;
for (int i=0;i<vec.size();i++) heap.push(mkp(vec[i].size(),i));
while (!heap.empty()) {
pii vec1=heap.top(); heap.pop();
if (heap.empty()) return vec[vec1.second];
pii vec2=heap.top(); heap.pop();
vec[vec1.second]=Mul(vec[vec1.second],vec[vec2.second]);
heap.push(mkp(vec[vec1.second].size(),vec1.second));
}
}
int vis[N];
int main() {
int n=read(),k=read(),m=n<<1,cnt=0; init(m); InitG();
for (int i=1;i<=n;i++) g[i]=(g[i-1]+1LL*n*inv[i]%Mod)%Mod;
for (int i=1;i<=n;i++) vis[read()]=true;
for (int i=1;i<=m;i++) {
if (vis[i]) {++cnt;continue;}
if (cnt>=k) vec.push_back(GetPoly(cnt,k));
cnt=0;
}
vector<int> ans=Merge(); int res=0;
for (int i=0;i<ans.size();i++) Add(res,1LL*(Mod-ans[i])*g[i]%Mod);
return printf("%d\n",res),0;
}