题目链接
https://www.luogu.org/problem/P4707
题解
最近被神仙题八连爆了……
首先Min-Max容斥肯定都能想到,问题是这题要用一个扩展版的——Kth Min-Max容斥
这个东西需要对Min-Max容斥的本质有着比较深刻的理解。
首先我们从另一个角度证明Min-Max容斥的正确性: \(\max(S)=\sum_{T\in S}f(|T|)\min(T)\), 对于第\((x+1)\)大来说它被计算的次数是\(\sum_{k\ge 0} {x\choose k}f(k+1)\),则有\([x=0]=\sum_{k\ge 0} {x\choose k}f(k+1)\), 二项式反演之后令\(f(k)=(-1)^{k-1}\)即可达到目的。
那么考虑把刚才式子中\([x=0]\)换成\([x=k-1]\)会怎样? 依然采取构造系数的思路,得出的结果是: \(f[x]=(-1)^{x-k}{x-1\choose k-1}\).
问题相当于求第\(K\)大的期望,所以可以转化成子集最小值: \(\text{kthmax}(S)=\sum_{T\in S}(-1)^{|T|-K}{|T|-1\choose k-1}\min(T)=\sum_{T\in S}\frac{m}{p_T}(-1)^{|T|-K}{|T|-1\choose k-1}\), 其中\(p_T=\sum_{i\in T} p_i\)
这个东西可以用一个dp来搞: 设\(dp[i][j][k]\)表示前\(i\)个数\(p\)之和为\(j\), 当组合数的下指标为\(k\)时每种方案乘以容斥系数之和。
考虑转移: 如果第\(i\)个元素不属于\(T\), 显然是加上\(dp[i-1][j][k]\); 如果属于\(T\), 那么要求的组合数\({|T|-1\choose k-1}={|T|-2\choose k-1}+{|T|-2\choose k-2}\), 对于前一项直接是\(dp[i-1][j-p_i][k]\), 对于后一项因为是从\(k-1\)转移过来,所以要多乘个\(-1\), 最终结果是减去\((dp[i-1][j-p_i][k]-dp[i-1][j-p_i][k-1])\).
然而这个dp的边界问题很难处理。这时我们不妨脱离实际问题,去思考一下组合数在指标为负数时的定义,直接代入可得\(dp[0][0][k]=-1 (k>0)\).
(然而感觉这种边界设置方法并不严谨)
时间复杂度\(O(nm(n-k))\).
代码
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cassert>
#include<iostream>
#define llong long long
using namespace std;
inline int read()
{
int x=0; bool f=1; char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=0;
for(; isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^'0');
if(f) return x;
return -x;
}
const int N = 1000;
const int S = 1e4;
const int M = 11;
const int P = 998244353;
llong quickpow(llong x,llong y)
{
llong cur = x,ret = 1ll;
for(int i=0; y; i++)
{
if(y&(1ll<<i)) {y-=(1ll<<i); ret = ret*cur%P;}
cur = cur*cur%P;
}
return ret;
}
llong mulinv(llong x) {return quickpow(x,P-2);}
llong a[N+3];
llong dp[S+3][M+2];
int n,m,s;
int main()
{
scanf("%d%d%d",&n,&m,&s); m = n-m+1;
for(int i=1; i<=n; i++) scanf("%lld",&a[i]);
for(int i=1; i<=m; i++) dp[0][i] = P-1;
for(int i=1; i<=n; i++)
{
for(int j=s; j>=a[i]; j--)
{
for(int k=1; k<=m; k++)
{
dp[j][k] = (dp[j][k]+dp[j-a[i]][k-1]-dp[j-a[i]][k]+P)%P;
}
}
}
llong ans = 0ll;
for(int i=1; i<=s; i++)
{
ans = (ans+dp[i][m]*s%P*mulinv(i))%P;
}
printf("%lld\n",ans);
return 0;
}