树上的最远点对
【来源】
【题目描述】
n个点被n-1条边连接成了一颗树,给出ab和cd两个区间,表示点的标号请你求出两个区间内各选一点之间的最大距离,即你需要求出max{dis(i,j) |a<=i<=b,c<=j<=d}
(PS 建议使用读入优化)
【输入格式】
第一行一个数字 n n<=100000。 第二行到第n行每行三个数字描述路的情况,x,y,z (1<=x,y<=n,1<=z<=10000)表示x和y之间有一条长度为z的路。 第n+1行一个数字m,表示询问次数 m<=100000。 接下来m行,每行四个数a,b,c,d。
【输出格式】
共m行,表示每次询问的最远距离
【样例输入】
5
1 2 1
2 3 2
1 4 3
4 5 4
1
2 3 4 5
【样例输出】
10
【解析】
线段树+LCA
给你两个区间,问各从一个区间选择一个点,两个点之间的最长路是多少,这里需要注意就是如果第一个区间是a和b最远,第二个区间是c和d最远,那么答案一定是ab,cd,ac,ad,bc,bd,其中一个,于是我们只要用线段树维护合并,外加LCA求两个点的距离即可。
【代码】
#pragma GCC optimize(3,"Ofast","inline")
#pragma G++ optimize(3,"Ofast","inline")
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#define RI register int
#define re(i,a,b) for(RI i=a; i<=b; i++)
#define ms(i,a) memset(a,i,sizeof(a))
#define MAX(a,b) (((a)>(b)) ? (a):(b))
#define MIN(a,b) (((a)<(b)) ? (a):(b))
using namespace std;
typedef long long LL;
namespace IO {
template <typename T>
inline void read(T &x){
x=0;
char c=0;
T w=0;
while (!isdigit(c)) w|=c=='-',c=getchar();
while (isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar();
if(w) x=-x;
}
template <typename T>
inline void write(T x) {
if(x<0) putchar('-'),x=-x;
if(x<10) putchar(x+'0');
else write(x/10),putchar(x%10+'0');
}
template <typename T>
inline void writesp(T x) {
write(x);
putchar(' ');
}
template <typename T>
inline void writeln(T x) {
write(x);
putchar('\n');
}
}
using IO::read;
using IO::write;
using IO::writesp;
using IO::writeln;
const int N=1e5+5;
const int inf=1e9;
struct Edge {
int to,nt,w;
} e[N<<1];
struct Ans {
int len;
int a[2];
} t[N<<2];
int n,m,cnt,sum;
int h[N],dep[N],f[N][20],tin[N],tout[N];
inline void add(int a,int b,int c) {
e[cnt]=(Edge){b,h[a],c};
h[a]=cnt++;
}
void dfs(int x,int fa,int d) {
f[x][0]=fa;
tin[x]=++sum;
dep[x]=d;
for(int i=h[x]; i!=-1; i=e[i].nt) {
int v=e[i].to;
if(v==fa) continue;
dfs(v,x,d+e[i].w);
}
tout[x]=++sum;
}
inline int ancestor(int x,int y) {
return tin[x]<=tin[y] && tout[y]<=tout[x];
}
inline int lca(int x,int y) {
if(ancestor(x,y)) return x;
if(ancestor(y,x)) return y;
for(int i=16; i>=0; i--)
if(!ancestor(f[x][i],y)) x=f[x][i];
return f[x][0];
}
inline int dist(int x,int y) {
int k=lca(x,y);
return dep[x]+dep[y]-(dep[k]<<1);
}
#define lch (o<<1)
#define rch (o<<1|1)
#define mid ((l+r)>>1)
void pushup(int o) {
t[o]=t[lch];
if(t[o].len<t[rch].len) t[o]=t[rch];
for(int i=0; i<=1; i++) for(int j=0; j<=1; j++) {
int tmp=dist(t[lch].a[i],t[rch].a[j]);
if(tmp>t[o].len) {
t[o].len=tmp;
t[o].a[0]=t[lch].a[i];
t[o].a[1]=t[rch].a[j];
}
}
}
void build(int o,int l,int r) {
if(l==r) {
t[o].len=0;
t[o].a[0]=l;
t[o].a[1]=r;
return;
}
build(lch,l,mid);
build(rch,mid+1,r);
pushup(o);
}
Ans query(int o,int l,int r,int ll,int rr) {
if(l==ll && r==rr) return t[o];
if(rr<=mid) return query(lch,l,mid,ll,rr);
else if(ll>mid) return query(rch,mid+1,r,ll,rr);
else {
Ans la=query(lch,l,mid,ll,mid);
Ans ra=query(rch,mid+1,r,mid+1,rr);
Ans ta;
if(la.len>ra.len) ta=la;
else ta=ra;
for(int i=0; i<=1; i++) for(int j=0; j<=1; j++) {
int tmp=dist(la.a[i],ra.a[j]);
if(tmp>ta.len) {
ta.len=tmp;
ta.a[0]=la.a[i];
ta.a[1]=ra.a[j];
}
}
return ta;
}
}
int main() {
read(n);
memset(h,-1,sizeof(h));
for(int i=1; i<n; i++) {
int x,y,z;
read(x);
read(y);
read(z);
add(x,y,z);
add(y,x,z);
}
dfs(1,1,0);
for(int j=1; j<=16; j++) for(int i=1; i<=n; i++)
f[i][j]=f[f[i][j-1]][j-1];
build(1,1,n);
read(m);
while(m--) {
int a,b,c,d;
read(a);
read(b);
read(c);
read(d);
Ans la,ra;
int ans=0;
la=query(1,1,n,a,b);
ra=query(1,1,n,c,d);
for(int i=0; i<=1; i++) for(int j=0; j<=1; j++) {
int tmp=dist(la.a[i],ra.a[j]);
ans=MAX(ans,tmp);
}
writeln(ans);
}
return 0;
}