Description
九条可怜是一个热爱阅读的女孩子。
这段时间,她看了一本非常有趣的小说,这本小说的架空世界引起了她的兴趣。
这个世界有 \(n\) 个城市,这 \(n\) 个城市被恰好 \(n − 1\) 条双向道路联通,即任意两个城市都可以互相到达。同时城市 \(1\) 坐落在世界的中心,占领了这个城市就称霸了这个世界。
在最开始,这 \(n\) 个城市都不在任何国家的控制之下,但是随着社会的发展,一些城市会崛起形成国家并夺取世界的霸权。为了方便,我们标记第 \(i\) 个城市崛起产生的国家为第 \(i\) 个国家。
在第 \(i\) 个城市崛起的过程中,第 \(i\) 个国家会取得城市 \(i\) 到城市 \(1\) 路径上所有城市的控制权。新的城市的崛起往往意味着战争与死亡,若第 \(i\) 个国家在崛起中,需要取得一个原本被国家 \(j(j \ne i)\) 控制的城市的控制权,那么国家 \(i\) 就必须向国家 \(j\) 宣战并进行战争。
现在,可怜知道了,在历史上,第 \(i\) 个城市一共崛起了 \(a_i\) 次。但是这些事件发生的相对顺序已经无从考究了,唯一的信息是,在一个城市崛起称霸世界之前,新的城市是不会崛起的。
战争对人民来说是灾难性的。可怜定义一次崛起的灾难度为崛起的过程中会和多少不同的国家进行战争(和同一个国家进行多次战争只会被计入一次)。可怜想要知道,在所有可能的崛起顺序中,灾难度之和最大是多少。
同时,在考古学家的努力下,越来越多的历史资料被发掘了出来,根据这些新的资料,可
怜会对 \(a_i\) 进行一些修正。具体来说,可怜会对 \(a_i\) 进行一些操作,每次会将 \(a_x\) 加上 \(w\)。她希望
在每次修改之后,都能计算得到最大的灾难度。
然而可怜对复杂的计算并不感兴趣,因此她想让你来帮她计算一下这些数值。
对题面的一些补充:
- 同一个城市多次崛起形成的国家是同一个国家,这意味着同一个城市连续崛起两次是不会和任何国家开战的:因为这些城市原来就在它的控制之下。
- 在历史的演变过程中,第 \(i\) 个国家可能会有一段时间没有任何城市的控制权。但是这并不意味着第 \(i\) 个国家灭亡了,在城市 \(i\) 崛起的时候,第 \(i\) 个国家仍然会取得 \(1\) 到 \(i\) 路径上的城市的控制权。
测试点 | \(n\) | \(m\) | 其他约定 |
---|---|---|---|
1 | \(\le 10\) | \(=0\) | \(a_i=1\) |
2-3 | \(\le 150000\) | \(\le 150000\) | 第 \(i\) 条道路连接 \(i\) 和 \(i + 1\) |
4-5 | \(\le 150000\) | \(=0\) | - |
6-8 | \(\le 150000\) | \(\le 150000\) | - |
9-10 | \(\le 4 \times 10^5\) | \(\le 4\times 10^5\) | - |
对于 \(100\%\) 的数据,\(1\le a_i,w_i\le 10^7,1\le x_i\le n\)。
Solution
真的很难啊。/kk
首先考虑没有修改的时候我们怎么做。我们可以考虑计算一个点会产生的贡献,可以想到的是,城市 \(i\) 对城市 \(j\) 产生的贡献仅在 \(\text{lca}(i,j)\) 处产生,所以我们枚举当前节点为 lca,那么每个子树就相当于一种颜色,子树内 access 操作总数就是该颜色个数,最多产生的贡献就是确定一种颜色序列相邻两个不同色个数,即:
\[\min\{t-1,2(t-h)\} \]其中 \(t\) 是颜色总数,\(h\) 是颜色个数最多的个数。其意义是,若出现次数最多的颜色个数超过一半,就把其他颜色往该颜色交错放,否则就可以达到上届 \(t-1\) 。
可以想到的是,我的 lca 不会影响子树放的方法,所以每个点都可以达到最优情况。
考虑增加了修改操作,我们可以使用 lct 进行维护。我们可以设 \(f_i\) 表示以 \(i\) 为根的子树内 access 操作总数,那么我们存在边 \((u,fa_u)\) 当且仅当 \(2f_u\ge f_{fa_u}+1\) 。可以想到这个时候 \(fa_u\) 的贡献一定是 \(2(f_{fa_u}-f_u)\) 。
那么考虑增值操作 \((u,w)\),可以发现只会影响 \(1\to u\) 上的点,而且链上实边仍会是实边,因为 \(2f_u\ge f_{fa_u}+1\rightarrow 2(f_u+w)\ge f_{fa_u}+w+1\) 。而这也满足 lct 所需性质,所以我们可以做到复杂度 \(\Theta(n\log n)\) 。
Code
#include <bits/stdc++.h>
using namespace std;
#define Int register int
#define int long long
#define MAXN 400005
template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> inline void chkmax (T &a,T b){a = max (a,b);}
template <typename T> inline void chkmin (T &a,T b){a = min (a,b);}
int n,m;
vector <int> g[MAXN];
#define ls(x) son[x][0]
#define rs(x) son[x][1]
int ans,fa[MAXN],val[MAXN],sum[MAXN],fak[MAXN],son[MAXN][2];
bool rnk (int x){return son[fa[x]][1] == x;}
bool isroot (int x){return son[fa[x]][rnk(x)] != x;}
void pushup (int x){sum[x] = sum[ls(x)] + sum[rs(x)] + val[x] + fak[x];}
void rotate (int x){
int y = fa[x],z = fa[y],k = rnk(x),w = son[x][!k];
if (!isroot (y)) son[z][rnk(y)] = x;son[x][!k] = y,son[y][k] = w;
if (w) fa[w] = y;fa[y] = x,fa[x] = z;
pushup (y),pushup (x);
}
void splay (int x){
while (!isroot (x)){
int y = fa[x];
if (!isroot (y)) rotate (rnk(x) == rnk(y) ? y : x);
rotate (x);
}
}
int calc (int u,int t,int h){return rs(u) ? 2 * (t - h) : (val[u] * 2 > t ? 2 * (t - val[u]) : t - 1);}
void access (int x,int w){
int t,h,tmp = x;
for (Int y = 0;x;x = fa[y = x]){
splay (x),t = sum[x] - sum[ls(x)],h = sum[rs(x)];
ans -= calc (x,t,h),(x == tmp) ? (val[x] += w) : (fak[x] += w),sum[x] += w,t += w;
if (h * 2 <= t) rs(x) = 0,fak[x] += h,h = (x == tmp ? h : 0);
if (sum[y] * 2 > t) rs(x) = y,fak[x] -= sum[y],h = sum[y];
ans += calc (x,t,h),pushup (x);
}
}
void init (int u,int par){
sum[u] = val[u];int p = 0,mx = val[u];
for (Int v : g[u]) if (v ^ par){
fa[v] = u,init (v,u),sum[u] += sum[v];
if (sum[v] > mx) mx = sum[v],p = v;
}
ans += min (sum[u] - 1,2 * (sum[u] - mx));
if (mx * 2 > sum[u]) rs(u) = p;
fak[u] = sum[u] - sum[rs(u)] - val[u];
}
signed main(){
read (n,m);
for (Int i = 1;i <= n;++ i) read (val[i]);
for (Int i = 2,u,v;i <= n;++ i) read (u,v),g[u].push_back (v),g[v].push_back (u);
init (1,0),write (ans),putchar ('\n');
while (m --> 0){
int u,w;read (u,w);
access (u,w),write (ans),putchar ('\n');
}
return 0;
}