题意:给定带点权边权的树,定义路径的花费=路径边权和e+起点点权w[s]*终点点权w[t]。N<2e5,e,w<1e6;
思路:首先,需要树分治。 然后得到方程dp[i]=min{ dis[i]+dis[j]+w[i]*w[j] },很显然需要斜率优化。
注意维护凸包的时候是需要保证w[j]是单调的,这样才能用不等式维护队尾。 由于w[i]不是对应的队尾,所以我们还要二分凸包。
还有个问题,怎么确定我们得到的i和j不是在同一个子树呢? 因为如果在一颗子树的时候dp[i]=dis[i]+dis[j]+w[i]*w[j]-2*dis[LCA]。 其实没必要考虑这个问题,因为当LCA为根的时候会更新答案。(这一点想不到估计要很难去维护了)
#include<bits/stdc++.h>
#define ll long long
#define rep(i,a,b) for(int i=a;i<=b;i++)
using namespace std;
const int maxn=;
int Laxt[maxn],Next[maxn],To[maxn],Len[maxn],cnt;
int sz[maxn],son[maxn],rt,all,vis[maxn],S[maxn],tot;
ll ans[maxn],a[maxn],dis[maxn],sum; int q[maxn],top;
bool cmp(int x,int y)
{
int xx=x,yy=y;
if(a[xx]==a[yy]) return dis[xx]<dis[yy];
return a[xx]<a[yy];
}
ll getans(int p,int k)
{
return dis[k]+dis[p]+a[p]*a[k];
}
void add(int u,int v,int w)
{
Next[++cnt]=Laxt[u]; Laxt[u]=cnt; To[cnt]=v; Len[cnt]=w;
}
void dfs1(int u,int f)
{
sz[u]=; son[u]=;
for(int i=Laxt[u];i;i=Next[i]){
if(To[i]!=f&&!vis[To[i]]) {
dfs1(To[i],u);
sz[u]+=sz[To[i]];
son[u]=max(son[u],sz[To[i]]);
}
}
son[u]=max(son[u],all-son[u]);
if(son[u]<son[rt]) rt=u;
}
void cal(int p)
{
if(top==) return ;int L=,R=top-,Mid;
ans[p]=min(ans[p],getans(p,q[top]));
while(L<=R){
Mid=(L+R)>>;
ll tmp1=getans(p,q[Mid]),tmp2=getans(p,q[Mid+]);
if(tmp1<tmp2) R=Mid-,ans[p]=min(ans[p],tmp1);
else L=Mid+,ans[p]=min(ans[p],tmp2);
}
}
void get(int u,int f)
{
cal(u);
for(int i=Laxt[u];i;i=Next[i])
if(To[i]!=f&&!vis[To[i]]) get(To[i],u);
}
bool check(int p){
return (dis[p]-dis[q[top]])*(a[p]-a[q[top-]])<=
(dis[p]-dis[q[top-]])*(a[p]-a[q[top]]);
}
void ADD(int p)
{
if(top&&a[p]==a[q[top]]&&dis[p]<dis[q[top]]) top--;
while(top>&&check(p)) top--;
q[++top]=p;
}
void get(int u,int f,ll D)
{
dis[u]=D; sz[u]=; S[++tot]=u;
for(int i=Laxt[u];i;i=Next[i])
if(To[i]!=f&&!vis[To[i]]){
get(To[i],u,D+Len[i]);
sz[u]+=sz[To[i]];
}
}
void solve(int u,int f)
{
vis[u]=; dis[u]=;
top=; tot=; S[++tot]=u;
for(int i=Laxt[u];i;i=Next[i]){
if(!vis[To[i]]&&To[i]!=f)
get(To[i],u,Len[i]);
}
sort(S+,S+tot+,cmp);
rep(i,,tot) ADD(S[i]);
rep(i,,tot) cal(S[i]);
for(int i=Laxt[u];i;i=Next[i]){
int v=To[i];
if(!vis[v]&&v!=f) {
all=sz[v]; rt=;
dfs1(v,); solve(rt,);
}
}
}
int main()
{
int N,u,v,w;
scanf("%d",&N); son[]=N+;
rep(i,,N) scanf("%lld",&a[i]);
rep(i,,N) ans[i]=a[i]*a[i];
rep(i,,N-){
scanf("%d%d%d",&u,&v,&w);
add(u,v,w); add(v,u,w);
}
all=N; rt=;
dfs1(,); solve(rt,);
rep(i,,N) sum+=ans[i];
printf("%lld\n",sum);
return ;
}