\(n\) 个点被 \(n-1\) 条边连接成了一颗树,给出 \(a \to b\) 和 \(c\to d\) 两个区间,表示点的标号请你求出两个区间内各选一点之间的最大距离,即你需要求出 \(\max(dis(i,j) |a<=i<=b,c<=j<=d)\)
Solution
如果我们有 \(S_1,S_2\) 两个点集合,那么 \(S_1,S_2\) 各取一个点获得最大距离时,所选取的点一定在各自的直径端点中选取
于是我们可以对每个区间记录它的直径,然后根据上面的性质写一个 merge
这样就可以用线段树维护了
求距离用 ST 表 LCA,每次询问时间 \(O(1)\)
于是总体复杂度 \(O(n\log n)\),常数略大
注意要把和 \(null\) 相关的距离设为小于零的数
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1000005;
int n,m,t1,t2,t3,t4;
struct edge{int v,w;};
vector <edge> g[N];
int s[N],ind,dfn[N],lg2[N],dis[N],dep[N],st[N][20];
struct node{
int p,q;
};
void dfs(int p) {
dfn[p]=++ind;
s[ind]=p;
for(edge e:g[p]) {
int q=e.v,w=e.w;
if(dfn[q]==0) {
dep[q]=dep[p]+1;
dis[q]=dis[p]+w;
dfs(q);
s[++ind]=p;
}
}
}
int lca(int p,int q) { //cout<<"begin"<<endl;
p=dfn[p]; q=dfn[q];
if(p>q) swap(p,q);
int l=lg2[q-p+1];
//cout<<p<<", "<<q<<", "<<l<<endl;
if(dep[s[st[p][l]]]<dep[s[st[q-(1<<l)+1][l]]])
return s[st[p][l]];
else return s[st[q-(1<<l)+1][l]];
}
int dist(int p,int q) {
if(p==0 || q==0) return -1; //!
int l=lca(p,q);
return dis[p]+dis[q]-2*dis[l];
}
node merge(node x,node y) {
node ret={0,0};
int ans=-1;
if(dist(x.p,x.q)>ans) ret=x,ans=dist(x.p,x.q);
if(dist(y.p,y.q)>ans) ret=y,ans=dist(y.p,y.q);
if(dist(x.p,y.p)>ans) ret={x.p,y.p},ans=dist(x.p,y.p);
if(dist(x.p,y.q)>ans) ret={x.p,y.q},ans=dist(x.p,y.q);
if(dist(x.q,y.p)>ans) ret={x.q,y.p},ans=dist(x.q,y.p);
if(dist(x.q,y.q)>ans) ret={x.q,y.q},ans=dist(x.q,y.q);
return ret;
}
node a[N*4];
void build(int p,int l,int r) {
if(l==r) {
a[p]={l,l};
}
else {
build(p*2,l,(l+r)/2);
build(p*2+1,(l+r)/2+1,r);
a[p]=merge(a[p*2],a[p*2+1]);
}
}
node query(int p,int l,int r,int ql,int qr) {
if(l>qr || r<ql) return {0,0};
if(l>=ql && r<=qr) return a[p];
return merge(query(p*2,l,(l+r)/2,ql,qr),
query(p*2+1,(l+r)/2+1,r,ql,qr));
}
signed main() {
scanf("%lld",&n);
for(int i=1;i<=2*n;i++) lg2[i]=log2(i);
for(int i=1;i<n;i++) {
scanf("%lld%lld%lld",&t1,&t2,&t3);
g[t1].push_back({t2,t3});
g[t2].push_back({t1,t3});
}
dfs(1);
for(int i=1;i<=2*n;i++) st[i][0]=i;
for(int j=1;j<=18;j++) {
for(int i=1;i<=ind;i++) {
if(dep[s[st[i][j-1]]]<dep[s[st[i+(1<<(j-1))][j-1]]])
st[i][j]=st[i][j-1];
else st[i][j]=st[i+(1<<(j-1))][j-1];
}
}
build(1,1,n);
scanf("%lld",&m);
for(int i=1;i<=m;i++) {
scanf("%lld%lld%lld%lld",&t1,&t2,&t3,&t4);
node p=query(1,1,n,t1,t2);
node q=query(1,1,n,t3,t4);
int ans=0;
ans=max(ans,dist(p.p,q.p));
ans=max(ans,dist(p.p,q.q));
ans=max(ans,dist(p.q,q.p));
ans=max(ans,dist(p.q,q.q));
printf("%lld\n",ans);
}
}