Description
有些黑点,问你选择不超过 \(k\) 个黑点的路径,路径权值最大是多少.
Sol
点分治.
这是qzc的论文题,不过我感觉他的翻译好强啊...我还是选择了自己去看题目...
点分治每次至少分一半,所以层数不超过过 \(logn\) 每次分治只考虑过根节点的情况.
我们想想如何统计答案.
\(f[i][j]\) 表示 \(i\) 节点的子树拥有 \(j\) 个黑点最大的边权.
\(g[i][j]\) 表示 \(i\) 节点的子树拥有不超过 \(j\) 个黑点的最大边权.
\(d[i]\) 表示 \(i\) 节点的子树中黑点的个数(或者表示成一条链上最多的黑点个数都可以).
然后就可以这样更新在最坏情况下是 \(O(nk)\) 的,但是我们如果按 \(d[i]\) 排一下序,从小到大更新,这样的更新的复杂度就是 \(O(\sum d_i)\) ,因为黑点个数是有限的,所以更新的复杂度就不超过 \(O(n)\) 了.
这样点分治一个 \(log\) ,排序一个 \(log\) ,所以总复杂度为 \(O(nlog^2n)\) .
感觉这道题也好强啊qwq..分治算法各种神奇..
Code
#include <bits/stdc++.h>
using namespace std; #define mpr make_pair
#define debug(a) cout<<#a<<"="<<a<<" " typedef pair< int,int > pr;
const int N = 2e5+50;
const int INF = 0x3f3f3f3f; int n,k,m; pr edge[N<<1];
int h[N],nxt[N<<1],cnte; int rt,ans;
int sz[N],t[N],usd[N];
pr d[N];
int f[N],g[N],bl[N]; inline int in(int x=0,char ch=getchar(),int v=1) {
while(ch>'9' || ch<'0') v=(ch=='-'?-1:v),ch=getchar();
while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();return x*v;
}
void Add_Edge(int fr,int to,int w) {
edge[++cnte]=mpr(to,w),nxt[cnte]=h[fr],h[fr]=cnte;
} void GetRoot(int u,int fa,int nn) {
sz[u]=1,t[u]=0;
for(int i=h[u],v;i;i=nxt[i]) if(!usd[v=edge[i].first] && v!=fa) {
GetRoot(v,u,nn),sz[u]+=sz[v],t[u]=max(t[u],sz[v]);
}t[u]=max(t[u],nn-sz[u]);
if(t[u]<t[rt]) rt=u;
} int GetD(int u,int fa) {
int td=bl[u];
for(int i=h[u],v;i;i=nxt[i]) if(!usd[v=edge[i].first] && v!=fa) {
td+=GetD(v,u);
}return td;
}
void GetF(int u,int fa,int c,int w) {
f[c]=max(f[c],w);
for(int i=h[u],v;i;i=nxt[i]) if(!usd[v=edge[i].first] && v!=fa) {
// cout<<u<<"-->"<<v<<" "<<c+bl[u]<<" "<<w+edge[i].second<<endl;
GetF(v,u,c+bl[v],w+edge[i].second);
}
}
void GetAns(int u,int n) {
usd[u]=1;int c=0;
for(int i=h[u],v;i;i=nxt[i]) if(!usd[v=edge[i].first]){
d[++c]=mpr(GetD(v,u),i);
}
sort(d+1,d+c+1);
int mxd=0,tt=min(k-bl[u],d[c].first),v,mx=0;
for(int i=0;i<=tt;i++) g[i]=-INF;
for(int i=1;i<=c;i++) {
tt=min(d[i].first,k-bl[u]);
for(int j=0;j<=tt;j++) f[j]=-INF;
v=edge[d[i].second].first;
GetF(v,u,bl[v],edge[d[i].second].second); // cout<<v<<" --> f[]"<<endl;
// for(int j=0;j<=tt;j++) cout<<f[j]<<" ";cout<<endl; for(int j=0,pp;j<=tt;j++) pp=min(mxd,k-bl[u]-j),ans=(g[pp]!=-INF&&f[j]!=-INF)?max(ans,g[pp]+f[j]):ans; for(int j=0;j<=tt;j++) g[j]=max(g[j],f[j]); // cout<<"g[]"<<endl;
// for(int j=0;j<=tt;j++) cout<<g[i]<<" ";cout<<endl; mxd=tt,mx=0;
for(int j=0;j<=mxd;j++) mx=max(mx,g[j]),g[j]=mx;
} ans=max(ans,g[mxd]); // debug(u)<<":"<<endl;
// for(int i=1;i<=c;i++) cout<<d[i].first<<" "<<edge[d[i].second].first<<endl;
// debug(ans)<<endl;
// cout<<"----------------------"<<endl; for(int i=h[u],v,nn;i;i=nxt[i]) if(!usd[v=edge[i].first]) {
rt=0,nn=sz[v]>sz[u] ? n-sz[u]: sz[v],GetRoot(v,v,nn),GetAns(rt,nn);
}
} int main() {
n=in(),k=in(),m=in();
for(int i=1,x;i<=m;i++) x=in(),bl[x]=1;
for(int i=1,u,v,w;i<n;i++)
u=in(),v=in(),w=in(),Add_Edge(u,v,w),Add_Edge(v,u,w); t[0]=n+1,rt=0,GetRoot((n+1)>>1,(n+1)>>1,n);
GetAns(rt,n); cout<<ans<<endl;
return 0;
} /*
8 2 3
3
5
7
1 3 1
2 3 10
3 4 -2
4 5 -1
5 7 6
5 6 5
4 8 3 12
*/