Little W and Contest
思路:首先很显然是并查集去维护答案。
一开始,所有点都是独立的。那么设CF1 = 1的总个数。CF2 = 2的总个数
那么一开始ans = C(CF1,1)*C(CF2,2)+C(CF2,3).
那么考虑合并后怎么维护答案。
这里运用了容斥思想。
当我们合并后。我们只需要将之前合法的,现在不合法的答案删去即可。
那么对于不合法的答案:肯定是这两个连通块各选一个,然后和剩下的所有组合。
设合并的两个连通块为x,y。
那么不合法的答案为。
1.C(x2,1)*C(y2,1)*C(除x,y之外的1,1)
2.C(x2,1)*C(y2,1)*C(除x,y之外的2,1)
3.C(x2,1)*C(y1,1)*C(除x,y之外的2,1)
4.C(x1,1)*C(y2,1)*C(除x,y之外的2,1)
删完之后注意合并连通块里的数目
Code:
#include<bits/stdc++.h> using namespace std; typedef long long LL; typedef long double ld; typedef pair<int,int> pii; const int N = 1e5+5; const int M = 250005; const LL Mod = 1e9+7; #define pi acos(-1) #define INF 1e8 #define INM INT_MIN #define dbg(ax) cout << "now this num is " << ax << endl; inline int read() { int x = 0,f = 1;char c = getchar(); while(c < '0' || c > '9'){if(c == '-') f = -1;c = getchar();} while(c >= '0' && c <= '9'){x = (x<<1)+(x<<3)+(c^48);c = getchar();} return x*f; } LL a[N],x[N],y[N],f[N];//x - 2 ,y - 1. int fa[N]; void init() { f[0] = 1;for(int i = 1;i < N;++i) f[i] = f[i-1]*i%Mod; } int Find(int x) { return x == fa[x] ? x : fa[x] = Find(fa[x]); } LL quick_mi(LL a,LL b) { LL re = 1; while(b) { if(b&1) re = (re*a)%Mod; a = (a*a)%Mod; b >>= 1; } return re; } LL inv(LL n){return quick_mi(n,Mod-2)%Mod;} LL C(LL n,LL m) { return f[n]*inv(f[n-m])%Mod*inv(f[m])%Mod; } int main() { init(); int ca;ca = read(); while(ca--) { int n;n = read(); memset(x,0,sizeof(x));//2 memset(y,0,sizeof(y));//1 int cnt1 = 0,cnt2 = 0;//1,2 for(int i = 1;i <= n;++i) { a[i] = read(),fa[i] = i; if(a[i] == 1) y[i]++,cnt1++; else x[i]++,cnt2++; } LL ans = (C(cnt2,2)*C(cnt1,1)%Mod+C(cnt2,3))%Mod; printf("%lld\n",ans); for(int i = 1;i < n;++i) { int u,v;u = read(),v = read(); int xx = Find(u),yy = Find(v); int z1 = cnt1-y[xx]-y[yy]; int z2 = cnt2-x[xx]-x[yy]; LL ma1,ma2,ma3,ma4; if(x[xx] > 0 && x[yy] > 0)//2 2 1 { ma1 = C(x[xx],1)*C(x[yy],1)%Mod*z1%Mod; ans = ((ans-ma1)%Mod+Mod)%Mod; } if(x[xx] > 0 && y[yy] > 0)//2 1 2 { ma2 = C(x[xx],1)*C(y[yy],1)%Mod*z2%Mod; ans = ((ans-ma2)%Mod+Mod)%Mod; } if(x[xx] > 0 && x[yy] > 0)//2 2 2 { ma3 = C(x[xx],1)*C(x[yy],1)%Mod*z2%Mod; ans = ((ans-ma3)%Mod+Mod)%Mod; } if(y[xx] > 0 && x[yy] > 0)///1 2 2 { ma4 = C(y[xx],1)*C(x[yy],1)%Mod*z2%Mod; ans = ((ans-ma4)%Mod+Mod)%Mod; } fa[xx] = yy;//合并 x[yy] += x[xx]; y[yy] += y[xx]; printf("%lld\n",ans); } } system("pause"); return 0; }View Code