这位巨佬的博客还是比我好多了
#include<bits/stdc++.h> using namespace std; const int N=11000000,mod=998244353; long long a,b,n,m,inc[N],jc[N],am[N],bm[N],fheng[N],fshu[N],ans; long long quick_pow(long long x,long long y) { long long res=1; while(y) { if(y&1) { res=res*x%mod; } y>>=1; x=x*x%mod; } return res; } long long C(int n,int m) { return jc[n]*inc[m]%mod*inc[n-m]%mod; } int main() { scanf("%lld%lld%lld%lld",&n,&m,&a,&b); a%=mod;b%=mod; am[0]=bm[0]=jc[0]=inc[0]=1; for(int i=1;i<=n+m;i++) { jc[i]=jc[i-1]*i%mod; am[i]=am[i-1]*a%mod; bm[i]=bm[i-1]*b%mod; } inc[n+m]=quick_pow(jc[n+m],mod-2); for(int i=n+m-1;i>=1;i--) { inc[i]=inc[i+1]*(i+1)%mod; } for(int i=1;i<=n;i++) { scanf("%lld",&fheng[i]);fheng[i]%=mod; } for(int i=1;i<=m;i++) { scanf("%lld",&fshu[i]);fshu[i]%=mod; } for(int i=1;i<=n;i++) { ans+=fheng[i]*C(n+m-1-i,m-1)%mod*am[m]%mod*bm[n-i]%mod; ans%=mod; } for(int i=1;i<=m;i++) { ans+=fshu[i]*C(n+m-1-i,n-1)%mod*am[m-i]%mod*bm[n]%mod; ans%=mod; } printf("%lld\n",ans); return 0; }
#include<bits/stdc++.h> using namespace std; const int N=1100000,M=N<<1; int n,aa,bb,cnt=1,head[N],circle[N<<1],num,c1,c2,du[N],e,f[2][N][2],ans=0x7fffffff,sam[N]; bool vis[N],used[N]; inline int read() { int s=0,w=1;char ch=getchar(); while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();} while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();} return s*w; } inline int _min(int a,int b) { return a<b?a:b; } struct bian { int nxt,to,w[2]; }b[M]; void add(int from,int to,bool g) { du[from]++;du[to]++; b[++cnt].w[g]=aa; b[cnt].w[g^1]=bb; b[cnt].nxt=head[from]; b[cnt].to=to; head[from]=cnt; } void findcir1(int u,int fa) { vis[u]=1; for(int i=head[u],v;i;i=b[i].nxt) { v=b[i].to; if(v==fa) continue; else if(vis[v]) { c1=u,c2=v;e=i; return; } else { findcir1(v,u); } } } bool findcir2(int u) { bool ok=0; for(int i=head[u],v;i;i=b[i].nxt) { if(used[i]) continue; v=b[i].to; circle[++num]=i; circle[++num]=i^1; used[i]=used[i^1]=1; if(du[v]==2) { circle[num--]=0; circle[num--]=0; continue; } if(v==c1) { ok=1;return ok; } ok=findcir2(v); if(!ok) { circle[num--]=0; circle[num--]=0; continue; } else return ok; } return ok; } void DP(int x,int fa,int o) { int sum=0; f[o][x][0]=0; f[o][x][1]=sam[x]; for(int i=head[x];i;i=b[i].nxt) { if(i==e||(i^1)==e||b[i].to==fa)continue; DP(b[i].to,x,o); sum+=min(f[o][b[i].to][0],f[o][b[i].to][1]); f[o][x][0]+=f[o][b[i].to][1]; } f[o][x][1]+=sum; } int main() { n=read();aa=read();bb=read(); for(int i=1,u,v;i<=n;i++) { u=read();v=read(); add(u,v,0); add(v,u,1); sam[u]+=aa; sam[v]+=bb; } findcir1(1,0); findcir2(c1); memset(used,0,sizeof(used)); DP(c1,0,0); DP(c2,0,1); printf("%d\n",_min(f[0][c1][1],f[1][c2][1])); return 0; }
#include<bits/stdc++.h> using namespace std; inline long long read() { long long s=0,w=1;char ch=getchar(); while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();} while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();} return s*w; } const int N=11000000; long long n,m,ans,p[N]; bool flag[N]; int main() { n=read();m=read(); scanf("%lld%lld",&n,&m); for(int i=1;i*i<=n;i++) { for(int j=1;j*i*i<=n;j++) { p[i*i*j]=j; } flag[i*i]=1; } int s=sqrt(m); for(int i=1;i<=n;i++) { if(flag[i]) { if(s&1) ans--; else ans++; } else { int x=sqrt(m/p[i]); if(x&1) ans--; else ans++; } } printf("%lld\n",ans) ; return 0; }