http://acm.hdu.edu.cn/showproblem.php?pid=4747
设我们输入的数组为 a[],我们需要从 1 到 n 遍历, 假设遍历到 i 时, 遍历的过程中用b[j]表示从 i 到 j 没出现的最小自然数
先从 n 到 1 扫一遍求出从 1 到各个点的b[j]值
然后遍历a[] 实际上就是不断的把当前a[i] 去掉,比如说去掉a[3]时,剩下的b[4]---b[n] 就表示从4到其他后续点形成的区间中没出现的最小自然数
要知道从 i 到 n ,b[]的值始终是单调递增的
我们每去掉当前a[i]会对b[]数组产生影响,
设下一个和a[i]相等的数出现的位置是 r 那么去掉a[i] 对 r 以及 r 以后的b[] 没有影响
在 i 和 r 之间受影响的段b[]是大于等于a[i]的那一段 假设是(l,r), 这个段内的b[]都大于等于a[i]
去掉a[i]的影响就是这个段内的b[] 都要等于 a[i]
找到r可以事先标记,找 l 和更新段 (l,r) 有两种方法
1,二分找到 l ,然后遍历更新段 (l,r) 这样代码比较短,也比较易懂,但比较耗时,不过可以过
2,线段树维护 这样代码量会比较大,不过耗时少,线段树的解法应该比较标准
两种代码:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<string>
#include<cstring>
#include<cmath>
#include<set>
#include<vector>
#include<list>
#include<stack>
#include<queue>
#include<map> using namespace std; typedef long long ll;
typedef pair<int,int> pp; const int INF=0x3f3f3f3f; const int N=200002;
bool exist[N];
int a[N],next[N],f[N];
int b[N];
int bsh(int l,int r,int k)
{
while(l<=r)
{
int mid=(l+r)>>1;
if(b[mid]<=k) l=mid+1;
else r=mid-1;
}
return r;
}
int main()
{
//freopen("data.in","r",stdin);
int n;
while(scanf("%d",&n)!=EOF)
{
if(n==0) break;
for(int i=1;i<=n;++i)
scanf("%d",&a[i]);
for(int i=0;i<=n;++i)
f[i]=n+1;
for(int i=n;i>=1;--i)
if(a[i]<n)
{
next[i]=f[a[i]];
f[a[i]]=i;
}
ll ans=0;
memset(exist,false,sizeof(exist));
ll tmp=0;int l=0;
for(int i=1;i<=n;++i)
{
if(a[i]<n)
{
exist[a[i]]=true;
while(exist[l]) ++l;
}
b[i]=l;
tmp+=b[i];
}
ans=tmp;
for(int i=1;i<n;++i)
{
if(a[i]<n)
{
int r=next[i];
int l=bsh(i,r-1,a[i]);
for(int j=l+1;j<r;++j)
{
tmp-=(b[j]-a[i]);
b[j]=a[i];
}
}
tmp-=b[i];
ans+=tmp;
}
cout<<ans<<endl;
}
return 0;
} #include<iostream>
#include<cstdio>
#include<algorithm>
#include<string>
#include<cstring>
#include<cmath>
#include<set>
#include<vector>
#include<list>
#include<stack>
#include<queue>
#include<map> using namespace std; typedef long long ll;
typedef pair<int,int> pp; const int INF=0x3f3f3f3f; const int N=200002;
bool exist[N];
int a[N],next[N],f[N];
int b[N];
struct node
{
int l,r,k,least;
ll sum;
}tr[N*4];
void build(int x,int l,int r)
{
tr[x].l=l;tr[x].r=r;tr[x].k=-1;
if(l==r)
{
tr[x].least=b[l];
tr[x].sum=b[l];
return ;
}
int mid=(l+r)>>1;
build((x<<1),l,mid);
build((x<<1)|1,mid+1,r);
tr[x].least=min(tr[x<<1].least,tr[(x<<1)|1].least);
tr[x].sum=(tr[x<<1].sum+tr[(x<<1)|1].sum);
}
void update(int x,int l,int r,int k)
{
if(l>r) return ;
if(tr[x].l==l&&tr[x].r==r)
{
tr[x].least=k;
tr[x].k=k;
tr[x].sum=(ll)k*(tr[x].r-tr[x].l+1);
return ;
}
if(tr[x].k!=-1)
{
tr[x<<1].k=tr[x].k;tr[x<<1].least=tr[x<<1].k;
tr[x<<1].sum=(ll)tr[x<<1].k*(tr[x<<1].r-tr[x<<1].l+1);
tr[(x<<1)|1].k=tr[x].k;tr[(x<<1)|1].least=tr[(x<<1)|1].k;
tr[(x<<1)|1].sum=(ll)tr[(x<<1)|1].k*(tr[(x<<1)|1].r-tr[(x<<1)|1].l+1);
tr[x].k=-1;
}
int mid=(tr[x].l+tr[x].r)>>1;
if(r<=mid)
update(x<<1,l,r,k);
else if(l>mid)
update((x<<1)|1,l,r,k);
else
{
update(x<<1,l,mid,k);
update((x<<1)|1,mid+1,r,k);
}
tr[x].least=min(tr[x<<1].least,tr[(x<<1)|1].least);
tr[x].sum=(tr[x<<1].sum+tr[(x<<1)|1].sum);
tr[x].k=-1;
}
int get(int x,int l,int r,int w)
{
if(tr[x].l==tr[x].r)
{
if(tr[x].least>w)
return (l-1);
return l;
}
if(tr[x].k!=-1)
{
tr[x<<1].k=tr[x].k;tr[x<<1].least=tr[x<<1].k;
tr[x<<1].sum=(ll)tr[x<<1].k*(tr[x<<1].r-tr[x<<1].l+1);
tr[(x<<1)|1].k=tr[x].k;tr[(x<<1)|1].least=tr[(x<<1)|1].k;
tr[(x<<1)|1].sum=(ll)tr[(x<<1)|1].k*(tr[(x<<1)|1].r-tr[(x<<1)|1].l+1);
tr[x].k=-1;
}
int mid=(tr[x].l+tr[x].r)>>1;
if(r<=mid)
return get(x<<1,l,r,w);
else if(l>mid)
return get((x<<1)|1,l,r,w);
else
{
if(tr[(x<<1)|1].least<=w)
return get((x<<1)|1,mid+1,r,w);
else
return get(x<<1,l,mid,w);
}
}
ll gsum(int x,int l,int r)
{
if(l>r) return 0; if(tr[x].l==l&&tr[x].r==r)
return tr[x].sum;
if(tr[x].k!=-1)
{
tr[x<<1].k=tr[x].k;tr[x<<1].least=tr[x<<1].k;
tr[x<<1].sum=(ll)tr[x<<1].k*(tr[x<<1].r-tr[x<<1].l+1);
tr[(x<<1)|1].k=tr[x].k;tr[(x<<1)|1].least=tr[(x<<1)|1].k;
tr[(x<<1)|1].sum=(ll)tr[(x<<1)|1].k*(tr[(x<<1)|1].r-tr[(x<<1)|1].l+1);
tr[x].k=-1;
}
int mid=(tr[x].l+tr[x].r)>>1;
if(r<=mid)
return gsum(x<<1,l,r);
else if(l>mid)
return gsum((x<<1)|1,l,r);
else
return gsum(x<<1,l,mid)+gsum((x<<1)|1,mid+1,r);
}
int main()
{
int n;
while(scanf("%d",&n)!=EOF)
{
if(n==0) break;
for(int i=1;i<=n;++i)
scanf("%d",&a[i]);
for(int i=0;i<=n;++i)
f[i]=n+1;
for(int i=n;i>=1;--i)
if(a[i]<n)
{
next[i]=f[a[i]];
f[a[i]]=i;
}
ll ans=0;
memset(exist,false,sizeof(exist));
int l=0;
for(int i=1;i<=n;++i)
{
if(a[i]<n)
{
exist[a[i]]=true;
while(exist[l]) ++l;
}
b[i]=l;
}
build(1,1,n);
ans+=gsum(1,1,n);
for(int i=1;i<n;++i)
{
if(a[i]<n)
{
int r=next[i];
int l=get(1,i,r-1,a[i]);
update(1,l+1,r-1,a[i]);
}
ans+=gsum(1,i+1,n);
}
cout<<ans<<endl;
}
return 0;
}