【题意】
给一个基环树森林,求每个基环树的最长链之和
【分析】
对于每个基环树,我们可以把环先当成一个点看待,然后作为树的根节点
这时,直径有两种情况:
1.完全在根的一个子树内,不经过根
2.跨过根,位于两个子树内
可以先dfs一波找到环,然后计算第一种情况,对于根的每个子树进行以此树形dp,取max就是第一种的答案
然后把每个点子树的结果记录到环对应的点上,这时我们就是要计算第二种情况
考虑我们可以取环上的一段+两个端点的子树的结果(刚刚计算出来的部分)
显然暴力取枚举是$O(n^2)$的,我们考虑断环为链,接上长度为2倍的部分,然后在上面利用单调队列计算
这个dfs找环的时候稍微注意一下就好了
#include<bits/stdc++.h> using namespace std; const int maxn=1e6+5; int n,tot,l,que[maxn],tim,id[maxn],vis[maxn],fa[maxn],mark[maxn],times,inin[maxn],dep[maxn],head[maxn]; long long b[maxn],maxdp,dp[maxn],a[maxn]; vector <int> cir[maxn]; struct edge { int to,nxt,v; }e[maxn<<1]; void add(int x,int y,int z) { e[++tot].nxt=head[x]; e[tot].to=y; head[x]=tot; e[tot].v=z; inin[y]++; } void dfs(int x,int deep,int laste,int from) { fa[x]=from; dep[x]=deep; for(int i=head[x];i;i=e[i].nxt) { int to=e[i].to; if(!dep[to]) dfs(to,deep+1,i,x); else if((laste!=i-1 && laste!=i+1) && dep[to]<dep[x]) { ++times; for(int now=x;now!=to;now=fa[now]) cir[times].push_back(now); cir[times].push_back(to); } } } void getw(int x,long long d,int laste) { vis[x]++; id[++tim]=x; b[tim]=d; for(int i=head[x];i;i=e[i].nxt) { int to=e[i].to; if(laste!=i-1&& laste!=i+1 && vis[to]<2 && mark[to]) getw(to,d+e[i].v,i); } } void DP(int x,int fa) { for(int i=head[x];i;i=e[i].nxt) { int to=e[i].to; if(to==fa || mark[to] ) continue; DP(to,x); maxdp=max(maxdp,dp[x]+dp[to]+e[i].v); dp[x]=max(dp[x],dp[to]+e[i].v); } } int q[maxn]; long long solve(int x) { long long res=0; for(int k=0;k<cir[x].size();k++) mark[cir[x][k]]=1; tim=0,getw(cir[x][0],0,0); for(int i=1;i<=cir[x].size();i++) { int x=id[i]; maxdp=0,DP(x,0),res=max(res,maxdp); a[i]=dp[x]; } int len=cir[x].size(); for(int i=1;i<=cir[x].size();i++) a[i+len]=a[i]; int l=1,r=1; q[l]=1; for (int i=2;i<=tim;i++){ while (l<=r&&i-q[l]>=len)l++; int j=q[l]; if (l<=r)res=max(res,a[i]+a[j]+b[i]-b[j]); while (l<=r&&a[i]-b[i]>a[q[r]]-b[q[r]])r--; q[++r]=i; } return res; } int main() { // freopen("a.in","r",stdin); // freopen("a.out","w",stdout); scanf("%d",&n); int x,y; for(int i=1;i<=n;i++) { scanf("%d%d",&x,&y); add(i,x,y); add(x,i,y); } for(int i=1;i<=n;i++) if(!dep[i]) dfs(i,1,0,0); long long ans=0; for(int i=1;i<=times;i++) ans+=solve(i); printf("%lld",ans); return 0; }