城市旅行
题目链接:ybt金牌导航5-4-4 / luogu P4842
题目大意
给你一棵树,要你维护一些操作:
删除某条边(如果两点间不联通就不管)
添加某条边(如果两点间已联通就不管)
给某条路径上的点点权加一个值(如果两点不连通就不管)
询问某条路径上任选两个点,这两个点之间路径的权值和的期望。(如果两点不连通就输出 -1)
思路
看到加边删边找路径,自然想到 LCT。
然后我们考虑如何维护输出的值。
那容易想到,我们可以补考虑选的点,而是考虑一个点,有多少个路径会经过它。
那容易想到对于长度为
s
z
sz
sz 的路径,对于第
i
i
i 个点,有
i
×
(
s
z
−
i
+
1
)
i\times(sz-i+1)
i×(sz−i+1) 个路径经过了它。(左右各选一个)
那它的贡献就是
i
×
(
s
z
−
i
+
1
)
×
a
i
i\times(sz-i+1)\times a_i
i×(sz−i+1)×ai。
那总贡献就是:
∑
i
=
1
s
z
i
×
(
s
z
−
i
+
1
)
×
a
i
C
s
z
+
1
2
\dfrac{\sum\limits_{i=1}^{sz}i\times(sz-i+1)\times a_i}{C_{sz+1}^2}
Csz+12i=1∑szi×(sz−i+1)×ai
(下面就是
C
s
z
+
1
2
C_{sz+1}^2
Csz+12 因为选的两个点可以是同一个点)
那我们要维护的就是它了,分母很好搞,直接每次算就行,问题是分子。
那我们考虑 DP,已经求出了左右子树,要怎么搞到它。
我们设左子树的大小是
b
0
b_0
b0,然后序列是
b
1
,
b
2
,
.
.
.
,
b
b
0
b_1,b_2,...,b_{b_0}
b1,b2,...,bb0,右子树大小是
c
0
c_0
c0,序列是
c
1
,
c
2
,
.
.
.
,
c
c
0
c_1,c_2,...,c_{c_0}
c1,c2,...,cc0。
那对于左子树里面的第
i
i
i 个点,它在左子树里面的贡献就是
i
×
(
b
0
−
i
+
1
)
×
b
i
i\times(b_0-i+1)\times b_i
i×(b0−i+1)×bi,它在这里的贡献就是
i
×
(
b
0
+
c
0
+
1
−
i
+
1
)
×
b
i
i\times(b_0+c_0+1-i+1)\times b_i
i×(b0+c0+1−i+1)×bi。作差,就是
i
×
b
i
×
(
c
0
+
1
)
i\times b_i\times (c_0+1)
i×bi×(c0+1)。
那左子树的贡献就是它原本的贡献加上
(
c
0
+
1
)
×
∑
i
=
1
b
0
i
×
b
i
(c_0+1)\times\sum\limits_{i=1}^{b_0}i\times b_i
(c0+1)×i=1∑b0i×bi
那我们发现右边的部分(
∑
i
=
1
b
0
i
×
b
i
\sum\limits_{i=1}^{b_0}i\times b_i
i=1∑b0i×bi)我们也可以 DP,我是用
l
s
u
m
lsum
lsum 数组记录,这里就不讲了,不会的自己看代码。
那接着右子树用同样的方法:
原本:
i
×
(
c
0
−
i
+
1
)
×
c
i
i\times(c_0-i+1)\times c_i
i×(c0−i+1)×ci
现在:
(
b
0
+
1
+
i
)
×
(
b
0
+
c
0
+
1
−
(
b
0
+
1
+
i
)
+
1
)
×
c
i
(b_0+1+i)\times(b_0+c_0+1-(b_0+1+i)+1)\times c_i
(b0+1+i)×(b0+c0+1−(b0+1+i)+1)×ci
=
(
b
0
+
1
+
i
)
×
(
c
0
−
i
+
1
)
×
c
i
=(b_0+1+i)\times(c_0-i+1)\times c_i
=(b0+1+i)×(c0−i+1)×ci
差:
(
b
0
+
1
)
×
(
c
0
−
i
+
1
)
×
c
i
(b_0+1)\times(c_0-i+1)\times c_i
(b0+1)×(c0−i+1)×ci
那左子树的贡献就是它原本的贡献加上
(
b
0
+
1
)
×
∑
i
=
1
b
0
(
c
0
−
i
+
1
)
×
c
i
(b_0+1)\times\sum\limits_{i=1}^{b_0}(c_0-i+1)\times c_i
(b0+1)×i=1∑b0(c0−i+1)×ci
然后右边部分(
∑
i
=
1
b
0
(
c
0
−
i
+
1
)
×
c
i
\sum\limits_{i=1}^{b_0}(c_0-i+1)\times c_i
i=1∑b0(c0−i+1)×ci)继续 DP,我是用
r
s
u
m
rsum
rsum 数组记录。
接着就是新的点,那这个其实容易,就直接暴力算: a × ( b 0 + 1 ) × ( c 0 + 1 ) a\times(b_0+1)\times(c_0+1) a×(b0+1)×(c0+1)。(记得加一)
那查询我们就搞定了,接着,就是修改了。(加边删边就是普通 LCT,不用搞)
那我们也是懒标记,那每次要怎么改呢?
首先权值就普通的加,权值和就加上它乘大小。
接着是
l
s
u
m
,
r
s
u
m
lsum,rsum
lsum,rsum,容易看到你每个数每加
x
x
x,值就会多
x
+
2
x
+
3
x
+
4
x
+
.
.
.
x+2x+3x+4x+...
x+2x+3x+4x+...,那就是
x
×
(
1
+
s
z
)
×
s
z
/
2
x\times (1+sz)\times sz / 2
x×(1+sz)×sz/2
那接着就是
a
n
s
ans
ans,即期望的分子,那我们可以列出式子:
a
n
s
+
=
∑
i
=
1
s
z
i
×
(
s
z
−
i
+
1
)
×
d
ans+=\sum\limits_{i=1}^{sz}i\times(sz-i+1)\times d
ans+=i=1∑szi×(sz−i+1)×d
然后我们由化简可以得到
a
n
s
+
=
s
z
(
s
z
+
1
)
(
s
z
+
2
)
6
×
d
ans+=\dfrac{sz(sz+1)(sz+2)}{6}\times d
ans+=6sz(sz+1)(sz+2)×d
然后就可以搞啦!
化简过程
知道的可以不看。
要搞的东西:
∑
i
=
1
s
z
i
×
(
s
z
−
i
+
1
)
=
s
z
(
s
z
+
1
)
(
s
z
+
2
)
6
\sum\limits_{i=1}^{sz}i\times(sz-i+1)=\dfrac{sz(sz+1)(sz+2)}{6}
i=1∑szi×(sz−i+1)=6sz(sz+1)(sz+2)
首先考虑让其中一项固定:
∑
i
=
1
s
z
i
×
(
s
z
−
i
+
1
)
=
∑
i
=
1
s
z
i
×
s
z
−
∑
i
=
1
s
z
i
×
(
i
−
1
)
\sum\limits_{i=1}^{sz}i\times(sz-i+1)=\sum\limits_{i=1}^{sz}i\times sz-\sum\limits_{i=1}^{sz}i\times(i-1)
i=1∑szi×(sz−i+1)=i=1∑szi×sz−i=1∑szi×(i−1)
然后右边部分考虑去括号:
∑
i
=
1
s
z
i
×
s
z
−
∑
i
=
1
s
z
(
i
2
−
i
)
\sum\limits_{i=1}^{sz}i\times sz-\sum\limits_{i=1}^{sz}(i^2-i)
i=1∑szi×sz−i=1∑sz(i2−i)
分别拿出来:
s
z
×
∑
i
=
1
s
z
i
−
∑
i
=
1
s
z
i
2
+
∑
i
=
1
s
z
i
sz\times\sum\limits_{i=1}^{sz}i-\sum\limits_{i=1}^{sz}i^2+\sum\limits_{i=1}^{sz}i
sz×i=1∑szi−i=1∑szi2+i=1∑szi
然后都可以去掉
∑
\sum
∑:
s
z
×
(
s
z
+
1
)
×
s
z
2
−
s
z
(
s
z
+
1
)
(
2
×
s
z
+
1
)
6
+
(
s
z
+
1
)
×
s
z
2
sz\times\frac{(sz+1)\times sz}{2}-\frac{sz(sz+1)(2\times sz+1)}{6}+\frac{(sz+1)\times sz}{2}
sz×2(sz+1)×sz−6sz(sz+1)(2×sz+1)+2(sz+1)×sz
合并一下:
3
(
s
z
+
1
)
(
s
z
+
1
)
s
z
6
−
s
z
(
s
z
+
1
)
(
2
s
z
+
1
)
6
\frac{3(sz+1)(sz+1)sz}{6}-\frac{sz(sz+1)(2sz+1)}{6}
63(sz+1)(sz+1)sz−6sz(sz+1)(2sz+1)
(
3
s
z
+
3
)
(
s
z
+
1
)
s
z
6
−
(
2
s
z
+
1
)
(
s
z
+
1
)
s
z
6
\frac{(3sz+3)(sz+1)sz}{6}-\frac{(2sz+1)(sz+1)sz}{6}
6(3sz+3)(sz+1)sz−6(2sz+1)(sz+1)sz
(
s
z
+
2
)
(
s
z
+
1
)
s
z
6
\frac{(sz+2)(sz+1)sz}{6}
6(sz+2)(sz+1)sz
然后就好啦!
可能有人(指我自己)会不知道为什么
∑
i
=
1
s
z
i
2
=
s
z
(
s
z
+
1
)
(
2
s
z
+
1
)
6
\sum\limits_{i=1}^{sz}i^2=\dfrac{sz(sz+1)(2sz+1)}{6}
i=1∑szi2=6sz(sz+1)(2sz+1)
然后这里也讲讲,这个是用立方差来搞的。
x
3
−
(
x
−
1
)
3
=
x
3
−
(
x
3
−
3
x
2
+
3
x
−
1
)
=
3
x
2
−
3
x
+
1
x^3-(x-1)^3=x^3-(x^3-3x^2+3x-1)=3x^2-3x+1
x3−(x−1)3=x3−(x3−3x2+3x−1)=3x2−3x+1
然后根据这个,我们把
(
n
3
−
(
n
−
1
)
3
)
+
(
(
n
−
1
)
3
−
(
n
−
2
)
3
)
+
.
.
.
+
(
2
3
−
1
3
)
(n^3-(n-1)^3)+((n-1)^3-(n-2)^3)+...+(2^3-1^3)
(n3−(n−1)3)+((n−1)3−(n−2)3)+...+(23−13) 每个都转。
那互相消掉,就是
n
3
−
1
3
=
(
3
n
2
−
3
n
+
1
)
+
(
3
(
n
−
1
)
2
−
3
(
n
−
1
)
+
1
)
+
.
.
.
+
(
3
×
2
2
−
3
×
2
+
1
)
n^3-1^3=(3n^2-3n+1)+(3(n-1)^2-3(n-1)+1)+...+(3\times2^2-3\times2+1)
n3−13=(3n2−3n+1)+(3(n−1)2−3(n−1)+1)+...+(3×22−3×2+1)
拆开:
n
3
−
1
=
3
n
2
+
3
(
n
−
1
)
2
+
.
.
.
+
3
×
2
2
−
(
3
n
+
3
(
n
−
1
)
+
.
.
.
+
3
×
2
+
(
n
−
1
)
)
n^3-1=3n^2+3(n-1)^2+...+3\times2^2-(3n+3(n-1)+...+3\times2+(n-1))
n3−1=3n2+3(n−1)2+...+3×22−(3n+3(n−1)+...+3×2+(n−1))
然后继续搞:
n
3
−
1
=
3
(
n
2
+
(
n
−
1
)
2
+
.
.
.
+
2
2
)
−
3
(
n
+
(
n
−
1
)
+
.
.
.
+
2
)
+
(
n
−
1
)
n^3-1=3(n^2+(n-1)^2+...+2^2)-3(n+(n-1)+...+2)+(n-1)
n3−1=3(n2+(n−1)2+...+22)−3(n+(n−1)+...+2)+(n−1)
移项:
3
(
n
2
+
(
n
−
1
)
2
+
.
.
.
+
2
2
+
1
2
)
=
n
3
−
1
−
(
n
−
1
)
+
3
(
n
+
2
)
(
n
−
1
)
2
+
3
×
1
2
3(n^2+(n-1)^2+...+2^2+1^2)=n^3-1-(n-1)+\frac{3(n+2)(n-1)}{2}+3\times1^2
3(n2+(n−1)2+...+22+12)=n3−1−(n−1)+23(n+2)(n−1)+3×12
3
(
n
2
+
(
n
−
1
)
2
+
.
.
.
+
2
2
+
1
2
)
=
n
3
−
n
+
3
+
3
(
n
+
2
)
(
n
−
1
)
2
3(n^2+(n-1)^2+...+2^2+1^2)=n^3-n+3+\frac{3(n+2)(n-1)}{2}
3(n2+(n−1)2+...+22+12)=n3−n+3+23(n+2)(n−1)
(
n
2
+
(
n
−
1
)
2
+
.
.
.
+
2
2
+
1
2
)
=
2
n
3
−
2
n
+
6
+
3
(
n
+
2
)
(
n
−
1
)
6
(n^2+(n-1)^2+...+2^2+1^2)=\frac{2n^3-2n+6+3(n+2)(n-1)}{6}
(n2+(n−1)2+...+22+12)=62n3−2n+6+3(n+2)(n−1)
=
2
n
3
−
2
n
+
6
+
3
(
n
2
+
n
−
2
)
6
=
2
n
3
+
3
n
2
+
n
6
=
n
(
2
n
2
+
3
n
+
1
)
6
=
n
(
n
+
1
)
(
2
n
+
1
)
6
=\frac{2n^3-2n+6+3(n^2+n-2)}{6}=\frac{2n^3+3n^2+n}{6}=\frac{n(2n^2+3n+1)}{6}=\frac{n(n+1)(2n+1)}{6}
=62n3−2n+6+3(n2+n−2)=62n3+3n2+n=6n(2n2+3n+1)=6n(n+1)(2n+1)
然后就有了。
代码
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;
int n, m, sz[50001], d;
int l[50001], r[50001], fa[50001];
ll ans[50001], val[50001], lz[50001];
ll lsum[50001], rsum[50001], sum[50001];
bool lzs[50001];
int op, x, y;
//LCT
bool nrt(int x) {
return l[fa[x]] == x || r[fa[x]] == x;
}
bool ls(int x) {
return l[fa[x]] == x;
}
void up(int x) {//把推公式推出来的放上去
sz[x] = sz[l[x]] + sz[r[x]] + 1;
sum[x] = sum[l[x]] + sum[r[x]] + val[x];
//DP 维护 lsum rsum
lsum[x] = lsum[l[x]] + val[x] * (sz[l[x]] + 1) + (lsum[r[x]] + sum[r[x]] * (sz[l[x]] + 1));
rsum[x] = rsum[r[x]] + val[x] * (sz[r[x]] + 1) + (rsum[l[x]] + sum[l[x]] * (sz[r[x]] + 1));
ans[x] = ans[l[x]] + ans[r[x]] + (sz[r[x]] + 1) * lsum[l[x]] + (sz[l[x]] + 1) * rsum[r[x]] + val[x] * (sz[l[x]] + 1) * (sz[r[x]] + 1);
}
void downa(int x, ll Val) {
val[x] += Val;
lz[x] += Val;
sum[x] += Val * sz[x];
lsum[x] += Val * (1 + sz[x]) * sz[x] / 2;
rsum[x] += Val * (sz[x] + 1) * sz[x] / 2;
ans[x] += Val * sz[x] * (sz[x] + 1) * (sz[x] + 2) / 6;
}
void downs(int x) {
swap(l[x], r[x]);
swap(lsum[x], rsum[x]);//记得这个也要 swap
lzs[x] ^= 1;
}
void down(int x) {
if (lzs[x]) {
if (l[x]) downs(l[x]);
if (r[x]) downs(r[x]);
lzs[x] = 0;
}
if (lz[x]) {
if (l[x]) downa(l[x], lz[x]);
if (r[x]) downa(r[x], lz[x]);
lz[x] = 0;
}
}
void down_line(int x) {
if (nrt(x)) down_line(fa[x]);
down(x);
}
void rotate(int x) {
int y = fa[x];
int z = fa[y];
int b = (ls(x) ? r[x] : l[x]);
if (z && nrt(y)) (ls(y) ? l[z] : r[z]) = x;
if (ls(x)) r[x] = y, l[y] = b;
else l[x] = y, r[y] = b;
fa[x] = z;
fa[y] = x;
if (b) fa[b] = y;
up(y);
}
void Splay(int x) {
down_line(x);
while (nrt(x)) {
if (nrt(fa[x])) {
if (ls(x) == ls(fa[x])) rotate(fa[x]);
else rotate(x);
}
rotate(x);
}
up(x);
}
void access(int x) {
int lst = 0;
for (; x; x = fa[x]) {
Splay(x);
r[x] = lst;
up(x);
lst = x;
}
}
void make_root(int x) {
access(x);
Splay(x);
downs(x);
}
int find_root(int x) {
access(x);
Splay(x);
while (l[x]) {
down(x);
x = l[x];
}
Splay(x);
return x;
}
int split(int x, int y) {
make_root(x);
if (find_root(y) != x) return -1;
access(y);
Splay(y);
return y;
}
void cut(int x, int y) {//连和断的时候都要判断连通
make_root(x);
if (find_root(y) != x) return ;
access(y);
Splay(y);
l[y] = 0;
fa[x] = 0;
}
void link(int x, int y) {
make_root(x);
if (find_root(y) != x)
fa[x] = y;
}
ll gcd(ll x, ll y) {
if (!y) return x;
return gcd(y, x % y);
}
void write(ll x, ll y) {
ll GCD = gcd(x, y);
x /= GCD; y /= GCD;
printf("%lld/%lld\n", x, y);
}
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &val[i]), sz[i] = 1, sum[i] = lsum[i] = rsum[i] = ans[i] = val[i];
for (int i = 1; i < n; i++) {
scanf("%d %d", &x, &y);
link(x, y);
}
while (m--) {
scanf("%d %d %d", &op, &x, &y);
if (op == 1) {
cut(x, y);
continue;
}
if (op == 2) {
link(x, y);
continue;
}
if (op == 3) {
scanf("%d", &d);
if (find_root(x) != find_root(y)) continue;//记得操作前要判断是否连通
x = split(x, y);
downa(x, d);
continue;
}
if (op == 4) {
if (find_root(x) != find_root(y)) {printf("-1\n");continue;}
x = split(x, y);
write(ans[x], 1ll * sz[x] * (sz[x] + 1) / 2);
continue;
}
}
return 0;
}