感觉这道题纯动态规划的边界等问题非常麻烦,所以这里采用记忆化搜索。
题目大意
给出 \(n,m,k\) 及 \(val_0\cdots val_m\),定义一个值 \(\in [0,m]\) 的序列 \(a\),其权值为 \(\prod\limits_{i=1}^{n} val_{a_i}\)
我们称 \(S\) 满足条件当且仅当 \(S=\sum\limits_{i=1}^{n} 2^{a_i}\) 的二进制表示中,\(1\) 的个数小于等于 \(k\)。此时,也称序列 \(a\) 为合法序列。
求所有合法序列 \(a\) 的权值和 \(\mod 998244353\) 的结果。
题目分析
令 \(dfs(bit,now,x,y)\) 表示:
\(S\) 从低到高二进制的 \(bit\) 位中,用了序列 \(a\) 的前 \(now\) 个数,此时 \(S\) 二进制下有 \(x\) 个 \(1\),上一位(第 \(bit+1\) 位)进位为 \(y\)。
\(mem[biw][now][x][y]\) 则储存答案。
于是,我们有:
\[mem[bit][now][x][y]=\sum\limits_{i=0}^{n-now}{mem[bit][now+i][x+(y+i)\%2][\left\lfloor\frac{y+i}{2}\right\rfloor])\times sum[bit][i]\times C_{now+i}^{i}} \]其中 \(C_{i}^{j}\) 表示组合数,\(sum[i][j]\) 表示:
for(register int i=0;i<=m;i++)
{
sum[i][0]=1;
for(register int j=1;j<=n;j++)
{
sum[i][j]=sum[i][j-1]*val[i]%mod;
}
}
可以看到,\(sum[i][j]\) 主要作用类似于前缀和,目的是简化计算。
边界部分:
当前转移到 \(dfs(bit,now,x,y)\)。
- 若 \(now=n\):
当 \(x+getcnt(y)>k\) 时,返回 \(0\)。表示不需要继续转移了。
否则返回 \(1\)。
-
若 \(bit>m\) 则直接返回。
-
若 \(mem[bit][now][x][y]\) 有数则直接返回该数。
代码
//2021/11/30
//2021/12/1
//2021/12/2
#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <cstdio>
#include <climits>//need "INT_MAX","INT_MIN"
#include <cstring>
#define int long long
#define enter() putchar(10)
#define debug(c,que) cerr<<#c<<" = "<<c<<que
#define cek(c) puts(c)
#define blow(arr,st,ed,w) for(register int i=(st);i<=(ed);i++)cout<<arr[i]<<w;
#define speed_up() cin.tie(0),cout.tie(0)
#define endl "\n"
#define Input_Int(n,a) for(register int i=1;i<=n;i++)scanf("%d",a+i);
#define Input_Long(n,a) for(register long long i=1;i<=n;i++)scanf("%lld",a+i);
namespace Newstd
{
inline int read()
{
int x=0,k=1;
char ch=getchar();
while(ch<'0' || ch>'9')
{
if(ch=='-')
{
k=-1;
}
ch=getchar();
}
while(ch>='0' && ch<='9')
{
x=(x<<1)+(x<<3)+ch-'0';
ch=getchar();
}
return x*k;
}
inline void write(int x)
{
if(x<0)
{
putchar('-');
x=-x;
}
if(x>9)
{
write(x/10);
}
putchar(x%10+'0');
}
}
using namespace Newstd;
using namespace std;
const int mod=998244353;
const int MA_1=105;
const int MA_2=35;
int val[MA_1];
int C[MA_1][MA_1],sum[MA_1][MA_1];
int mem[MA_1][MA_2][MA_2][MA_2];
int n,m,k;
inline void init()
{
C[0][0]=1;
for(register int i=1;i<=n;i++)
{
C[i][0]=1;
for(register int j=1;j<=i;j++)
{
C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
}
}
}
inline int lowbit(int x)
{
return x&-x;
}
inline int getcnt(int x)
{
int ans(0);
while(x!=0)
{
x-=lowbit(x);
ans++;
}
return ans;
}
//dfs(k,now,x,y)
//S从低到高二进制的 bit 位中,用了数列 a 的前 now 项,且此时 S *有 x 个二进制位为 1,第 now+1 位进了 y 过去
inline int dfs(int bit,int now,int x,int y)
{
if(now==n)
{
if(x+getcnt(y)>k)
{
return 0;
}
return 1;
}
if(bit>m)
{
return 0;
}
if(mem[bit][now][x][y]!=-1)
{
return mem[bit][now][x][y];
}
int ans(0);
for(register int i=0;i<=n-now;i++)
{
ans=(ans+dfs(bit+1,now+i,x+(y+i)%2,(y+i)/2)*sum[bit][i]%mod*C[now+i][i]%mod)%mod;
}
return mem[bit][now][x][y]=ans;
}
#undef int
int main(void)
{
#define int long long
memset(mem,-1,sizeof(mem));
n=read(),m=read(),k=read();
init();
for(register int i=0;i<=m;i++)
{
val[i]=read();
}
for(register int i=0;i<=m;i++)
{
sum[i][0]=1;
for(register int j=1;j<=n;j++)
{
sum[i][j]=sum[i][j-1]*val[i]%mod;
}
}
printf("%lld\n",dfs(0,0,0,0));
return 0;
}