题解
为啥我写个线段树还得调 1h 啊?
考虑枚举每一种颜色 \(c\)。设 \(S_i\) 为 \(a_{1\dots i}\) 中 \(c\) 的出现次数,那么一个区间 \((l,r]\) 是合法的当且仅当 \(2S_r-r>2S_l-l\)。设 \(f(x)=2S_x-x\)。按顺序枚举 \(c\) 的每一个出现位置,设这个位置为 \(p\),下一个出现位置为 \(q\)。那么,对于 \([p,q)\) 中的每个位置 \(i\),\(f(i)=f(p)+p-i\)。因此,\([p,q)\) 之间的位置不会互相产生贡献。并且,\([0,p)\) 之间的位置对 \([p,q)\) 的贡献系数是一段斜率为 \(0\) 的线段加上一段斜率为 \(-1\) 的线段。由此我们便可以快速计算以 \([p,q)\) 中的某个位置为右端点的合法线段个数了。我们需要支持的操作是:
- 区间 \(+1\);
- 询问区间的和以及二阶和。
线段树或树状数组都可以胜任。
代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
#define For(Ti,Ta,Tb) for(int Ti=(Ta);Ti<=(Tb);++Ti)
#define Dec(Ti,Ta,Tb) for(int Ti=(Ta);Ti>=(Tb);--Ti)
typedef long long ll;
const int N=5e5+5;
int n,tp;
struct SegmentTree{
struct Node{
int l,r,clean;ll s[2],Add;
}t[N<<3];
void Pushup(int p){
t[p].s[0]=t[p*2].s[0]+t[p*2+1].s[0];
t[p].s[1]=t[p*2].s[1]+t[p*2+1].s[1]+(t[p*2+1].l-t[p].l)*t[p*2+1].s[0];
}
void PushAdd(int p,ll k){
t[p].s[0]+=k*(t[p].r-t[p].l+1);
t[p].s[1]+=1LL*(t[p].r-t[p].l+2)*(t[p].r-t[p].l+1)/2*k;
t[p].Add+=k;
}
void PushClean(int p){
t[p].clean=1,t[p].Add=0;
t[p].s[0]=t[p].s[1]=t[p].Add=0;
}
void Pushdown(int p){
if(t[p].clean){
PushClean(p*2),PushClean(p*2+1);
t[p].clean=0;
}
PushAdd(p*2,t[p].Add),PushAdd(p*2+1,t[p].Add);
t[p].Add=0;
}
void Build(int p,int l,int r){
t[p].l=l,t[p].r=r;
if(l==r) return;
Build(p*2,l,(l+r)/2),Build(p*2+1,(l+r)/2+1,r);
}
void Add(int p,int l,int r,ll k){
if(l<=t[p].l&&t[p].r<=r) return PushAdd(p,k);
Pushdown(p);
int mid=(t[p].l+t[p].r)>>1;
if(l<=mid) Add(p*2,l,r,k);
if(r>mid) Add(p*2+1,l,r,k);
Pushup(p);
}
pair<ll,ll> Query(int p,int l,int r){
if(l>t[p].r||r<t[p].l) return {0,0};
if(l<=t[p].l&&t[p].r<=r) return {t[p].s[0],t[p].s[0]*(t[p].l-l)+t[p].s[1]};
Pushdown(p);
auto resl=Query(p*2,l,r),resr=Query(p*2+1,l,r);
return {resl.first+resr.first,resl.second+resr.second};
}
}seg;
int a[N];
vector<int> occ[N];
int main(){
ios::sync_with_stdio(false),cin.tie(nullptr);
cin>>n>>tp;
int mx=0;
For(i,1,n) cin>>a[i],++a[i],mx=max(mx,a[i]),occ[a[i]].push_back(i);
seg.Build(1,1,(n+3)*2);
ll ans=0;
const int delt=n+3;
For(i,1,n){
if(!occ[i].size()) continue;
seg.PushClean(1);
seg.Add(1,delt-occ[i].front()+1,delt,1);
occ[i].push_back(n+1);
for(auto it=occ[i].begin();it!=prev(occ[i].end());++it){
int cnt=it-occ[i].begin()+1,l=2*cnt-*next(it)+1,r=2*cnt-*it;
ans+=seg.Query(1,1,delt+r).first*(r-l+1)-seg.Query(1,l+delt,r+delt).second;
seg.Add(1,l+delt,r+delt,1);
}
}
cout<<ans;
return 0;
}