HDU-4747 Mex 线段树应用 Mex性质
题意
给定长度为\(n\)的数组\(a\),求
\[\sum \sum mex(i,j) \]其中\(mex(i,j)\)表示区间\(mex(a_i...a_j)的值\)
\[1\leq n \leq 2\times 10^5\\ 1\leq a_i \leq 10^9 \]分析
此题我认为还是不太好想到的
首先如果只求一维,由于单调性,求\(\sum mex(1,i)\)是可以在\(O(n)\)下完成的。
然后注意到第二维即\(\sum mex(2,i)\)该如何计算,这个时候\(1\)相当于没有了,1在这一维上产生的影响就是当前下一个等于\(a[1]\)的元素之前的一段。后面的显然和原来的保持不变,这就让我们想到了用区间维护。
那么\(a[1]\)会如何影响\([2,next[a[1]] - 1]\)呢?
再次想到\(mex\)在“前缀”意义上的单调性,我们只需要把其中大于\(a[1]\)的部分变为\(a[1]\)即可,其他部分的\(mex\)并不会受影响
最后由于递推,不要忘记单点修改。
所以问题就转化成了
- 求出每个数的下一个等于它的数出现的位置
- 求出第一个大于等于\(a[i]\)的位置
- 修改某一段区间的值
这些都可以用线段树实现,当然要注意一些细节,比如\(lazy\)标记应该设置\(-1\),否则\(mx\)会无法下传,以及下一个位置数组应该在最后加上\(n + 1\)点
代码
struct Tree {
int lazy;
int sum;
int mx;
int l, r;
};
int n;
Tree node[maxn << 2];
int a[maxn];
int mex[maxn];
int nxt[maxn];
void push_up(int i) {
node[i].sum = node[i << 1].sum + node[i << 1 | 1].sum;
node[i].mx = max(node[i << 1].mx, node[i << 1 | 1].mx);
}
void build(int i, int l, int r) {
node[i].l = l;
node[i].r = r;
if (l == r) {
node[i].sum = mex[l];
node[i].mx = mex[l];
return;
}
int mid = l + r >> 1;
build(i << 1, l, mid);
build(i << 1 | 1, mid + 1, r);
push_up(i);
}
void push_down(int i, int m) {
if (node[i].lazy >= 0) {
node[i << 1].lazy = node[i].lazy;
node[i << 1 | 1].lazy = node[i].lazy;
node[i << 1].sum = node[i].lazy * (m - (m >> 1));
node[i << 1 | 1].sum = node[i].lazy * (m >> 1);
node[i << 1].mx = node[i << 1 | 1].mx = node[i].lazy;
node[i].lazy = -1;
}
}
void update(int i, int l, int r, int val) {
if (node[i].l > r || node[i].r < l) return;
if (node[i].l >= l && node[i].r <= r) {
//bug;
node[i].lazy = val;
node[i].sum = (node[i].r - node[i].l + 1) * val;
node[i].mx = val;
//cout << node[i].mx << ' ' << i << '\n';
return;
}
push_down(i, node[i].r - node[i].l + 1);
update(i << 1, l, r, val);
update(i << 1 | 1, l, r, val);
push_up(i);
}
int query(int i, int x) {
if (node[i].l == node[i].r) {
return node[i].l;
}
if (node[i << 1].mx > x) return query(i << 1, x);
else return query(i << 1 | 1, x);
}
signed main() {
while (scanf("%lld", &n)) {
if (!n) break;
for (int i = 1; i <= n; i++)
a[i] = readint(), nxt[i] = n + 1;
for (int i = 1; i < 4 * n; i++) {
node[i].l = node[i].r = node[i].sum = node[i].lazy = -1, node[i].mx = 0;
}
unordered_map<int, int> mp;
int cur = 0;
for (int i = 1; i <= n; i++) {
if (mp[a[i]]) nxt[mp[a[i]]] = i, mp[a[i]] = i;
else mp[a[i]] = i;
while (mp[cur]) cur++;
mex[i] = cur;
}
build(1, 1, n);
ll ans = 0;
for (int i = 1; i <= n; i++) {
ans += node[1].sum;
if (node[1].mx > a[i]) {
int l = query(1,a[i]);
int r = nxt[i] - 1;
if (l <= r) update(1, l, r, a[i]);
}
update(1, i, i, 0);
}
cout << ans << '\n';
}
}