【LCA】求和VII @北京OI2018

求和VII

PROBLEM

时间限制: 2 Sec 内存限制: 256 MB

题目描述

master对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的k次方和,而且每次的k可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil并不会这么复杂的操作,你能帮他解决吗?

输入

第一行包含一个正整数n,表示树的节点数。

之后n−1行每行两个空格隔开的正整数i,j,表示树上的一条连接点i和点j的边。

之后一行一个正整数m,表示询问的数量。

之后每行三个空格隔开的正整数i,j,k,表示询问从点i到点j的路径上所有节点深度的k次方和。由于这个结果可能非常大,输出其对998244353取模的结果。

树的节点从1开始标号,其中1号节点为树的根。

输出

对于每组数据输出一行一个正整数表示取模后的结果。

样例输入

5

1 2

1 3

2 4

2 5

2

1 4 5

5 4 45

样例输出

33

503245989

提示

以下用d(i)表示第i个节点的深度。

对于样例中的树,有d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2。

因此第一个询问答案为(25+15+05) mod 998244353=33,第二个询问答案为(245+145+245) mod 998244353=503245989。

对于30%的数据,1≤n,m≤100;

对于60%的数据,1≤n,m≤1000;

对于100%的数据,1≤n,m≤300000,1≤k≤50。

SOLUTION

预处理每个点到根节点的50个和(k<=50) sum[i][k]

对每个询问x,y,求la = lca(x,y)。

答案就是sum[x][k]+sum[y][k]-sum[la][k]-sum[anc[la][0]][k];

CODE

#define IN_PC() freopen("C:\\Users\\hz\\Desktop\\in.txt","r",stdin)
#define IN_LB() freopen("C:\\Users\\acm2018\\Desktop\\in.txt","r",stdin)
#define OUT_PC() freopen("C:\\Users\\hz\\Desktop\\out.txt","w",stdout)
#define OUT_LB() freopen("C:\\Users\\acm2018\\Desktop\\out.txt","w",stdout)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 3e5 + 5;
const int INF = 0x3f3f3f3f;
const ll MOD = 998244353; int anc[MAXN][20],deep[MAXN],t;
ll sum[MAXN][55]; struct edge {
int v,nex;
} ed[MAXN*2]; int head[MAXN],cnt; void addedge(int u,int v) {
cnt++;
ed[cnt].v = v;
ed[cnt].nex = head[u];
head[u] = cnt;
} queue<int> q; void bfs() {
q.push(1);
deep[1] = 0;
while(q.size()) {
int x = q.front();
q.pop();
for(int i=head[x]; i; i=ed[i].nex) {
int y = ed[i].v;
if(deep[y]||y==1)continue;
deep[y] = deep[x]+1;
ll base = deep[y];
for(int i=1;i<=50;i++){
sum[y][i] = (sum[x][i]+base)%MOD;
base=base*deep[y]%MOD;
}
anc[y][0] = x;
for(int j=1;j<=t;j++){
anc[y][j] = anc[anc[y][j-1]][j-1];
}
q.push(y);
}
}
} int lca(int x,int y) {
if(deep[x]<deep[y])swap(x,y);
for(int i=t; i>=0; i--) //to same deep;
if(deep[y]<=deep[anc[x][i]])
x = anc[x][i];
if(x==y)return x;
for(int i=t; i>=0; i--)
if(anc[x][i]!=anc[y][i]) {
x = anc[x][i];
y = anc[y][i];
}
return anc[x][0];
} int main() {
// IN_LB();
int n;
scanf("%d",&n);
t = (int)(log(n)/log(2))+1;
for(int i=0; i<n-1; i++) {
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
bfs();
int m;
scanf("%d",&m);
for(int i=0; i<m; i++) {
int x,y,k;
scanf("%d%d%d",&x,&y,&k);
int la = lca(x,y);
printf("%lld\n",(sum[x][k]+sum[y][k]-sum[la][k]-sum[anc[la][0]][k]+MOD+MOD)%MOD);
}
return 0;
}
上一篇:How do you build a database?


下一篇:Spring 从零開始-05