正题
题目链接:https://ac.nowcoder.com/acm/contest/11193/G
题目大意
给出\(n\)个长度为\(m\)的数组,然后你每次可以进行差分(不会改变数组长度那种)和前缀和。
如果两个数组可以在模\(998244353\)意义下操作成同一个数组,那么这两个同源,求所有的同源数组。
\(1\leq n\leq 100,1\leq m\leq 1000\)
解题思路
考虑将所有同源的操作成同一种形式。
注意到对于差分来说数组的第一个位置是保持不变的,更具体地说,其实是数组中从前往后第一个不是\(0\)的数字是不会变的。
设为\(a_k\),然后此时每次差分都会令\(a_{k+1}=a_{k+1}-a_k\)。首先对于两个数组来说肯定得有\(k=k'\)且\(a_{k}=a'_{k'}\),然后再考虑后面的。
为了方便比较我们之间让\(a_{k+1}\)一直差分直到其等于\(0\),此时我们就可以直接拿两个数组比较了。
快速处理\(k\)阶差分的做法就直接上\(NTT\)乘上一个\((1-x)^k\)(二项式展开)就好了
时间复杂度:\(O(nm\log m)\)
Hard Version要任意模加Lucas先润了
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define ull unsigned long long
using namespace std;
const ll N=4e3+10,P=998244353;
struct poly{
ll a[N];ll n;
}F,G;
ll n,m,cnt,a[N][N],r[N],col[N],num[N];
bool v[N];
ll power(ll x,ll b){
ll ans=1;
while(b){
if(b&1)ans=ans*x%P;
x=x*x%P;b>>=1;
}
return ans;
}
void NTT(ll *f,ll n,ll op){
for(ll i=0;i<n;i++)
if(i<r[i])swap(f[i],f[r[i]]);
for(ll p=2;p<=n;p<<=1){
ll len=p>>1,tmp=power(3,(P-1)/p);
if(op==-1)tmp=power(tmp,P-2);
for(ll k=0;k<n;k+=p){
ll buf=1;
for(ll i=k;i<k+len;i++){
ll tt=f[i+len]*buf%P;
f[i+len]=(f[i]-tt+P)%P;
f[i]=(f[i]+tt)%P;
buf=buf*tmp%P;
}
}
}
if(op==-1){
ll invn=power(n,P-2);
for(ll i=0;i<n;i++)
f[i]=f[i]*invn%P;
}
return;
}
void mul(poly &a,poly &b){
ll n=1;
while(n<=a.n+b.n)n<<=1;
for(ll i=0;i<n;i++)
r[i]=(r[i>>1]>>1)|((i&1)?(n>>1):0);
NTT(a.a,n,1);NTT(b.a,n,1);
for(ll i=0;i<n;i++)
a.a[i]=a.a[i]*b.a[i]%P;
NTT(a.a,n,-1);return;
}
void Diff(ll *a,ll n,ll k){
if(!k)return;
memset(F.a,0,sizeof(F.a));
memset(G.a,0,sizeof(G.a));
for(ll i=0;i<n;i++)F.a[i]=a[i];
for(ll i=0,ans=1;i<=min(k,n-1);i++){
if(i)ans=ans*(k-i+1)%P*power(i,P-2)%P;
G.a[i]=(i&1)?(P-ans):ans;
}
F.n=n;G.n=min(k+1,n);mul(F,G);
for(ll i=0;i<n;i++)
a[i]=F.a[i];
return;
}
signed main()
{
scanf("%lld%lld%lld",&n,&m,&a[0][0]);
for(ll i=1;i<=n;i++){
memset(v,0,sizeof(v));
for(ll j=0;j<m;j++)
scanf("%lld",&a[i][j]);
ll z;
for(z=1;z<m;z++)
if(a[i][z-1])break;
ll k=a[i][z]*power(a[i][z-1],P-2)%P;
Diff(a[i],m,k);
for(ll j=1;j<i;j++){
bool flag=0;
for(ll k=0;k<m;k++)
if(a[i][k]!=a[j][k]){flag=1;break;}
if(!flag){col[i]=col[j];num[col[i]]++;break;}
}
if(!col[i]){col[i]=++cnt;num[cnt]=1;}
}
printf("%lld\n",cnt);
for(ll i=1;i<=cnt;i++){
printf("%lld\n",num[i]);
for(ll j=1;j<=n;j++)
if(col[j]==i)printf("%lld ",j-1);
putchar('\n');
}
return 0;
}