题意及思路看这篇博客就行了,讲得很详细。
下面是我自己的理解:
如果只有2,没有3的话,做法就很简单了,只需要对数组排个序,然后从小到大枚举最大的那个数。那么它对答案的贡献为(假设这个数排序后的位置是pos)2 ^ (pos - 1) * 2 ^ a[pos]。意思是a[pos]这个数必选,其它比它小的数可选可不选,有2^(pos - 1)种情况。现在相当于变成了一个二维的问题。对于这种问题,我们常见的做法是确定一维,在从前往后扫描某一维时加上另一维对答案的贡献。对于这个题,我们可以按数组b从小到大排序,去计算a的贡献。假设现在扫描到的第pos个位置(二元组(a[i], b[i])已经按数组b排序),我们考虑来计算a[i]对答案的贡献。a对答案的贡献分为2部分,一部分是之前已经出现过的,小于等于a[i]的值,假设一共有x个,那么这部分的贡献为(2 ^ x * 2 ^ a[i]),那么大于a[i]的部分呢?其实和这个式子差不多。对于每个已经出现过,并且大于a[i]的a[j],假设已经出现过的比a[j]小的数有y个,那么贡献为2 ^ (y - 1) * 2 * a[j]。为什么是y - 1? 因为a[i]是必选的。通过观察,我们可以发现,每一个a[j]对答案的贡献,取决当前已经出现过的数中有多少个比它小的数,所以我们可以这样维护:在每次插入一个值时,先询问在这个数之前出现了多少个数(假设有x个),然后插入2 ^ x * 2 ^ a[i],询问[i,n]的区间和,就是这一阶段的答案。之后,要把[i + 1,n]中的数乘2,因为他们的前面都多了一个a[i]。
代码:
#include<bits/stdc++.h> #define ls(x) (x << 1) #define rs(x) ((x << 1) | 1) #define LL long long using namespace std; const int maxn = 100010; const LL mod = 1000000007; struct node{ int x, y, rank; }; bool cmp1(node x, node y) { return x.x == y.x ? x.y < y.y : x.x < y.x; } bool cmp2(node x, node y) { return x.y == y.y ? x.x < y.x : x.y < y.y; } node a[maxn]; struct SegementTree { LL sum, cnt, lz; }; SegementTree tr[maxn * 4]; LL qpow(LL x, LL y) { LL ans = 1; for (; y; y >>= 1) { if(y & 1) ans = (ans * x) % mod; x = (x * x) % mod; } return ans; } void pushup(int x) { tr[x].sum = (tr[ls(x)].sum +tr[rs(x)].sum) % mod; tr[x].cnt = (tr[ls(x)].cnt + tr[rs(x)].cnt) % mod; } void maintain(int x, int y) { tr[x].sum = (tr[x].sum * qpow(2, y)) % mod; tr[x].lz += y; } void pushdown(int x) { if(tr[x].lz) { if(tr[ls(x)].cnt) maintain(ls(x), tr[x].lz); if(tr[rs(x)].cnt) maintain(rs(x), tr[x].lz); tr[x].lz = 0; } } void build(int x, int l, int r) { if(l == r) { tr[x].sum = tr[x].cnt = 0; return; } int mid = (l + r) >> 1; build(ls(x), l, mid); build(rs(x), mid + 1, r); pushup(x); } void update_cnt(int x, int l, int r, int pos, int y, int z) { if(l == r) { tr[x].cnt = 1; tr[x].sum = (qpow(2, y) * qpow(2, z)) % mod; return; } pushdown(x); int mid = (l + r) >> 1; if(pos <= mid) update_cnt(ls(x), l, mid, pos, y, z); else update_cnt(rs(x), mid + 1, r, pos ,y, z); pushup(x); } void update_sum(int x, int l, int r, int ql, int qr) { if(l >= ql && r <= qr) { tr[x].lz++; tr[x].sum = (tr[x].sum * 2) % mod; return; } pushdown(x); int mid = (l + r) >> 1; if(ql <= mid) update_sum(ls(x), l, mid, ql, qr); if(qr > mid) update_sum(rs(x), mid + 1, r, ql, qr); pushup(x); } LL query_cnt(int x, int l, int r, int ql, int qr) { if(l >= ql && r <= qr) { return tr[x].cnt; } int mid = (l + r) >> 1; pushdown(x); LL ans = 0; if(ql <= mid) ans += query_cnt(ls(x), l, mid, ql, qr); if(qr > mid) ans += query_cnt(rs(x), mid + 1, r, ql, qr); return ans; } LL query_sum(int x, int l, int r, int ql, int qr) { if(l >= ql && r <= qr) { return tr[x].sum; } int mid = (l + r) >> 1; LL ans = 0; pushdown(x); if(ql <= mid) ans += query_sum(ls(x), l, mid, ql, qr); if(qr > mid) ans += query_sum(rs(x), mid + 1, r, ql, qr); return ans % mod; } int main() { int n; while(~scanf("%d", &n)) { for (int i = 1; i <= n; i++) { scanf("%d%d", &a[i].x, &a[i].y); } sort(a + 1, a + 1 + n, cmp1); for (int i = 1; i <= n; i++) { a[i].rank = i; } sort(a + 1, a + 1 + n, cmp2); build(1, 1, n); LL ans = 0; for (int i = 1; i <= n; i++) { LL tmp = query_cnt(1, 1, n, 1, a[i].rank); update_cnt(1, 1, n, a[i].rank, tmp, a[i].x); ans = (ans + query_sum(1, 1, n, a[i].rank, n) * qpow(3, a[i].y) % mod) % mod; if(a[i].rank != n) update_sum(1, 1, n, a[i].rank + 1, n); } printf("%lld\n", ans); } }