题意:给你一个序列,让你求出对于所有区间<i, j>的mex和,mex表示该区间没有出现过的最小的整数。
思路:从时限和点数就可以看出是线段树,并且我们可以枚举左端点i, 然后求出所有左端点为i的区间内mex值的和。
先把数插满,然后先询问后删除当前最左边的断点i。而且显然线段树里面保存的是mex值,而且这个序列是非递减的。
分析:我们先预处理出对于右端点为i的所有<1,i>的mex,分别插入线段树的i位置。然后每次删除最左边的左端点i
,假如当前我们要删除a[i] ,我们找到它之后第一个位置j满足a[i] == a[j], 那么区间i------j-1里面的所有mex都要更新,取线段树内的值和a[i]的最小值。 实际操作我们只要找到第一个比a[i]的位置l, r = j-1, 更新<l,r>之间的mex为a[i]即可。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
#define lson l, m, rt<<1
#define rson m+1, r, rt<<1|1
#define ls rt<<1
#define rs rt<<1|1
#define Mid int m = l+r>>1
const int maxn = 2000006;
int next[maxn], pre[maxn], n;
int a[maxn], mex;
bool vis[maxn];
ll sum[maxn<<2];
int mx[maxn<<2], col[maxn<<2];
void build(int l=1, int r=n, int rt=1) {
col[rt] = -1;
sum[rt] = 0;
mx[rt] = 0;
if(l == r) return;
Mid;
build(lson);
build(rson);
}
inline void down(int l, int r, int rt) {
if(~col[rt]) {
col[ls] = col[rs] = col[rt];
Mid;
sum[ls] = (ll)(m-l+1)*col[rt];
mx[ls] = mx[rs] = col[rt];
sum[rs] = (ll)(r-m)*col[rt];
col[rt] = -1;
}
}
inline void up(int rt) {
sum[rt] = sum[ls] + sum[rs];
mx[rt] = max(mx[ls], mx[rs]);
}
void update(int L, int R, int v, int l=1, int r=n, int rt=1) {
if(L <= l && r <= R) {
col[rt] = mx[rt] = v;
sum[rt] = (ll)(r-l+1)*v;
return;
}
Mid; down(l, r, rt);
if(L <= m) update(L, R, v, lson);
if(R > m) update(L, R, v, rson);
up(rt);
}
ll query(int L, int R, int l=1, int r=n, int rt=1) {
if(L <= l && r <= R)
return sum[rt];
Mid; down(l, r, rt);
ll ret = 0;
if(L <= m) ret += query(L, R, lson);
if(R > m) ret += query(L, R, rson);
up(rt);
return ret;
}
int find(int v, int l=1, int r=n, int rt=1) {
if(mx[rt] <= v) return n+1;
if(l == r) return l;
Mid; down(l, r, rt);
if(mx[ls] > v) return find(v, lson);
else return find(v, rson);
}
int main() {
int i, j;
while(~scanf("%d", &n) && n) {
for(i = 1; i <= n; i++) {
scanf("%d", &a[i]);
pre[i] = vis[i] = 0;
next[i] = n+1;
}
pre[0] = vis[0] = 0;
for(i = 1; i <= n; i++)
if(a[i] <= n) {
if(pre[a[i]])
next[pre[a[i]]] = i;
pre[a[i]] = i;
}
build();
mex = 0;
for(i = 1; i <= n; i++) {
if(a[i] <= n){
vis[a[i]] = 1;
while(vis[mex]) mex++;
}
update(i, i, mex);
}
ll ans = 0;
for(i = 1; i <= n; i++) {
ans += query(i, n);
if(a[i] <= mex) {
int l = max(find(a[i]), i);
int r = next[i]-1;
if(l <= r) update(l, r, a[i]);
}
}
printf("%I64d\n", ans);
}
return 0;
}
/*
3
0 10000 20000
*/