题意:
给一棵有根树,对Q次询问,每次输入x,y,k。输出树上x到y的路径上点的深度的k次方和。
思路:
树上两点间路径的权值和很容易想到LCA, 然后发现可以预处理深度的k次方的前缀和。对每个x和lca之间点的深度肯定是连续和,其深度k次方和(不算lca点)是sum[d[x]][k] - sum[d[lca]][k]。
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<map>
#include<queue>
#include<vector>
#include<string>
#include<bitset>
#include<fstream>
using namespace std;
#define rep(i, a, n) for(int i = a; i <= n; ++ i);
#define per(i, a, n) for(int i = n; i >= a; -- i);
typedef long long ll;
typedef pair<int,int> PII;
const int N = 1e6 + 105;
const int mod = 998244353;
const double Pi = acos(- 1.0);
const int INF = 0x3f3f3f3f;
const int G = 3, Gi = 332748118;
ll qpow(ll a, ll b) { ll res = 1; while(b){ if(b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1;} return res; }
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
// bool cmp(int a, int b){return a > b;}
//
int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
ll sum[N][60];
void add(int u, int v){
to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}
//lca
int t;
int d[N], dist[N], f[N][20];
queue<int> q;
void bfs(){
q.push(1);
d[1] = 0;
while(q.size()){
int u = q.front(); q.pop();
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(d[v] != -1) continue;
d[v] = d[u] + 1;
f[v][0] = u;
for(int j = 1; j <= t; ++ j)
f[v][j] = f[f[v][j - 1]][j - 1];
q.push(v);
}
}
}
int Lca(int x,int y)
{
//调整到同样高度
if(d[x] > d[y]) swap(x, y);
for(int i = t; i >= 0; -- i)
if(d[f[y][i]] >= d[x]) y = f[y][i];
//特殊情况
if(x == y) return x;
//一般情况
for(int i = t; i >= 0; -- i)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
int main()
{
scanf("%d",&n);
cnt = 0;
memset(d, -1, sizeof(d));
for(int i = 0; i <= n; ++ i) head[i] = -1;
for(int i = 1; i <= n; ++ i){
ll tp = 1;
for(int j = 1; j <= 50; ++ j){
tp = tp * 1ll * i % mod;
sum[i][j] = (sum[i - 1][j] + tp) % mod;
}
}
for(int i = 1; i < n; ++ i){
int x, y; scanf("%d%d",&x,&y);
add(x, y);
}
t=(int)(log(n)/log(2))+1;
bfs();
int Q; scanf("%d",&Q);
while(Q --){
int x, y; ll k; scanf("%d%d%lld",&x,&y,&k);
int lca = Lca(x, y);
ll res = ((sum[d[x]][k] + sum[d[y]][k] - sum[d[lca]][k] * 2ll % mod + mod) % mod + qpow(d[lca], k)) % mod;
printf("%lld\n",res);
}
return 0;
}