Solution
首先考虑转化题意,以下用\(scc\)指代强连通分量
活动的每一步相同于:如果\(y,z\)在同一个\(scc\)中,\(x\)向\(y\)有连边,那么\(x\)就可以向\(z\)连边。
也就是说,对于一个\(scc\),如果\(x\)向\(scc\)中连了一条边,那么\(x\)就可以在活动后向\(scc\)中的任何一个点连边。
那么分别考虑每个\(scc\)的贡献就是内部的贡献\(siz(siz-1)+\)向\(scc\)连了边的点的贡献\(siz\cdot x\),其中\(x\)时向该\(scc\)连边的点的数量。
于是,对于每一个\(scc\)开一个\(set\) \(pt[i]\)维护内部的点,一个\(set\) \(in[i]\)维护向该\(scc\)连边的点。
对于加入的每一条边\((x,y)\):
\(1.\)\((x,y)\)在同一个\(scc\)中:直接跳过
\(2.\)如果\(y\)所在\(scc\)本来就向\(x\)所在\(scc\)连了边,那么需要合并这两个\(scc\)
\(3.\)否则,直接在\(in[y]\)中加入\(x\)
为了判断是否可以合并\(scc\),还需要额外维护两个\(set\):\(ins[s]\)和\(outs[s]\)表示向\(s\)连边的\(scc\)以及\(s\)连出去的\(scc\),合并时使用启发式合并,依次考虑每一个\(set\)即可
注意到合并后可能还会导致新的需要合并的\(scc\),于是开一个·队列先后合并即可。
Code
#include<bits/stdc++.h>
using namespace std;
const int N=3e5+10;
typedef long long ll;
int n,m,fa[N];
inline int find(int x){return fa[x]==x?x:fa[x]=find(fa[x]);}
set<int> pt[N];//点
set<int> ins[N];//连入的强连通分量
set<int> outs[N];//连出的强连通分量
set<int> in[N];//连入的点
inline ll calc(int x){
return 1ll*pt[x].size()*(pt[x].size()-1+in[x].size());
}
typedef pair<int,int> pii;
#define mp make_pair
#define it set<int>::iterator
ll ans=0,ret=0;
queue<pii>q;
inline void merge(int x,int y){
int fx=find(x),fy=find(y);
if(fx==fy) return ;
ans-=calc(fx)+calc(fy);
if(pt[fx].size()>pt[fy].size()) swap(fx,fy),swap(x,y);
fa[fx]=fy;
if(ins[fx].count(fy)) ins[fx].erase(fy),outs[fy].erase(fx);
if(outs[fx].count(fy)) outs[fx].erase(fy),ins[fy].erase(fx);
for(it i=pt[fx].begin();i!=pt[fx].end();i++){
int s=*i;
if(in[fy].count(s)) in[fy].erase(s);
pt[fy].insert(*i);
}
for(it i=ins[fx].begin();i!=ins[fx].end();i++){
int s=*i;
if(outs[fy].count(s)) q.push(mp(s,fy));
else ins[fy].insert(s);outs[s].erase(fx);outs[s].insert(fy);
}
for(it i=outs[fx].begin();i!=outs[fx].end();i++){
int s=*i;
if(ins[fy].count(s)) q.push(mp(s,fy));
else outs[fy].insert(s);ins[s].erase(fx);ins[s].insert(fy);
}
for(it i=in[fx].begin();i!=in[fx].end();i++)
if(!pt[fy].count(*i)) in[fy].insert(*i);
ans+=calc(fy);
}
inline void work(int x,int y){
q.push(mp(x,y));
while(!q.empty())
merge(q.front().first,q.front().second),q.pop();
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i) fa[i]=i,pt[i].insert(i);
for(int i=1;i<=n;++i) ans+=calc(i);
for(int i=1,a,b;i<=m;++i){
scanf("%d%d",&a,&b);
int fx=find(a),fy=find(b);
if(fx==fy){
printf("%lld\n",ans);
continue;
}
if(ins[fx].count(fy)) work(a,b);
else{
ans-=calc(fy);
ins[fy].insert(fx); outs[fx].insert(fy);
in[fy].insert(a);
ans+=calc(fy);
}
printf("%lld\n",ans);
}
return 0;
}