2019牛客暑期多校训练营(第五场)基础DP+组合数 矩阵快速幂

  1. G题:
    题目链接:https://ac.nowcoder.com/acm/contest/885/G
    大意: 给你两个由数字组成的字符串(S),(T) 长度为(1e3),问你S中有多少个子序列的值大于字符串T
  2. 思路:开dp[i][j]二维数组,i维护的是t串长度为i的前缀,j维护s串中长度为j的前缀。存储的值是s串的前j缀中大于t串前i缀且长度也为i的子序列的数目。记住,这里维护的是长度也为i的子序列数量。每次状态转移的时候就是要先算上s串中以第j位为结尾的且满足其他条件的子序列数量,再加上之前的(加上之前的就和前缀和一样),其中由于对应位大于小于等于的情况,得记录大于等于和大于两种情况的子序列数量。
  3. 代码有很多细节,看看注释:
#include<bits/stdc++.h>
#define ll long long
#define ms0(x) memset(x,0,sizeof(x))
#define ms-1(x) memset(x,-1,sizeof(x))
using namespace std;
ll mod= 998244353;
const int maxn = 3e3+5;
ll dp1[maxn][maxn]; //大于等于 
ll dp2[maxn][maxn]; //大于
char s[maxn],t[maxn];
int pos0[maxn];
void add(ll &a,ll b) //※※※※这样做应该快一点???  
{
    a+=b;
    if(a>=mod)a-=mod;
}
void sub(ll &a,ll b)
{
    a-=b;
    if(a<0)a+=mod;
}
ll C[maxn][maxn]; 
void predo()    {
    for (int i=C[0][0]=1;i<maxn;++i)
        for (int j=C[i][0]=1;j<=i;++j)
            C[i][j]=(C[i-1][j-1]+C[i-1][j])%mod;
}//预处理 存储所需的组合数的答案 
int main()
{
    predo();
    int T;
    cin>>T;
    while(T--)
    {
        int m,n;
        for(int j=0;j<=n;j++)
            for(int i=0;i<=m;i++)
                dp1[j][i]=dp2[j][i]=0;   //(+_+)?这里要注意 原来用memset居然会TLE。。。直接用for两层循环清空 
        scanf("%d%d",&m,&n);
        scanf("%s%s",s,t);
        int cnt=0;
        for(int j=0;j<m;j++)			//记录每个0在串s的位置 
        {
            if(s[j]=='0')
            {
                pos0[cnt++]=j+1;
            }
        }
        ll ans=0;
        for(int j=1;j<=m;j++)			//对dp[1]预处理,否者下面的转移方程过程中,dp全是0 
        {
            if(s[j-1]>t[0])
            {
                dp2[1][j]++;
                dp1[1][j]++;      
            } 
            if(s[j-1]==t[0])
            {
                dp1[1][j]++;
            }
            add(dp2[1][j],dp2[1][j-1]);
            add(dp1[1][j],dp1[1][j-1]);
        }
        for(int i=2;i<=n;i++)
        {
            for(int j=min(i,m);j<=m;j++)
            {
                if(s[j-1]>t[i-1])
                {
                    dp2[i][j]=dp1[i-1][j-1];
                    dp1[i][j]=dp1[i-1][j-1];      
                } 
                if(s[j-1]==t[i-1])
                {
                    dp2[i][j]=dp2[i-1][j-1];
                    dp1[i][j]=dp1[i-1][j-1];
                }
                if(s[j-1]<t[i-1])
                {
                    dp2[i][j]=dp2[i-1][j-1];
                    dp1[i][j]=dp2[i-1][j-1];
                }
                add(dp2[i][j],dp2[i][j-1]);
                add(dp1[i][j],dp1[i][j-1]);
            }
        }
        add(ans,dp2[n][m]);
        for(int j=n+1;j<=m;j++)			//最后要加组合数,前面只是加了s串中和t串长度相同的且大于t的子序列的数量 
        {
            add(ans,C[m][j]);
            for(int i=0;i<cnt;i++)
            {
                sub(ans,C[m-pos0[i]][j-1]);	//排除含前导0的子序列 
            }
        }
        printf("%lld\n",ans);
    }
    return 0;
      
}
上一篇:Codeforces Round #552 (Div. 3) F题


下一篇:luogu 同花顺