题目
点这里看题目。
分析
直接来做这个有趣的问题似乎显得太过棘手,不妨考虑一个较弱的问题:
\[\sum_{u=1}^n s_u \]假如当前根确定为 \(r\) ,那么就有:
\[\sum_{u=1}^ns_u=\sum_{u=1}^n(\operatorname{dist}(u,r)+1)\times a_u=\sum_{u=1}^n\operatorname{dist}(u,r)\times a_u+\sum_{u=1}^na_u \]后一个东西可以很好地维护,前一个东西可以用点分树或者其他乱七八糟的数据结构来维护。
以下即设 \(S=\sum a_u\) 。
下面介绍一个极其优雅的构造:
考虑 \(\sum_{u=1}^ns_u^2\) 会出现在哪里?比如在下面这个构造里面:
\[W=\sum_{u=1}^ns_u(S-s_u)=S\times \sum_{u=1}^n s_u-\sum_{u=1}^ns_u^2 \]为什么偏要选它?因为 \(\sum_{u=1}^ns_u(S-s_u)\) 有一个很漂亮的性质:它的值不随 \(r\) 的改变而改变!
不难说明下列等式成立:
\[\sum_{u=1}^n\sum_{v=1}^na_ua_v\times \operatorname{dist}(u,v)=\sum_{u=1}^ns_u(S-s_u) \]
现在我们只需要解决 \(\Delta W\) 即可,可以发现:
\[\Delta W=\Delta a_x\times \left(\sum_{u=1}^n a_u\times \operatorname{dist}(u,x)\right) \]最终后面的内容也转化为了前面的问题,那么这道题就解决了。时间为 \(O(n\log_2n)\) 。
小结:
-
这里的构造简直太漂亮了!
我甚至想不出来该怎么说它总之这也说明,在变化过程中的许多不变量是值得关注的(这样的内容其实也值得积累)。
代码
#include <cstdio>
#define rep( i, a, b ) for( int i = (a) ; i <= (b) ; i ++ )
#define per( i, a, b ) for( int i = (a) ; i >= (b) ; i -- )
typedef long long LL;
const int MAXN = 2e5 + 5, MAXLOG = 18;
template<typename _T>
void read( _T &x )
{
x = 0; char s = getchar(); int f = 1;
while( s < '0' || '9' < s ) { f = 1; if( s == '-' ) f = -1; s = getchar(); }
while( '0' <= s && s <= '9' ) { x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar(); }
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ) putchar( '-' ), x = - x;
if( 9 < x ) write( x / 10 );
putchar( x % 10 + '0' );
}
template<typename _T>
_T MAX( const _T a, const _T b )
{
return a > b ? a : b;
}
struct Edge
{
int to, nxt;
}Graph[MAXN << 1];
LL su[MAXN], wSu[MAXN], faWSu[MAXN];
int fath[MAXN][MAXLOG], dist[MAXN][MAXLOG];
int inter[MAXN], dep[MAXN];
int head[MAXN], A[MAXN], siz[MAXN];
int N, Q, cnt; LL W = 0, S = 0;
bool vis[MAXN];
void AddEdge( const int from, const int to )
{
Graph[++ cnt].to = to, Graph[cnt].nxt = head[from];
head[from] = cnt;
}
int GetCen( const int u, const int fa, const int all )
{
int mx = 0, ret = 0; siz[u] = 1;
for( int i = head[u], v ; i ; i = Graph[i].nxt )
if( ! vis[v = Graph[i].to] && v ^ fa )
{
ret |= GetCen( v, u, all );
siz[u] += siz[v], mx = MAX( mx, siz[v] );
}
if( ( MAX( mx, all - siz[u] ) << 1 ) <= all ) ret = u;
return ret;
}
void DFS( const int u, const int fa, const int fr, const int d, const int dis )
{
dist[u][d] = dis, fath[u][d] = fr;
su[fr] += A[u], wSu[fr] += 1ll * dis * A[u];
for( int i = head[u], v ; i ; i = Graph[i].nxt )
if( ! vis[v = Graph[i].to] && v ^ fa )
DFS( v, u, fr, d, dis + 1 );
}
void Divide( const int u, const int all, const int d )
{
fath[u][d] = u;
dep[u] = d, su[u] = A[u], vis[u] = true; LL tmp;
for( int i = head[u], v, t, nxt ; i ; i = Graph[i].nxt )
if( ! vis[v = Graph[i].to] )
{
tmp = wSu[u], DFS( v, u, u, d, 1 );
t = siz[v] > siz[u] ? all - siz[u] : siz[v];
inter[nxt = GetCen( v, u, t )] = v, faWSu[nxt] = wSu[u] - tmp;
Divide( nxt, t, d + 1 );
}
}
void InitW( const int u, const int fa )
{
su[u] = A[u];
for( int i = head[u], v ; i ; i = Graph[i].nxt )
if( ( v = Graph[i].to ) ^ fa )
InitW( v, u ), su[u] += su[v];
W += su[u] * ( S - su[u] );
}
void Update( const int u, const int delt )
{
per( i, dep[u], 0 )
{
su[fath[u][i]] += delt, wSu[fath[u][i]] += 1ll * delt * dist[u][i];
if( i ^ dep[u] ) faWSu[fath[u][i + 1]] += 1ll * delt * dist[u][i];
}
}
LL Calc( const int u )
{
LL ret = wSu[u];
for( int i = dep[u] - 1, v, w ; ~ i ; i -- )
{
v = fath[u][i], w = fath[u][i + 1];
ret += wSu[v] - faWSu[w] + ( su[v] - su[w] ) * dist[u][i];
}
return ret;
}
int main()
{
read( N ), read( Q );
rep( i, 1, N - 1 ) { int a, b;
read( a ), read( b );
AddEdge( a, b ), AddEdge( b, a );
}
rep( i, 1, N ) read( A[i] ), S += A[i];
InitW( 1, 0 );
rep( i, 1, N ) su[i] = 0;
Divide( GetCen( 1, 0, N ), N, 0 );
for( int opt, x, y ; Q -- ; )
{
read( opt ), read( x );
if( opt == 1 )
{
read( y );
int delt = y - A[x];
W += Calc( x ) * delt;
Update( x, delt );
S += delt, A[x] = y;
}
if( opt == 2 ) write( S * ( Calc( x ) + S ) - W ), putchar( '\n' );
}
return 0;
}