https://ac.nowcoder.com/acm/contest/5278/G
题意很好理解。而且很容易发现树中同一深度的松鼠才会打架。
预处理出节点的深度和dfs序,然后枚举树的深度,同一深度的所有节点和根节点s去建虚树,每建好一次就从根节点s出发跑一次树型dp
dp的转移方程比较好想,设当前节点为x,si为x的子树,则dp[x] = ∑max(1,dp[si] + (depth[si] - depth[x]) ) (dp[si]需要>0)
depth[]表示节点在原树的深度,最终每棵虚树对答案的贡献是 max(1,dp[x] - 1 )(dp[x]>0)
复杂度O(nlogn)
1 #include<bits/stdc++.h> 2 typedef long long ll; 3 using namespace std; 4 const int maxbit = 20; 5 const int maxn = 2e5+5; 6 vector<int> G[maxn],vt[maxn],p[maxn];//vt为虚树 7 ll a[maxn],dp[maxn]; 8 int depth[maxn],fa[maxn][maxbit],Log[maxn],in[maxn];//in数组为dfs序 9 int n,cnt,s; 10 void add(int u,int v){G[u].push_back(v),G[v].push_back(u);} 11 bool cmp(int u,int v) {return in[u]<in[v];} 12 void pre(){ 13 Log[0] = -1; 14 Log[1] = 0,Log[2] = 1; 15 for(int i = 3;i<maxn;i++) Log[i] = Log[i/2] + 1; 16 } 17 void dfs(int cur,int father){//dfs预处理 18 in[cur] = ++cnt;//处理dfs序 19 depth[cur] = depth[father] + 1;//当前结点的深度为父亲结点+1 20 fa[cur][0] = father;//更新当前结点的父亲结点 21 for(int j = 1;(1<<j)<=n;j++){//倍增更新当前结点的祖先 22 fa[cur][j] = fa[fa[cur][j-1]][j-1]; 23 } 24 for(int i = 0;i<G[cur].size() ;i++){ 25 if(G[cur][i] != father) {//dfs遍历 26 dfs(G[cur][i],cur); 27 } 28 } 29 } 30 int LCA(int u,int v){ 31 if(depth[u]<depth[v]) swap(u,v); 32 int dist = depth[u] - depth[v];//深度差 33 while(depth[u]!=depth[v]){//把较深的结点u倍增到与v高度相等 34 u = fa[u][Log[depth[u]-depth[v]]]; 35 } 36 if(u == v) return u;//如果u倍增到v,说明v是u的LCA 37 for(int i = Log[depth[u]];i>=0;i--){//否则两者同时向上倍增 38 if(fa[u][i]!=fa[v][i]){//如果向上倍增的祖先不同,说明是可以继续倍增 39 u = fa[u][i];//替换两个结点 40 v = fa[v][i]; 41 } 42 } 43 return fa[u][0];//最终结果为u v向上一层就是LCA 44 } 45 46 void build (int indx){//传虚树和深度为indx的数组 47 stack<int> st; 48 st.push(s);//入栈s节点 49 vt[s].clear(); 50 int tmp ; 51 for(int i = 0;i<p[indx].size();i++){ 52 tmp = 0; 53 int cur = LCA(p[indx][i],st.top());//求出栈顶元素和p[i]的LCA 54 while(!st.empty() && LCA(cur,st.top()) != st.top()){//如果LCA和栈顶元素不一样,建栈中虚树的边 55 if(tmp) vt[st.top()].push_back(tmp); 56 tmp = st.top(); 57 st.pop(); 58 } 59 if(st.empty() || st.top()!=cur){ 60 st.push(cur);//如果栈为空或者栈顶元素不等于LCA,入栈 61 vt[cur].clear(); 62 } 63 if(tmp) vt[st.top()].push_back(tmp);//把最后的tmp节点加入链中 64 st.push(p[indx][i]);//当前p[i]入栈 65 vt[p[indx][i]].clear(); 66 } 67 tmp = 0; 68 while(!st.empty()){//加入最后一条链 69 if(tmp) vt[st.top()].push_back(tmp); 70 tmp = st.top(); 71 st.pop(); 72 } 73 } 74 void getdp(int cur){//传需要dp的子虚树 75 dp[cur] = 0; 76 if(vt[cur].size() == 0) { 77 dp[cur] = a[cur]; 78 return; 79 } 80 for(int i = 0;i<vt[cur].size();i++){ 81 int v = vt[cur][i]; 82 getdp(v); 83 if(dp[v]!=0){ 84 dp[cur]+=max((ll)1,dp[v]-(depth[v]-depth[cur])); 85 } 86 } 87 } 88 int main(){ 89 scanf("%d%d",&n,&s); 90 for(int i = 1;i<=n;i++){ 91 scanf("%lld",&a[i]); 92 } 93 for (int i = 0; i < n-1; ++i) 94 { 95 int u,v; 96 scanf("%d%d",&u,&v); 97 add(u,v); 98 /* code */ 99 } 100 pre(); 101 dfs(s,0); 102 ll ans = 0; 103 if(a[s]>1) ans+=(a[s]-1); 104 else ans+=a[s]; 105 for(int i = 1;i<=n;i++) p[depth[i]].push_back(i); 106 for(int i = 2;i<=n;i++){ 107 sort(p[i].begin(),p[i].end(),cmp);//深度一样的节点按dfs序排序 108 if(p[i].size() == 0) continue; 109 build(i);//同一深度的节点建虚树 110 getdp(s); 111 if(dp[s]>1) ans+=dp[s]-1; 112 else ans+=dp[s]; 113 vt[s].clear(); 114 for(int j = 0;j<p[i].size();j++) vt[p[i][j]].clear(); 115 } 116 printf("%lld",ans); 117 return 0; 118 }