题目描述
题解
至少相比一年以前想到了拆y^i,只不过没想到提y^n出来而已(确信)
op=0
块=点-边,hash
op=1
假设一棵红树的块数为j,则贡献为y^j*方案数
方案数直接用prufer算\(n^{m-2}\prod a_i\)会算重,会连上蓝树的边
套路:恰好=-1后的至少
问题是直接把(y-1+1)展开会发现顺序反了,有兴趣可以自己尝试
对于一个树S,设|S|为边数,则贡献为\(y^{n-|S|}\)
把y^n提出来,设z=1/y,变成每选一条边就乘上(z-1),此时把(z-1+1)^j二项式展开后就是对的了
具体来说就是当i<j时,算z^i的多选边的方案会把j算C(j,i)次
于是可以设f[i][j]表示当前到子树i,块大小为j的系数和,O(n^2)
瓶颈在于求\(\prod a_i\),等价于每个连通块里选恰好一个点
所以设\(f[i][0/1]\)表示当前连通块内是否有选点,随便转移即可O(n)
op=2
先提y^n,op=1的展开仍然可以用
枚举树S,贡献为\(\sum_{S}(z-1)^{|S|}*(方案)^2\)
边不好搞,设块m=n-|S|,则原式为\(\sum_S (z-1)^{n-m}*(n^{m-2}\prod a_i)^2\)
直接枚举块大小ai,那么原式变为\(\sum (z-1)^{n-m}*n^{2m-4}*\prod a_i^2*\prod a_i^{a_i-2}*C(n,...)/块数!\),后面的是块内的方案,除以块数!把块变为无序
提掉n和常数,可以设f[i]表示n个点放了i个的答案
\(f[i]=\sum f[i-j]*(z-1)^{-1}*n^2*j^j*\binom{i-1}{j-1}\),即枚举最后一个点所在块
这个用分治ntt可以做到log^2,也可以换一种做法:
设OGF\(F_i=(z-1)^{-1}*n^2*i^i\),\(G(x)\)是\(F(x)\)的EGF,则\([x^n]e^{G(x)}\)就是答案
因为\(e^{G(x)}=\sum \frac{G^i(x)}{i!}\),即枚举段数再除阶乘变为无序
常数略大
注意EGF和exp是两种不同的东西,并且EGF中的i!是形式,最后一定要乘回去变为OGF
code
exp调试方法:
内部的数组可以用static,不要用namespace
#include <bits/stdc++.h>
#define fo(a,b,c) for (a=b; a<=c; a++)
#define fd(a,b,c) for (a=b; a>=c; a--)
#define min(a,b) (a<b?a:b)
#define max(a,b) (a>b?a:b)
#define mod 998244353
#define Mod 998244351
#define G 3
#define ll long long
//#define file
using namespace std;
struct graph{
int a[200001][2],ls[100001],len;
void New(int x,int y) {++len;a[len][0]=y;a[len][1]=ls[x];ls[x]=len;}
} gr;
ll dp[100001][2],jc[262144],Jc[262144],w[262144],y,z,ans;
int n,op,i,j,k,l,Len,N,len;
ll qpower(ll a,int b) {ll ans=1; while (b) {if (b&1) ans=ans*a%mod;a=a*a%mod;b>>=1;} return ans;}
void swap(int &x,int &y) {int z=x;x=y;y=z;}
void work0()
{
static map<pair<int,int>,bool> hs;
static map<pair<int,int>,bool> :: iterator I;
fo(i,1,n-1)
{
scanf("%d%d",&j,&k);
if (j>k) swap(j,k);
hs[pair<int,int>(j,k)]=1;
}
fo(i,1,n-1)
{
scanf("%d%d",&j,&k);
if (j>k) swap(j,k);
I=hs.find(pair<int,int>(j,k));
if (I!=hs.end()) ++ans;
}
ans=qpower(y,n-ans);
}
void dfs(int Fa,int t)
{
ll x,y;
int i;
dp[t][0]=1;
for (i=gr.ls[t]; i; i=gr.a[i][1])
if (gr.a[i][0]!=Fa)
{
dfs(t,gr.a[i][0]);
x=(dp[t][0]*dp[gr.a[i][0]][1]%mod*n+dp[t][0]*dp[gr.a[i][0]][0]%mod*z)%mod;
y=(dp[t][1]*dp[gr.a[i][0]][1]%mod*n+(dp[t][1]*dp[gr.a[i][0]][0]+dp[t][0]*dp[gr.a[i][0]][1])%mod*z)%mod;
dp[t][0]=x,dp[t][1]=y;
}
dp[t][1]=(dp[t][1]+dp[t][0])%mod;
}
void work1()
{
fo(i,1,n-1) scanf("%d%d",&j,&k),gr.New(j,k),gr.New(k,j);
dfs(0,1);
ans=dp[1][1]*qpower(n,Mod)%mod*qpower(y,n)%mod;
}
void init()
{
jc[0]=jc[1]=Jc[0]=Jc[1]=w[1]=1;
fo(i,2,262143) w[i]=mod-w[mod%i]*(mod/i)%mod,jc[i]=jc[i-1]*i%mod,Jc[i]=Jc[i-1]*w[i]%mod;
}
void dft(ll *a,int tp,int N,int len)
{
static ll A[262144];
int i,j,k,l,S=N,s1=2,s2=1;
ll u,v,w,W;
fo(i,0,N-1)
{
j=i,k=0;
fo(l,1,len)
k=k*2+(j&1),j>>=1;
A[i]=a[k];
}
memcpy(a,A,N*8);
fo(i,1,len)
{
W=(tp==1)?qpower(G,(mod-1)/s1):qpower(G,(mod-1)-(mod-1)/s1);
S>>=1;
fo(j,0,S-1)
{
w=1;
fo(k,0,s2-1)
{
u=a[j*s1+k],v=a[j*s1+k+s2]*w;
a[j*s1+k]=(u+v)%mod;
a[j*s1+k+s2]=(u-v)%mod;
w=w*W%mod;
}
}
s1<<=1,s2<<=1;
}
}
void mul(ll *a,ll *b,ll *c,int N,int len)
{
static ll A[262144],B[262144];
int i,N2=qpower(N,Mod);
memset(A,0,N*8),memset(B,0,N*8);
fo(i,0,N-1) A[i]=a[i],B[i]=b[i];
dft(A,1,N,len),dft(B,1,N,len);
fo(i,0,N-1) A[i]=A[i]*B[i]%mod;
dft(A,-1,N,len);
fo(i,0,N-1) c[i]=A[i]*N2%mod;
}
void ny(ll *a,ll *b,int N,int len)
{
static ll A[262144],c[262144];
int i;
memset(b,0,N*8);
if (N==1) {b[0]=qpower(a[0],Mod);return;}
ny(a,b,N/2,len-1);
memset(c,0,N*8*2);
mul(b,b,c,N,len);
memset(A,0,N*8*2),memcpy(A,a,N*8);
mul(c,A,c,N*2,len+1);
fo(i,0,N-1) b[i]=(2*b[i]-c[i])%mod;
}
void dao(ll *a,ll *b,int N,int len)
{
int i;
fo(i,0,N-2) b[i]=a[i+1]*(i+1)%mod;b[N-1]=0;
}
void ji(ll *a,ll *b,int N,int len)
{
int i;
fd(i,N-1,1) b[i]=a[i-1]*w[i]%mod;b[0]=0;
}
void Ln(ll *a,ll *b,int N,int len)
{
static ll A[262144],B[262144];
int i;
memset(A,0,N*8*2),memset(B,0,N*8*2);
dao(a,A,N,len),ny(a,B,N,len);
mul(A,B,b,N*2,len+1);
ji(b,b,N,len);
}
void Exp(ll *a,ll *b,int N,int len)
{
static ll A[262144];
int i;
memset(b,0,N*8*2);
if (N==1) {b[0]=1;return;}
Exp(a,b,N/2,len-1);
memset(A,0,N*8*2);
Ln(b,A,N,len);
fo(i,N,N+N-1) A[i]=0;
fo(i,0,N-1) A[i]=(-A[i]+a[i])%mod;++A[0];
mul(A,b,b,N*2,len+1);
}
void work2()
{
static ll f[262144],F[262144];
fo(i,1,n)
f[i]=qpower(z,Mod)*n%mod*n%mod*qpower(i,i)%mod*Jc[i]%mod;len=ceil(log2(n+1)),N=qpower(2,len);
Exp(f,F,N,len);
ans=(F[n]*jc[n]%mod)*qpower(z,n)%mod*qpower(qpower(n,Mod),4)%mod*qpower(y,n)%mod; //EGF->OGF
}
int main()
{
#ifdef file
freopen("loj2983.in","r",stdin);
#endif
init();
scanf("%d%lld%d",&n,&y,&op),z=qpower(y,Mod)-1;
if (y==1)
{
switch (op)
{
case 0:{printf("%lld\n",1);break;}
case 1:{printf("%lld\n",qpower(n,n-2));break;}
case 2:{printf("%lld\n",qpower(n,2*(n-2)));break;}
}
return 0;
}
switch (op)
{
case 0:{work0();break;}
case 1:{work1();break;}
case 2:{work2();break;}
}
printf("%lld\n",(ans+mod)%mod);
fclose(stdin);
fclose(stdout);
return 0;
}