这题卡常……而且目前还没有卡过去
首先以原树重心为根,向所有子树重心连边,可以建立一棵点分树
点分树有两个性质:
一个是树高只有log层
另一个是两点在点分树上的lca一定在原树上两点间的树上路径上
所以在原树上不断删点,并统计当前子树中的信息就好
至于如何统计,令 \(dp[i][j][k][l]\) 表示分治中心为 \(i\) ,到点 \(j\) ,第一条边颜色为 \(k\) ,最后一条边颜色为 \(j\) 的最大得分
转移挺好写的,询问时暴力枚举相关连边的颜色
留个坑,纯点分治还不会写呢
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
#define reg register int
#define fir first
#define sec second
#define make make_pair
#define pb push_back
#define min2(a, b) ((a)<(b)?(a):(b))
#define max2(a, b) ((a)>(b)?(a):(b))
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m, q;
int head[N], size;
struct edge{int to, next, val;}e[N*3];
inline void add(int s, int t, int w) {e[++size].to=t; e[size].val=w; e[size].next=head[s]; head[s]=size;}
namespace force{
bool none[N];
int dfs(int u, int to, int fa, int now, int sum) {
//cout<<"dfs "<<u<<' '<<to<<' '<<fa<<' '<<now<<' '<<sum<<endl;
if (u==to) return sum;
int ans=0;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa || none[v]) continue;
ans=max(ans, dfs(v, to, u, e[i].val, sum+(now!=e[i].val)));
}
if (!ans) none[u]=1;
return ans;
}
void solve() {
for (int i=1,u,v,w; i<=m; ++i) {
u=read(); v=read(); w=read();
add(u, v, w); add(v, u, w);
}
q=read();
if (!q) exit(0);
for (int i=1,x,y; i<=q; ++i) {
x=read(); y=read();
memset(none, 0, sizeof(bool)*(n+5));
printf("%d\n", dfs(x, y, 0, 0, 0));
}
exit(0);
}
}
namespace task1{
bool none[N];
int dfs(int u, int to, int fa, int now, int sum, int dis) {
//cout<<"dfs "<<u<<' '<<to<<' '<<fa<<' '<<now<<' '<<sum<<endl;
if (dis>60) {none[u]=1; return 0;}
if (u==to) return sum;
int ans=0;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa || none[v]) continue;
ans=max(ans, dfs(v, to, u, e[i].val, sum+(now!=e[i].val), dis+1));
}
if (!ans) none[u]=1;
return ans;
}
void solve() {
for (int i=1,u,v,w; i<=m; ++i) {
u=read(); v=read(); w=read();
add(u, v, w); add(v, u, w);
}
q=read();
if (!q) exit(0);
for (int i=1,x,y; i<=q; ++i) {
x=read(); y=read();
memset(none, 0, sizeof(bool)*(n+5));
printf("%d\n", dfs(x, y, 0, 0, 0, 0));
}
exit(0);
}
}
namespace task{
int fa[N][25], dep[N], lg[N], rot, siz[N], msiz[N], sumsiz;
bool rm[N];
struct st{int i, j, k; st() {} st(int a, int b, int c):i(a),j(b),k(c){}};
inline bool operator == (st a, st b) {return a.i==b.i&&a.j==b.j&&a.k==b.k;}
inline bool operator < (st a, st b) {return a.k<b.k;}
struct pair_hash{inline size_t operator () (pair<int, int> p) const {return hash<int>()(p.fir*p.sec+p.fir);}};
unordered_map<pair<int, int>, vector<int>*, pair_hash> mp{5000, pair_hash()};
unordered_map<pair<int, int>, vector<st>*, pair_hash> dp{20000, pair_hash()};
unordered_map<pair<int, int>, int, pair_hash> tem{50, pair_hash()};
struct edge{int to, next;}e2[N<<1];
inline void add(int s, int t) {e2[++size].to=t; e2[size].next=head[s]; head[s]=size;}
vector<int> e[N];
void getrt(int u, int fa) {
//cout<<"getrt"<<u<<' '<<fa<<endl;
siz[u]=1; msiz[u]=0;
for (register auto v:e[u]) {
if (!rm[v] && v!=fa) {
getrt(v, u);
siz[u]+=siz[v];
msiz[u]=max(msiz[u], siz[v]);
}
}
//cout<<"sumsiz: "<<sumsiz<<' '<<sumsiz-siz[u]<<' '<<msiz[u]<<' '<<msiz[rot]<<' '<<rot<<endl;
msiz[u]=max(msiz[u], sumsiz-siz[u]);
if (msiz[u]<msiz[rot]) rot=u;
}
void info(int u, int fa) {
//cout<<"info "<<u<<' '<<fa<<endl;
for (register auto v:e[u]) {
if (v==fa || rm[v]) continue;
if (rot==u) {
if (dp.find(make(u, v))==dp.end()) dp[make(u, v)]=new vector<st>;
auto t1=mp[make(u, v)]; auto t2=dp[make(u, v)];
for (register auto it:*t1)
t2->pb(st(it, it, 1));
}
else {
if (dp.find(make(rot, v))==dp.end()) dp[make(rot, v)]=new vector<st>;
for (register auto it:*dp[make(rot, u)])
for (register auto t:*mp[make(u, v)])
tem[make(it.i, t)] = max(tem[make(it.i, t)], it.k+(it.j!=t));
register auto t1=dp[make(rot, v)];
for (register auto it:tem) t1->pb(st(it.fir.fir, it.fir.sec, it.sec));
//assert(tem.size()<=9);
tem.clear();
}
info(v, u);
//cout<<"size: "<<dp.size()<<endl;
}
}
void build(int u) {
//cout<<"build "<<u<<endl;
rm[u]=1;
info(u, 0);
for (register auto v:e[u]) {
rot=0;
if (!rm[v]) {
sumsiz=siz[v];
getrt(v, u);
add(u, rot), add(rot, u);
//info(rot, u, t);
build(rot);
}
}
}
void dfs(int u, int pa) {
//cout<<"dfs "<<u<<' '<<pa<<endl;
for (reg i=1; i<25; ++i)
if (dep[u]>=(1<<i)) fa[u][i] = fa[fa[u][i-1]][i-1];
else break;
for (reg i=head[u],v; ~i; i=e2[i].next) {
v = e2[i].to;
if (v!=pa) dep[v]=dep[u]+1, fa[v][0]=u, dfs(v, u);
}
}
int lca(int a, int b) {
if (dep[a]<dep[b]) swap(a, b);
while (dep[a]>dep[b]) a=fa[a][lg[dep[a]-dep[b]]-1];
if (a==b) return a;
for (reg i=lg[dep[a]]-1; ~i; --i)
if (fa[a][i]!=fa[b][i])
a=fa[a][i], b=fa[b][i];
return fa[a][0];
}
void solve() {
for (reg i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
pair<int, int> t1, t2;
for (reg i=1,u,v,w; i<=m; ++i) {
u=read(); v=read(); w=read();
e[u].pb(v); e[v].pb(u);
t1=make(u, v), t2=make(v, u);
if (mp.find(t1)==mp.end()) mp[t1]=new vector<int>;
mp[t1]->pb(w);
if (mp.find(t2)==mp.end()) mp[t2]=new vector<int>;
mp[t2]->pb(w);
}
for (reg i=1,siz; i<=n; ++i) {
sort(e[i].begin(), e[i].end());
siz=unique(e[i].begin(), e[i].end())-e[i].begin();
e[i].resize(siz);
}
for (register auto it:mp) {
sort(it.sec->begin(), it.sec->end());
int siz=unique(it.sec->begin(), it.sec->end())-it.sec->begin();
it.sec->resize(min2(siz, 3));
}
msiz[0]=sumsiz=n;
getrt(1, 0);
int root=rot;
//cout<<"root: "<<root<<endl;
build(rot);
dep[root]=1;
dfs(root, 0);
//for (auto it:dp) cout<<it.fir.i<<' '<<it.fir.j<<' '<<it.fir.k<<' '<<it.fir.h<<' '<<it.sec<<endl;
q=read(); if (!q) exit(0);
for (reg i=1,x,y,t,ans; i<=q; ++i) {
x=read(); y=read();
if (x==y) {puts("0"); continue;}
t=lca(x, y); ans=0;
//cout<<"lca: "<<x<<' '<<y<<' '<<t<<endl;
if (t==x || t==y) {
//puts("pos1");
if (t!=x) swap(x, y);
for (register auto it:*dp[make(t, y)]) ans=max(ans, it.k);
printf("%d\n", ans);
continue;
}
for (register auto i:*dp[make(t, x)])
for (register auto j:*dp[make(t, y)])
ans = max(ans, i.k+j.k-(i.i==j.i));
printf("%d\n", ans);
}
//cout<<"size: "<<dp.size()<<endl;
exit(0);
}
}
signed main()
{
memset(head, -1, sizeof(head));
n=read(); m=read();
//if (n<100000) force::solve();
//else task1::solve();
task::solve();
return 0;
}