Link
把条件容斥一下,先乘个二,然后都加上一端在\([l,r]\)内一端在\([l,r]\)外的路径条数,题目给的限制就变成了起点在\([l,r]\)内的合法路径条数大于起点在\([l,r]\)外的合法路径条数。
那么点分治算出以每个点为起点的合法路径条数,然后枚举右端点计算合法左端点,易知合法左端点随右端点递增而递增,故可用一单调指针维护。
#include<cstdio>
#include<cctype>
#include<vector>
#include<utility>
#define pi pair<int,int>
#define fi first
#define se second
using std::vector;
using std::pair;
using ll=long long;
const int N=100007;
int max(int a,int b){return a>b? a:b;}
int min(int a,int b){return a<b? a:b;}
int read(){int x=0,c=getchar();while(!isdigit(c))c=getchar();while(isdigit(c))x=x*10+c-48,c=getchar();return x;}
vector<int>e[N];
int n,top,a[N],vis[N],dis[N],son[N],size[N],t[N<<1];pi stk[N];ll g[N],sum,now,ans;
void findroot(int u,int fa,int&root)
{
son[u]=0,size[u]=1;
for(int v:e[u]) if(!vis[v]&&v^fa) findroot(v,u,root),size[u]+=size[v],son[u]=max(son[u],size[v]);
son[u]=max(son[u],son[0]-size[u]),root=son[u]<=son[root]? u:root;
}
void getdis(int u,int fa)
{
dis[u]=dis[fa]+a[u],stk[++top]={u,dis[u]},size[u]=1;
for(int v:e[u]) if(!vis[v]&&v^fa) getdis(v,u),size[u]+=size[v];
}
void calc(int u,int fa,int f)
{
top=0,getdis(u,fa);
for(int i=1;i<=top;++i) ++t[n+stk[i].se];
for(int i=1,s;i<=top;++i) s=f*t[n+(fa? a[fa]:a[u])-stk[i].se],sum+=s,g[stk[i].fi]+=s;
for(int i=1;i<=top;++i) --t[n+stk[i].se];
}
void solve(int u,int S)
{
int root=0;
son[0]=S,findroot(u,0,root);
vis[root]=1,dis[root]=a[root],calc(root,0,1);
for(int v:e[root]) if(!vis[v]) calc(v,root,-1);
for(int v:e[root]) if(!vis[v]) solve(v,size[v]);
}
int main()
{
n=read();
for(int i=1;i<=n;++i) a[i]=read()? 1:-1;
for(int i=1,u,v;i<n;++i) u=read(),v=read(),e[u].push_back(v),e[v].push_back(u);
solve(1,n);
for(int l=1,r=1;r<=n;ans+=l-1,++r) for(now+=g[r];l<=r&&now<<1>sum;now-=g[l++]);
printf("%lld",ans);
}