上面是pdf题解,下面自己说一下
子任务一:
对于两棵树的比较好求,对于每个联通块,求出每个点到其他点的距离之和
一棵树内部的贡献就是所有的点的\(dis_{sum}\)相加,然后除以2。
树与树之间的贡献:对于上面的每个联通块中的每个点求出的\(dis_{sum}\),取最大值设为,\(max_{dis}[1],max_{dis}[2]\).
然后就是\(max_{dis}[1]*siz[rt[2]]+max_{dis}[2]*siz[rt[1]]+siz[rt[1]]*siz[rt[2]]\)
对于三棵树的,需要考虑哪棵树在中间,假设三棵树编号1,2,3。2在中间,\(1\)和\(2\)通过\(x,y\)相连,2和3通过
\(u,v\)相连,那么树内部的贡献和上面一样.除此之外,
树1和\((x,y)\)这条边贡献 \((dis_{sum}[1,x]+siz[rt[1]])*(n-siz[rt[1]])\)
树3和\((u,v)\)这条边贡献 \((dis_{sum}[3,v]+siz[rt[3]])*(n-siz[rt[3]])\)
树2和\((y,u)\)这条边贡献 \(dis_{sum}[2,y]*siz[rt[1]]+dis[2,u]*siz[rt[3]]+d(y,u)*siz[rt[1]]*siz[rt[3]]\)
发现前两行可以通过两棵树的方法,取max得到答案,最后一行要再用个换根dp,\(dp_{down}[y]\)和\(dp_{up}[y]\)记录,对于固定的y,
它下面和上面的最优的 \(dis[2,u]*siz[rt[3]]+d(y,u)*siz[rt[1]]*siz[rt[3]]\),两遍\(dfs\),\(O(n)\)求出。最后求答案时候再加上一个\(dis_{sum}[2,y]*siz[rt[1]]\)
子任务二:
有一种很简单的方法,我们考虑从叶子节点着手,因为叶子节点限制比较少,要删只能删父边,删完后又会出现新的叶子
叶子是白色肯定得删父边,黑色就不删,这样\(dfs\),不断向上传递,最后如果根节点要删父边肯定不合法,否则合法
而且只有这一种方案,直接输出即可
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
#define rint register int
using namespace std;
const int maxn=1e5+5;
int n,m,blo_cnt;
int rt[4];
int siz[maxn],fa[maxn];
long long dis_sum[maxn];
long long max_dis[4];
long long in_dis[4];
char col[maxn];
bool vis[maxn],is_del[maxn],du[maxn];
vector < int > vec[maxn];
vector < int > vec_id[maxn];
vector < long long > pre[maxn];
vector < long long > suf[maxn];
void dfs1(int x,int prt,int dep){
vis[x]=1,siz[x]=1,fa[x]=prt;
for(rint i=0;i<(int)vec[x].size();++i){
const rint y=vec[x][i];
if(vis[y]) continue;
dfs1(y,x,dep+1);
siz[x]+=siz[y];
}
dis_sum[rt[blo_cnt]]+=dep;
}
void dfs2(int x,int id){
for(rint i=0;i<(int)vec[x].size();++i){
const rint y=vec[x][i];
if(y==fa[x]) continue;
dis_sum[y]=dis_sum[x]+siz[rt[id]]-2ll*siz[y];
dfs2(y,id);
}
max_dis[id]=max(max_dis[id],dis_sum[x]);
in_dis[id]+=dis_sum[x];
}
long long dp_down[maxn],dp_up[maxn],dp[maxn];
// siz1 * dis2,y + siz3 * dis2,u + d(y,u) * siz1 * siz3
// dp_down or up is to calc [ siz3 * dis2,u + d(y,u) * siz1 * siz3 ]
void dfs_down(int x,int A,int B){
dp_down[x]=dis_sum[x]*B;
long long res=0;
for(rint i=0;i<(int)vec[x].size();++i){
pre[x].push_back(res);
const rint y=vec[x][i];
if(y==fa[x]) continue;
dfs_down(y,A,B);
res=max(res,dp_down[y]+1ll*A*B);
}
dp_down[x]=max(dp_down[x],res);
res=0;
for(rint i=(int)vec[x].size()-1;i>=0;--i){
suf[x].push_back(res);
const rint y=vec[x][i];
if(y==fa[x]) continue;
res=max(res,dp_down[y]+1ll*A*B);
}
}
void dfs_up(int x,int A,int B,int id_pre,int id_suf){
dp_up[x]=dis_sum[x]*B;
if(fa[x]) dp_up[x]=max(dp_up[x],max(dp_up[fa[x]],max(pre[fa[x]][id_pre],suf[fa[x]][id_suf]))+1ll*A*B);
dp[x]=max(dp_down[x],dp_up[x])+dis_sum[x]*A;
for(rint i=0;i<(int)vec[x].size();++i){
const rint y=vec[x][i];
if(y==fa[x]) continue;
dfs_up(y,A,B,i,(int)vec[x].size()-i-1);
dp[x]=max(dp[x],dp[y]);
}
}
void solve1(){
long long ans=0;
if(blo_cnt==2) ans=max_dis[1]*siz[rt[2]]+max_dis[2]*siz[rt[1]]+1ll*siz[rt[1]]*siz[rt[2]];
else{
long long Const1=(max_dis[1]+siz[rt[1]])*(n-siz[rt[1]]);
long long Const2=(max_dis[2]+siz[rt[2]])*(n-siz[rt[2]]);
long long Const3=(max_dis[3]+siz[rt[3]])*(n-siz[rt[3]]);
dfs_down(rt[2],siz[rt[1]],siz[rt[3]]);
dfs_up(rt[2],siz[rt[1]],siz[rt[3]],0,0);
ans=max(ans,Const1+Const3+dp[rt[2]]);
dfs_down(rt[3],siz[rt[1]],siz[rt[2]]);
dfs_up(rt[3],siz[rt[1]],siz[rt[2]],0,0);
ans=max(ans,Const1+Const2+dp[rt[3]]);
dfs_down(rt[1],siz[rt[2]],siz[rt[3]]);
dfs_up(rt[1],siz[rt[2]],siz[rt[3]],0,0);
ans=max(ans,Const2+Const3+dp[rt[1]]);
}
for(rint i=1;i<=blo_cnt;++i) ans+=in_dis[i];
printf("%lld\n",ans);
}
// delete white leaves
// this is to judge if the edge to the father need to be cut
bool dfs_del_leaf(int x,int in_edge){
bool del_cc=0;
for(rint i=0;i<(int)vec[x].size();++i){
const rint y=vec[x][i];
if(y==fa[x]) continue;
del_cc^=dfs_del_leaf(y,vec_id[x][i]);
}
if(col[x]=='B'){
if(du[x]==del_cc) return is_del[in_edge]=1;
return 0;
}
else{
if(du[x]!=del_cc) return is_del[in_edge]=1;
return 0;
}
}
void solve2(){
for(int i=1;i<=blo_cnt;++i) if(dfs_del_leaf(rt[i],0)) return printf("-1\n"),void();
int cc=0;
for(rint i=1;i<=m;++i) if(!is_del[i]) ++cc;
printf("%d\n",cc);
for(rint i=1;i<=m;++i) if(!is_del[i]) printf("%d ",i);
printf("\n");
}
int main(){
freopen("lct.in","r",stdin);
freopen("lct.out","w",stdout);
scanf("%d%d%s",&n,&m,col+1);
for(int i=1,x,y;i<=m;++i){
scanf("%d%d",&x,&y);
vec[x].push_back(y),vec_id[x].push_back(i);
vec[y].push_back(x),vec_id[y].push_back(i);
du[x]^=1,du[y]^=1;
}
for(rint i=1;i<=n;++i){
if(!vis[i]){
rt[++blo_cnt]=i;
dfs1(i,0,0);
}
}
for(int i=1;i<=blo_cnt;++i){
dfs2(rt[i],i);
in_dis[i]/=2;
}
solve1(),solve2();
return 0;
}