\(\text{Problem}:\)[PKUWC2018] Slay the Spire
\(\text{Solution}:\)
最优方案显然为:从大往小取强化牌,直到已经打出 \(k-1\) 张牌或强化牌没有了为止,然后从大往小取攻击牌。证明:将一张强化牌换为攻击牌后,因为这张攻击牌不是严格最大,故不优于原来的方案。
首先将强化牌和攻击牌从大到小排序。设 \(f_{i,j,0/1}\) 表示前 \(i\) 张强化牌选了 \(j\) 张,位置 \(i\) 选或不选的所有方案的乘积和(\(0\) 表示不选,\(1\) 表示选,下同),\(g_{i,j,0/1}\) 表示前 \(i\) 张攻击牌选了 \(j\) 张,位置 \(i\) 选或不选的所有方案的和之和。
由于强化牌最多选 \(k-1\) 张,而抽出的强化牌数量可能大于 \(k-1\),故此时用 \(f\) 转移不是最优的。再设 \(h_{i,j}\) 表示前 \(i\) 张强化牌选了 \(j\) 张的所有方案的乘积和(\(j\geq k\))。有以下转移:
\[f_{i,j,0}=f_{i-1,j,0}+f_{i-1,j,1},j<k\\ f_{i,j,1}=(f_{i-1,j-1,0}+f_{i-1,j-1,1})\times w_{i},j<k\\ h_{i,j}=h_{i-1,j}+f_{i-1,k-1,1}\times \binom{n-i+1}{j-k+1},j\geq k\\ g_{i,j,0}=g_{i-1,j,0}+g_{i-1,j,1}\\ g_{i,j,1}=g_{i-1,j-1,0}+g_{i-1,j-1,1}+w_{i}\times\binom{i-1}{j-1} \]注意 \(h\) 的转移式,当 \(k=1\) 时需要特判。
解释一下为什么不对 \(g\) 记类似 \(h\) 对 \(f\) 的转移。由于我们优先选择强化牌,故当存在小于等于 \(k-1\) 张强化牌时必然全部选择强化牌,临界点只有 \(k-1\)。而对于攻击牌,它的临界点是随着强化牌的数量改变的,故无法求出类似 \(h\) 的统计数组。
然后枚举抽出的强化牌个数 \(i\),可以得到抽出的攻击牌个数 \(m-i\)。根据最优方案计算答案即可。令 \(n,m,k\) 同阶,则时间复杂度为 \(O(n^{2})\),可以通过。
\(\text{Code}:\)
#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=3010, Mod=998244353;
inline int read()
{
int s=0, w=1; ri char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
return s*w;
}
int n,m,K,w1[N],w2[N],f[N][N][2],g[N][N][2],h[N][N],fac[N+5],inv[N+5];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline int C(int x,int y) { if(x<y||x<0||y<0) return 0; return 1ll*fac[x]*inv[x-y]%Mod*inv[y]%Mod; }
inline bool cp(int x,int y) { return x>y; }
signed main()
{
fac[0]=1;
for(ri int i=1;i<=N;i++) fac[i]=1ll*fac[i-1]*i%Mod;
inv[N]=ksc(fac[N],Mod-2);
for(ri int i=N;i;i--) inv[i-1]=1ll*inv[i]*i%Mod;
for(ri int T=read();T;T--)
{
n=read(), m=read(), K=read();
for(ri int i=1;i<=n;i++) w1[i]=read();
for(ri int i=1;i<=n;i++) w2[i]=read();
sort(w1+1,w1+1+n,cp);
sort(w2+1,w2+1+n,cp);
for(ri int i=0;i<=n;i++)for(ri int j=0;j<=n;j++)f[i][j][0]=f[i][j][1]=g[i][j][0]=g[i][j][1]=h[i][j]=0;
f[0][0][0]=1;
for(ri int i=1;i<=n;i++)
{
for(ri int j=0;j<=n;j++)
{
f[i][j][0]=(f[i-1][j][0]+f[i-1][j][1])%Mod;
if(j)
{
if(j<K) f[i][j][1]=1ll*(f[i-1][j-1][0]+f[i-1][j-1][1])*w1[i]%Mod;
else h[i][j]=(h[i-1][j]+1ll*f[i-1][K-1][1]*C(n-i+1,j-K+1)%Mod)%Mod;
}
g[i][j][0]=(g[i-1][j][0]+g[i-1][j][1])%Mod;
if(j) g[i][j][1]=(1ll*(g[i-1][j-1][0]+g[i-1][j-1][1])%Mod+1ll*C(i-1,j-1)*w2[i]%Mod)%Mod;
}
}
if(K==1)
{
for(ri int i=1;i<=n;i++) h[n][i]=C(n,i);
}
int ans=0;
for(ri int i=0;i<=m;i++)
{
int k=K-min(K-1,i);
if(k>m-i || i>n || m-i>n) continue;
int gg=0;
for(ri int j=k;j<=n;j++) gg=(gg+1ll*g[j][k][1]*C(n-j,m-i-k)%Mod)%Mod;
int ff=0;
if(i<K) ff=(f[n][i][0]+f[n][i][1])%Mod;
else ff=h[n][i];
ans=(ans+1ll*ff*gg%Mod)%Mod;
}
printf("%d\n",ans);
}
return 0;
}