hdu7024 Penguin Love Tour(2021杭电暑假多校5)树形dp

题意

给定一棵\(n\)个点的树,树的每个边有个权值\(w\),每个点有个权值\(p\)。每个点可以把相邻的某一条边边权减\(p\)。最小化直径。(\(1\le n,w\le{10}^5,0\le p\le{10}^5\))

思路

考虑二分答案,设为\(limit\)​。那么\(check\)​就是每棵子树最大的两条边之和不能超过\(limit\)​。设\(dp[u][0]\)​为节点\(u\)​这棵子树没有使用\(u\)​时,某个叶子到\(u\)​的最长路径的最小值。\(dp[u][1]\)​为已经使用了\(u\)​的最小值。那么有:

$ dp[u][0]=max_{v}{min(dp[v][0]+max(0,w_{u,v}-p[v]),dp[v][1]+w_{u,v})} \tag{1}$

\(dp[u][1]=min_{v_0}\{max(min(dp[v_0][0]+max(0,w_{u,v_0}-p[v_0]-p[u]),dp[v_0][1]+max(0,w_{u,v_0}-p[u])),\\{max_{v\not=v_0}\{min(dp[v][0]+max(0,w_{u,v}-p[v]),dp[v][1]+w_{u,v})\})}\} \tag{2}\)

然后又因为对儿子用了\(p[u]\)​后最长的儿子一定会在\((1)\)​中最长的三个中取,那么求\(dp[u][1]\)​只需要枚举\(v_0\)​为\((1)\)​中最大的三个即可。

代码

#include <bits/stdc++.h>
using namespace std;
using ll=long long;
using pii=pair<int,int>;
using pli=pair<ll,int>;
constexpr ll inf=1e18;

inline char gc() {
  static constexpr int BufferSize = 1 << 22 | 5;
  static char buf[BufferSize], *p, *q;
  static std::streambuf *i = std::cin.rdbuf();
  return p == q ? p = buf, q = p + i->sgetn(p, BufferSize), p == q ? EOF : *p++ : *p++;
}
struct Reader {
  template <class T>
  Reader &operator>>(T &w) {
    char c, p = 0;
    for (; !std::isdigit(c = gc());) if (c == '-') p = 1;
    for (w = c & 15; std::isdigit(c = gc()); w = w * 10 + (c & 15)) ;
    if (p) w = -w;
    return *this;
  }
} fin;

template<int N>
struct Max{
  int n=0;
  array<pli,N> a;
  void insert(pli x) {
    if(n!=0)
      for(int i=0;i<n;i++) {
        if(a[i]<x)
          swap(a[i],x);
      }
    if(n<N) a[n++]=x;
  }
  void erase(int id) {
    for(int i=0;i<n;i++) {
      if(a[i].second==id) {
        for(int j=i;j<n-1;j++) a[j]=a[j+1];
        n--;
        break;
      }
    }
  }
  ll sum(int cnt) {
    cnt=min(cnt,N);
    ll ans=0;
    for(int i=0;i<cnt;i++) ans+=a[i].first;
    return ans;
  }
  bool vis(int id) {
    for(int i=0;i<n;i++)
      if(a[i].second==id) return true;
    return false;
  }
};

void solve() {
  int n;
  ll L=0,R=0;
  fin>>n;
  vector<int> p(n+1);
  vector<vector<pii>> g(n+1);
  for(int i=1;i<=n;i++) fin>>p[i];
  for(int i=1,u,v,w;i<=n-1;i++) {
    fin>>u>>v>>w;
    g[u].push_back({v,w});
    g[v].push_back({u,w});
    R+=w;
  }

  vector<ll>dp[2];
  ll mid;
  bool flag;
  function<void(int,int)> dfs=[&](int u,int f) {
    int son=0;
    Max<3>s;
    for(int i=0;i<g[u].size();i++) {
      int v=g[u][i].first;
      int w=g[u][i].second;
      if(v==f) continue;
      dfs(v,u);
      if(!flag)return;
      son++;
      s.insert({min(dp[0][v]+max(w-p[v],0),dp[1][v]+w),i});
    }
    if(son==0) {
      dp[0][u]=0;
      return;
    }

    if(s.sum(2)<=mid)
      dp[0][u]=s.sum(1);
    dp[1][u]=inf;
    for(int i=0;i<g[u].size();i++) {
      int v=g[u][i].first;
      int w=g[u][i].second;
      if(v==f || !s.vis(i)) continue;
      Max<3> s1=s;
      s1.erase(i);
      s1.insert({min(dp[0][v]+max(w-p[v]-p[u],0),dp[1][v]+max(w-p[u],0)),i});
      if(s1.sum(2)<=mid)
        dp[1][u]=min(dp[1][u],s1.sum(1));
    }
    if(dp[0][u]==inf && dp[1][u]==inf)
      flag=false;
  };
 while(L<R) {
    mid=(L+R)/2;
    flag=true;
    dp[0]=dp[1]=vector<ll>(n+1,inf);
    dfs(1,0);
    if(flag && (dp[0][1]<=mid || dp[1][1]<=mid)) R=mid;
    else L=mid+1;
 }
  cout<<L<<'\n';
}

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  int T;
  fin>>T;
  while(T--) solve();

  return 0;
}
上一篇:【java】Function.identity()的含义


下一篇:切片学习转载文档-易懂好学