题意是不同颜色区间首尾相接,询问一个区间内同色区间的最长长度。
网上流行的做法,包括翻出来之前POJ的代码也是RMQ做法,对于序列上的每个数,记录该数向左和向右延续的最远位置,那么对于一个查询Q(L, R),它的答案就分成了三种情况right(L) - L,R - left(R)以及Q(L+right(L),R-left(R))。
这里给出一个线段树做法,在线段树的节点上维护3个量:l_value, r_value, value分别表示以左端点为起始点,以右端点为起始点以及该区间内的最大的连续长度,更新时通过两个子区间相接的地方是否相同分不同的情况进行讨论。
#include <cstdio>
#include <algorithm>
using namespace std; const int MAXN = ; class SegNode {
public:
int L, R;
int l_value, r_value, value;
int is_same;
} node[ * MAXN]; int num[MAXN]; class SegTree {
public:
void log(int idx) {
printf("%d: ", idx);
for (int i = node[idx].L; i <= node[idx].R; i++)
printf("%d ", num[i]);
printf("(%d %d %d %d)\n", node[idx].l_value, node[idx].r_value, node[idx].value, node[idx].is_same);
}
void build(int root, int L, int R) { node[root].L = L;
node[root].R = R; if (L == R) {
// leaf
node[root].l_value = ;
node[root].r_value = ;
node[root].value = ;
node[root].is_same = ;
} else {
// non leaf
int M = (L + R) / ;
if (L <= M) {
build( * root, L, M);
}
if (M + <= R) {
build( * root + , M + , R);
}
if (num[node[ * root].R] == num[node[ * root + ].L]) {
node[root].l_value = node[ * root].l_value + node[ * root].is_same * node[ * root + ].l_value;
node[root].r_value = node[ * root + ].r_value + node[ * root + ].is_same * node[ * root].r_value;
node[root].value = max(max(node[ * root].value, node[ * root + ].value), node[ * root].r_value + node[ * root + ].l_value);
node[root].is_same = node[ * root].is_same & node[ * root + ].is_same;
} else {
node[root].l_value = node[ * root].l_value;
node[root].r_value = node[ * root + ].r_value;
node[root].value = max(node[ * root].value, node[ * root + ].value);
node[root].is_same = ;
}
//log(root);
}
}
int query(int root, int L, int R, int k) {
if (L <= node[root].L && R >= node[root].R) {
if (k == ) return node[root].value;
else if (k == ) return node[root].l_value;
else return node[root].r_value;
} if (L > node[root].R || R < node[root].L) {
return ;
} int M = (node[root].L + node[root].R) / ;
if (R <= M) {
return query( * root, L, R, k);
} else if (L > M) {
return query( * root + , L, R, k);
} else {
if (num[node[ * root].R] == num[node[ * root + ].L]) {
if (k == ) {
int res = ;
res = max(query( * root, L, R, ), query( * root + , L, R, ));
res = max(res, query( * root, L, R, ) + query( * root + , L, R, ));
return res;
} else if (k == ) {
int res = query( * root, L, R, );
if (node[ * root].is_same) res += query( * root + , L, R, );
return res;
} else {
int res = query( * root + , L, R, );
if (node[ * root + ].is_same) res += query( * root, L, R, );
return res;
}
} else {
if (k == ) {
return max(query( * root, L, R, ), query( * root + , L, R, ));
} else if (k == ) {
return query( * root, L, R, );
} else {
return query( * root + , L, R, );
}
}
}
}
} tree; int main() {
int n, q;
while (scanf("%d%d", &n, &q) && n) {
for (int i = ; i <= n; i++)
scanf("%d", &num[i]);
tree.build(, , n);
while (q--) {
int l, r;
scanf("%d%d", &l, &r);
printf("%d\n", tree.query(, l, r, ));
}
}
}
由这题想到了最大连续子段和这个问题,常见的解法是动态规划解法,算法课上讲了一个分治的解法,将整段分成左右两半,然后在中间点处向左和向右遍历,寻找最长的连续段,算法复杂度分析T(n)=2T(n/2)+O(n),因此复杂度是O(nlgn)。然而可以使用上面的思路进行维护,维护一个以左端点为起点的最长连续子段,该子段记作left(root),讲root分成L,R,那么left(root)=max{left(L), sum(L)+left(R)},这样查询中点mid的最优值就可以用right(L)+left(R)来替代了,复杂度为O(1),最后也就可以做到复杂度为O(nlgn)的最大子段和的分治算法了。
题意是各个序列,统计这样的三元组(a,b,c),满足条件idx(a)<idx(b)<idx(c),且a<b<c或者a>b>c的数量。
做法是用树状数组统计出给定一个索引i,i左侧比a[i]小的数量,以及右侧比a[i]小的数量,用左侧比a[i]小的数量乘上右侧比a[i]大的数量,以及左侧比a[i]大的数量乘上右侧比a[i]小的数量。
#include <cstdio>
#include <cstring>
using namespace std; const int MAXA = ;
const int MAXN = ; int c[MAXA];
int a[MAXN];
int left[MAXN], right[MAXN]; int lowbit(int x) {
return x & (-x);
} void insert(int i, int x) {
while (i < MAXA) {
c[i] += x;
i += lowbit(i);
}
} int query(int i) {
int res = ;
while (i > ) {
res += c[i];
i -= lowbit(i);
}
return res;
} int main() {
int T;
scanf("%d", &T);
while (T--) {
int n;
scanf("%d", &n);
for (int i = ; i < n; i++)
scanf("%d", &a[i]);
memset(c, , sizeof(c));
for (int i = ; i < n; i++) {
left[i] = query(a[i]);
insert(a[i], );
}
memset(c, , sizeof(c));
for (int i = n - ; i >= ; i--) {
right[i] = query(a[i]);
insert(a[i], );
}
long long ans = ;
for (int i = ; i < n; i++) {
ans += (long long)left[i] * (n - - i - right[i]);
ans += (i - left[i]) * (long long)right[i];
}
printf("%lld\n", ans);
}
}