本来以为这道题会非常难调,但是没想到调了不到 5 分钟就 A 了.
由于基于多项式的运算都可以方便地进行封装,所以细节就不是很多(或者说几乎没有细节)
题意:给定一棵树,每个点有点权,求对于所有大小为 $m$ 的独立集的点权之积的和.
数据范围:$n,m \leqslant 8 \times 10^4$.
先考虑一个十分显然的 $O(n^2)$ 暴力:
令 $f[x][i],g[x][i]$ 分别表示点 $x$ 选/不选的情况下独立集大小为 $i$ 的点积 之和.
考虑将 $x$ 与 $x$ 的一个儿子 $y$ 合并:$f[x][i+j]=f[x][i] \times f[y][j]$,$g$ 同理.
然后 $x$ 的初始值是:$f[x][1]=w[x],g[x][0]=1$.
树形DP 卡一下上界复杂度是 $O(n^2)$ 的.
不难发现,上述 $f[x][i+j] = f[x][i] \times f[y][j]$ 是一个卷积的形式.
如果是菊花图或者链的话可以直接用 NTT/分治NTT 来做.
正解的话考虑进行轻重路径剖分:
对于一条重链来说,先求出该重链中每个点轻儿子为根的多项式 $f,g$,然后对于重链中每个点都将其轻儿子与该点合并.
最后对于一条重链进行分治,求出该重链链顶为根的多项式.
分析一下时间复杂度:
考虑一条重链链顶为根的子树会被卷多少次:其祖先中每一条重链都会将其贡献一次.
那么树链剖分中一个点有 $O(\log n)$ 个祖先,而每次卷积的时候对链分治的复杂度是 $O(n \log^2 n)$.
总复杂度就是 $O(n \log^3 n)$,但是由于树链剖分的常数比较小,跑的并不慢.
code:
#include <queue> #include <cstdio> #include <vector> #include <cstring> #include <algorithm> #define N 1000009 #define ll long long #define mod 998244353 #define pb push_back #define setIO(s) freopen(s".in","r",stdin) using namespace std; int m; int A[N<<2],B[N<<2]; int tim,edges,n; int size[N],son[N],top[N],hd[N],to[N<<1],nex[N<<1],fa[N],dep[N]; int dfn[N],bu[N],si[N],val[N]; void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } int ADD(int x,int y) { return (ll)(x+y)%mod; } int DEC(int x,int y) { return (ll)(x-y+mod)%mod; } int MUL(int x,int y) { return (ll)x*y%mod; } int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) { if(y&1) tmp=(ll)tmp*x%mod; } return tmp; } int get_inv(int x) { return qpow(x,mod-2); } void NTT(int *a,int len,int op) { for(int i=0,k=0;i<len;++i) { if(i>k) { swap(a[i],a[k]); } for(int j=len>>1;(k^=j)<j;j>>=1); } for(int l=1;l<len;l<<=1) { int wn=qpow(3,(mod-1)/(l<<1)); if(op==-1) wn=get_inv(wn); for(int i=0;i<len;i+=l<<1) { int w=1; for(int j=0;j<l;++j) { int x=a[i+j],y=(ll)w*a[i+j+l]%mod; a[i+j]=(ll)(x+y)%mod; a[i+j+l]=(ll)(x-y+mod)%mod; w=(ll)w*wn%mod; } } } if(op==-1) { int iv=get_inv(len); for(int i=0;i<len;++i) { a[i]=(ll)a[i]*iv%mod; } } } struct poly { int len; vector<int>a; poly() { len=0,a.clear(); } void push(int x) { a.pb(x),++len; } void resize(int x) { a.resize(x),len=x; } poly operator*(const poly &b) const { int lim; for(lim=1;lim<len+b.len-1;lim<<=1); for(int i=0;i<lim;++i) A[i]=B[i]=0; for(int i=0;i<len;++i) A[i]=a[i]; for(int i=0;i<b.len;++i) B[i]=b.a[i]; NTT(A,lim,1),NTT(B,lim,1); for(int i=0;i<lim;++i) { A[i]=(ll)A[i]*B[i]%mod; } NTT(A,lim,-1); poly c; for(int i=0;i<len+b.len-1;++i) { c.push(A[i]); } if(c.len>m+1) c.resize(m+1); return c; } poly operator+(const poly &b) const { poly c; c.resize(max(len,b.len)); for(int i=0;i<c.len;++i) c.a[i]=0; for(int i=0;i<c.len;++i) { if(i<len) c.a[i]=ADD(c.a[i],a[i]); if(i<b.len) c.a[i]=ADD(c.a[i],b.a[i]); } return c; } poly operator-(const poly &b) const { poly c; c.resize(max(len,b.len)); for(int i=0;i<c.len;++i) c.a[i]=0; for(int i=0;i<c.len;++i) { if(i<len) c.a[i]=ADD(c.a[i],a[i]); if(i<b.len) c.a[i]=DEC(c.a[i],b.a[i]); } return c; } }f0[N],f1[N],g[2][N]; struct data { poly f00,f01,f10,f11; data operator+(const data &b) const { data c; c.f00=(f01*b.f00)+(f00*(b.f00+b.f10)); c.f11=(f11*b.f01)+(f10*(b.f11+b.f01)); c.f01=(f01*b.f01)+(f00*(b.f01+b.f11)); c.f10=(f11*b.f00)+(f10*(b.f10+b.f00)); return c; } }tmp; void dfs1(int x,int ff) { fa[x]=ff,dep[x]=dep[ff]+1,size[x]=1; for(int i=hd[x];i;i=nex[i]) { int y=to[i]; if(y==ff) continue; dfs1(y,x); size[x]+=size[y]; if(size[y]>size[son[x]]) son[x]=y; } } void dfs2(int x,int tp) { top[x]=tp; dfn[x]=++tim; bu[tim]=x; ++si[tp]; if(son[x]) { dfs2(son[x],tp); } for(int i=hd[x];i;i=nex[i]) { if(to[i]!=fa[x]&&to[i]!=son[x]) { dfs2(to[i],to[i]); } } } poly calc(int l,int r,int d) { if(l==r) { return g[d][l]; } int mid=(l+r)>>1; return calc(l,mid,d)*calc(mid+1,r,d); } data solve(int l,int r) { if(l==r) { int u=bu[l]; data e; e.f00=f0[u]; e.f11=f1[u]; return e; } int mid=(l+r)>>1; return solve(l,mid)+solve(mid+1,r); } int main() { // setIO("input"); int x,y,z; scanf("%d%d",&n,&m); for(int i=1;i<=n;++i) scanf("%d",&val[i]); for(int i=1;i<n;++i) { scanf("%d%d",&x,&y); add(x,y),add(y,x); } dfs1(1,0),dfs2(1,1); for(int i=1;i<=n;++i) { f0[i].push(1); f1[i].push(0); f1[i].push(val[i]); } for(int i=n;i>=1;--i) { int p=bu[i]; if(top[p]==p) { for(int j=dfn[p];j<=dfn[p]+si[p]-1;++j) { x=bu[j]; int p0=0,p1=0; for(int e=hd[x];e;e=nex[e]) { y=to[e]; if(y==son[x]||y==fa[x]) continue; g[0][++p0]=f0[y]+f1[y]; g[1][++p1]=f0[y]; } if(p0) f0[x]=calc(1,p0,0); if(p1) f1[x]=f1[x]*calc(1,p1,1); } tmp=solve(dfn[p],dfn[p]+si[p]-1); f0[p]=tmp.f01+tmp.f00; f1[p]=tmp.f10+tmp.f11; } } f0[1].resize(m+1); f1[1].resize(m+1); printf("%d\n",(ll)(f0[1].a[m]+f1[1].a[m])%mod); return 0; }