题目
链接:https://ac.nowcoder.com/acm/contest/15644/B
来源:牛客网
Like the trees ܶT1T_1T1 and ܶT2T_2T2 above, gene trees are modeled as unrooted trees where each internal node (non-leaf node) has degree three. A path-length between two leaf nodes is the sum of the phylogenetic lengths of the edges along the unique path between them. In ܶT1T_1T1, the path-length between Human and Cow is 2 + 3 = 5 and the path-length between Human and Goldfish is 2 + 4 + 8 + 10 = 24. These lengths indicate that Human is much closer to Cow than to Goldfish genetically. From ܶT2T_2T2, we can guess that the primate closest to Human is Chimpanzee.
Researchers are interested in measuring the distance between genes in the tree. A famous distance measure is the sum of squared path-lengths of all unordered leaf pairs. More precisely, such a distance ݀d(ܶT) is defined as follows:
d(T)=∑unordered pair(u,v)pu,v2d(T)=\sum_{unordered\,pair(u,v)}p^2_{u,v}d(T)=∑unorderedpair(u,v)pu,v2
where pu,vp_{u,v}pu,v is a path-length between two leaf nodes u and v in ܶT. Note that ݀d(ܶT) is the sum of the squared path-lengths pu,v2p^2_{u,v}pu,v2 over all unordered leaf pairs u and v in ܶT. For the gene tree ܶT2T_2T2 in Figure B.1, there are six paths over all unordered leaf pairs, (Human, Chimpanzee), (Human, Gorilla), (Human, Orangutan), (Chimpanzee, Gorilla), (Chimpanzee, Orangutan), and (Gorilla, Orangutan). The sum of squared path-lengths is 22+42+52+42+52+52=1112^2 + 4^2 + 5^2 + 4^2 + 5^2 + 5^2 = 11122+42+52+42+52+52=111, so ݀d(ܶT2)d(ܶT_2)d(ܶT2) = 111.
Given an unrooted gene tree T, write a program to output ݀d(T).
输入描述:
Your program is to read from standard input. The input starts with a line containing an integer n (4 ≤ n ≤ 100,000), where n is the number of nodes of the input gene tree ܶT. Then ܶT has n − 1 edges. The nodes of ܶT are numbered from 1 to n. The following n − 1 lines represent n − 1 edges of ܶT, where each line contains three non-negative integers ܽa,b, and ݈l (1 ≤ ܽa ≠ ܾb ≤ n, 1 ≤ ݈l ≤ 50) where two nodes ܽa and ܾb form an edge with phylogenetic length ݈l.
输出描述:
Your program is to write to standard output. Print exactly one line. The line should contain one positive integer d(ܶT)示例1
输入
输出
示例2输入
输出
示例3输入
输出
题意
给你一个无根树,求任意两叶节点路径和的平方和。题解
正解好像是换根dp,但我因为比赛时昨天看了半小时点分治,一直以为是点分治,当时比赛时点分治学的不行,最后改完bug交完后tle,补题时才知道,点分治是每一个子树都找一次重心,才能达到nlogn的复杂度。
我不是dp选手所以不懂换根dp怎么搞,就讲讲点分治吧。
点分治,实际上是树上分治算法,它可以很好的处理树上路径问题。它把一颗树看成根节点与他的子树,同时它每一个子树也可以分成一个根节点和子树。以这个为分治的单位。
树上的所有路径按这种分法,实际上就两种情况:
1.路径经过根节点。
2.路径不经过根节点。
就考虑这两种情况,然后我们一步步分治下去,就可以找到所有答案。
第二种情况由分治来解决,我们就只要处理第一种情况。
两叶节点的路径长度可以表示为两个叶节点到根节点距离的和,所以我们只需要求。数组dis[x]表示节点x到根节点的距离,dfs一遍就可以求出所有的dis,这样我们利用dis就可以在O(1)的复杂度中求出任意两叶节点的长度。当然只有这个还是不够,这样两两匹配复杂度是O(n^2)是数据不能容忍的复杂度。但是我们很容易想到,我们能用组合数学的方法成组的找到答案,如有3个叶节点,a1,a2,a3,任意两叶节点路径和的平方和是,a1-a2,a1-a3,a2-a3,这3条路径的平方和,即(dis[a1]+dis[a2])^2+(dis[a1]+dis[a3])^2+(dis[a2]+dis[a3])^2,显然,化简该公式得到,
设dis[ai]=di
2*d1^2+2*d1*(d2+d3)+d2^2+d3^2 +(d2+d3)^2
我们发现先不考虑a2-a3的情况,就从a1出发到其他节点的值为
设n为叶节点个数,sum(i,j)为di到dj的和,ssum(i,j)为di到dj的平方和
(n-1)*d1^2+2*d1*sum(2,n)+ssum(2,n)
其他的路径,如a2-a3,也可以表示为去掉a1剩下的从a2开始的节点的路径的平方和
所以这个公式就可以推广为
(n-1)*d1^2+2*d1*sum(2,n)+ssum(2,n)+(n-2)*d2^2+2*d2*sum(3,n)+ssum(3,n)+...
然后sum和ssum可以使用前缀和维护,这样我们就可以在O(n)的复杂度中求出任意两点的平方和
上面我们讨论的都是子树只有单个叶节点的情况,如果子树有多个叶节点,那我们就会把同子树的叶节点也算上,但同子树的叶节点路径不通过根节点,所以我们需要改动下,最简单的方法就是单个单个计数,计数时不考虑同子树的,也容易实现,只要dfs求出bt[X],表是节点X在根节点的哪个子树,然后使用bt[X]来划分叶节点就可行,通过一些预处理,也能达到O(n)的复杂度。
但实际上有种更优的方法,
很容易发现同子树的连接的节点都是相同的,我们可以从这点优化,
设a1,a2,a3为同子树的叶节点,m为除去这3节点的剩下节点的个数,sum为剩下节点的和,ssum为剩下节点的平方和,则有
m*d1^2+2*d1*sum+ssum+m*d2^2+2*d2*sum+ssum+m*d3^2+2*d3*sum+ssum'
变形得
m*(d1^2+d2^2+d3^2)+2*(d1+d2+d3)*sum+3*ssum
推广得
设n为同子树叶节点个数,sum1为同子树叶节点和,ssum1为平方和,sum2为剩下节点和,ssum2为平方和
m*ssum1+2*sum1*sum2+n*sum2
这样就可以成块的处理节点,并且使用前缀和可以非常方便快速的维护
由于点分治每一次递归都会重新寻找一次重心,所以每一次分治都会减少一半的大小,所以最终的复杂度是O(nlogn)
代码
#include<iostream> #include<algorithm> #include<cmath> #include<cstdio> #include<queue> #include<cstring> #include<ctime> #include<string> #include<vector> #include<map> #include<list> #include<set> #include<stack> #include<bitset> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<ll, ll> pii; typedef pair<ll, ll> pll; const ll N = 1e5 + 5; const ll mod = 1e9 + 7; const double gold = (1 + sqrt(5)) / 2.0; const double PI = acos(-1); const double eps = 1e-7; const ll dx[] = { 0,1,0,-1 }; const ll dy[] = { 1,0,-1,0 }; ll gcd(ll a, ll b) { return b == 0 ? a : gcd(b, a%b); } ll pow(ll x, ll y, ll mod) { ll ans = 1; while (y) { if (y & 1)ans = (ans* x) % mod; x = (x*x) % mod; y >>= 1; }return ans; } ll pow(ll x, ll y) { ll ans = 1; while (y) { if (y & 1)ans = (ans* x) % mod; x = (x*x) % mod; y >>= 1; }return ans; } struct node { ll to, w; node() {} node(ll a, ll b) :to(a), w(b) {} }; vector<node> e[N]; ll Gsize[N]; ll n; ll Gans, root; ll vis[N]; //长链缩边 ll from, Pid; ll CDSPsum; ll CDSPcnt; ll CDSPnum; void cdsp(ll x, ll f) { if (e[x].size() == 2) { CDSPcnt++; vis[x] = 1; if (e[x][0].to != f) { CDSPsum += e[x][0].w; cdsp(e[x][0].to, x); } else { CDSPsum += e[x][1].w; cdsp(e[x][1].to, x); } } else { ll a = from, b = Pid, d = CDSPcnt; ll c = CDSPsum; for (ll i = 0; i < e[x].size(); i++) { ll y = e[x][i].to; if (y == f) { if (CDSPsum&&from&&from != f) { e[a][b].to = x; e[a][b].w = c; e[x][i].to = a; e[x][i].w = c; CDSPnum -= d; } continue; } from = x; Pid = i; CDSPsum = e[x][i].w; CDSPcnt = 0; cdsp(y, x); } } } //计数 ll tnum; void getnum(ll x) { tnum++; for (int i = 0; i < e[x].size(); i++) { ll y = e[x][i].to; if (vis[y])continue; vis[y] = 1; getnum(y); vis[y] = 0; } } //找重心 void Gdfs(ll x) { Gsize[x] = 1; ll mp = 0; for (ll i = 0; i < e[x].size(); i++) { ll y = e[x][i].to; if (vis[y])continue; vis[y] = 1; Gdfs(y); Gsize[x] += Gsize[y]; if (mp < Gsize[y]) mp = Gsize[y]; vis[y] = 0; } mp = max(mp, tnum - Gsize[x]); if (mp < Gans) { Gans = mp; root = x; } } ll dis[N]; ll bt[N]; ll leaf[N]; ll llen; void dfs(ll x) { if (e[x].size() == 1 && x != root) { leaf[++llen] = x; } for (ll i = 0; i < e[x].size(); i++) { ll y = e[x][i].to; if (vis[y])continue; if (x != root)bt[y] = bt[x]; vis[y] = 1; dis[y] = dis[x] + e[x][i].w; dfs(y); vis[y] = 0; } } ll ans; ll sf[N], ssf[N]; //点分治 ll L[N], R[N]; ll slen; void calc(ll x) { Gans = 1e9; tnum = 0; vis[x] = 1; getnum(x); Gdfs(x); vis[x] = 0; x = root; bt[x] = x; for (ll i = 0; i < e[x].size(); i++) { bt[e[x][i].to] = e[x][i].to; } llen = 0; //for(ll i=0;i<=n;i++){ // dis[i]=0; //} dis[x] = 0; vis[x] = 1; dfs(x); vis[x] = 0; sf[0] = ssf[0] = 0; for (ll i = 1; i <= llen; i++) { sf[i] = sf[i - 1] + dis[leaf[i]]; ssf[i] = ssf[i - 1] + dis[leaf[i]] * dis[leaf[i]]; } ll l = 1, r = 1, tip = 0; slen = 0; for (; r <= llen; r++) { if (tip == 0) { tip = bt[leaf[r]]; } if (tip != bt[leaf[r + 1]]) { L[slen] = l; R[slen++] = r; tip = 0; l = r + 1; } } if (tip) { L[slen] = l; R[slen++] = r-1; } for (ll i = 0; i < slen - 1; i++) { ll suma = sf[R[i]] - sf[L[i] - 1]; ll ssuma = ssf[R[i]] - ssf[L[i] - 1]; ll sumb = sf[R[slen - 1]] - sf[L[i + 1] - 1]; ll ssumb = ssf[R[slen - 1]] - ssf[L[i + 1] - 1]; ans += (R[slen - 1] - L[i + 1] + 1)*ssuma + (R[i] - L[i] + 1)*ssumb + 2 * suma*sumb; } vis[root] = 1; for (ll i = 0; i < e[x].size(); i++) { ll y = e[x][i].to; if (vis[y])continue; calc(y); } } inline ll read() { ll s = 0, w = 1; char ch = getchar(); while (ch<'0' || ch>'9') { if (ch == '-')w = -1; ch = getchar(); } while (ch >= '0'&&ch <= '9') s = s * 10 + ch - '0', ch = getchar(); return s * w; } int main() { scanf("%lld", &n); ll a, b, v; ll lf; for (ll i = 0; i < n - 1; i++) { a = read(); b = read(); v = read(); e[a].emplace_back(node(b, v)); e[b].emplace_back(node(a, v)); } for (ll i = 1; i <= n; i++) { if (e[i].size() == 1) { lf = i; break; } } //这缩边实际上速度影响不大,快了3ms。 CDSPnum = n; cdsp(lf, 0); if (CDSPnum == 2) { ans = e[lf][0].w*e[lf][0].w; } calc(lf); printf("%lld\n", ans); scanf(" "); return 0; }