题意:给你m个数(m<=100),每个数的素因子仅来自于前t(t<=100)个素数,问这m个数的非空子集里,满足子集里的数的积为完全平方数的有多少个。
一开始就想进去里典型的dp世界观里,dp[n][mask]表示前n个数里为mask的有多少个,但显然这里t太大了。然后又YY了很多很多。像m少的时候应该用的是高消。即对每个因子列一个xor方程,然后高斯消元,其中*元的个数就是可以随便取的,所以答案是2^(*元个数),然后把空集的减掉,就是2^(*元)-1,不过大数是必须的。
#include <iostream>
#include <cstring>
#include <string>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std; #define ll long long
#define maxn 110 int t,m;
int b[maxn]; int p[1000+50];
int tot;
int vis[1000+50]; void getPrime()
{
memset(vis,0,sizeof(vis));
tot=0;
for(int i=2;i<=1000;++i){
if(!vis[i]) p[tot++]=i;
for(int j=0;j<tot&&i*p[j]<=1000;++j){
vis[i*p[j]]=true;
if(!(i%p[j])) break;
}
}
} int a[maxn][maxn]; int gauss()
{
int row=t,col=m;
int fix=0;
int cur=0;
int row_choose;
for(int i=0;i<col&&fix<row;++i){
row_choose=-1;
for(int j=cur;j<row;++j){
if(a[j][i]==1) row_choose=j;
}
if(row_choose==-1) {
continue;
}
++fix;
swap(a[row_choose],a[cur]);
for(int j=0;j<row;++j){
if(j==cur) continue;
if(a[j][i]==1) {
for(int k=i;k<col;++k){
a[j][k]^=a[cur][k];
}
}
}
++cur;
}
return col-fix;
} const int base=10000;
const int width=4;
const int N=100;
const int static ten[width]={1,10,100,1000};
struct bint
{
int ln;
int v[N];
bint(int r=0){
for(ln=0;r>0;r/=base) v[ln++]=r%base;
}
bint & operator = (const bint &r){
memcpy(this,&r,(r.ln+1)*sizeof(int));
return *this;
}
}; bint operator + (const bint &a,const bint &b){
bint res;int i,cy=0;
for(i=0;i<a.ln||i<b.ln||cy>0;i++){
if(i<a.ln) cy+=a.v[i];
if(i<b.ln) cy+=b.v[i];
res.v[i]=cy%base;cy/=base;
}
res.ln=i;
return res;
}
bint operator- (const bint & a, const bint & b){
bint res; int i, cy = 0;
for (res.ln = a.ln, i = 0; i < res.ln; i++) {
res.v[i] = a.v[i] - cy;
if (i < b.ln) res.v[i] -= b.v[i];
if (res.v[i] < 0) cy = 1, res.v[i] += base;
else cy = 0;
}
while (res.ln > 0 && res.v[res.ln - 1] == 0) res.ln--;
return res;
} bint operator* (const bint & a, const bint & b){
bint res; res.ln = 0;
if (0 == b.ln) { res.v[0] = 0; return res; }
long long i, j, cy;
for (i = 0; i < a.ln; i++) {
for (j = cy = 0; j < b.ln || cy > 0; j++, cy /= base) {
if (j < b.ln) cy += a.v[i] * b.v[j];
if (i + j < res.ln) cy += res.v[i + j];
if (i + j >= res.ln) res.v[res.ln++] = cy % base;
else res.v[i + j] = cy % base;
}
}
return res;
} void write(const bint & v){
int i;
printf("%d", v.ln == 0 ? 0 : v.v[v.ln - 1]);
for (i = v.ln - 2; i >= 0; i--)
printf("%04d", v.v[i]); // ! 4 == width
// printf("\n");
} int main()
{
getPrime();
while(~scanf("%d%d",&t,&m)){
memset(a,0,sizeof(a));
for(int i=0;i<m;++i){
scanf("%d",b+i);
for(int j=0;j<t;++j){
int cnt=0;
while(b[i]%p[j]==0){
b[i]/=p[j];cnt^=1;
}
a[j][i]=cnt;
}
}
int res=gauss();
bint x(1);
for(int i=0;i<res;++i){
x=x*2;
}
x=x-1;
write(x);puts("");
}
return 0;
}