题目链接:http://codeforces.com/problemset/problem/1029/B
题目大意:从数组a中选出一些数组成数组b,要求 b[i+1]<=b[i]*2
。
一开始想到的是O(n^2)的动态规划,但是超时了,下面是超时的代码。
#include <iostream>
using namespace std;
const int maxn = 200020;
int n, a[maxn], f[maxn], res = 0;
int main() {
cin >> n;
for (int i = 0; i < n; i ++) cin >> a[i];
for (int i = 0; i < n; i ++) {
f[i] = 1;
for (int j = i-1; j >= 0 && a[i] <= 2*a[j]; j --) {
f[i] = max(f[i], f[j]+1);
}
if (f[i] > res) res = f[i];
}
cout << res << endl;
return 0;
}
然后想到的是:
- 二分找最小的满足a[i]<=a[j]*2的那个j ,O(logn)的时间复杂度
- 线段树求区间[j,i-1]的最大值,O(logn)的时间复杂度
再加上外层的循环,时间复杂度会降到O(n * logn)。
代码:
#include <iostream>
using namespace std;
#define lson l, m, rt << 1
#define rson m+1, r, rt << 1 | 1
const int maxn = 200020;
int n, a[maxn], MAX[maxn<<2];
void pushUp(int rt) {
MAX[rt] = max(MAX[rt<<1] , MAX[rt<<1|1]);
}
void build(int l, int r, int rt) {
if (l == r) {
MAX[rt] = 0;
return;
}
int m = (l+r) >> 1;
build(lson);
build(rson);
pushUp(rt);
}
void update(int p, int val, int l, int r, int rt) {
if (l == r) {
MAX[rt] = val;
return;
}
int m = (l + r) >> 1;
if (p <= m) update(p, val, lson);
else update(p, val, rson);
pushUp(rt);
}
int query(int L, int R, int l, int r, int rt) {
if (L <= l && r <= R) {
return MAX[rt];
}
int m = (l + r) >> 1;
int tmp = 0;
if (L <= m) tmp = query(L, R, lson);
if (R >= m+1) tmp = max(tmp, query(L, R, rson));
return tmp;
}
int findL(int i) {
return lower_bound(a, a+n+1, (a[i]+1)/2) - a;
}
int main() {
cin >> n;
build(1, n, 1);
for (int i = 1; i <= n; i ++) cin >> a[i];
for (int i = 1; i <= n; i ++) {
int L = findL(i);
int val;
if (L >= i) val = 1;
else {
val = query(L, i-1, 1, n, 1) + 1;
}
update(i, val, 1, n, 1);
}
cout << query(1, n, 1, n, 1) << endl;
return 0;
}
然后是O(n)的单调队列解法。
代码:
#include <iostream>
using namespace std;
const int maxn = 200020;
int n, a[maxn], MAX[maxn];
int st = 0, ed = 0, sum = 0;
int que[maxn]; // que存放id,id对应的最大长度是MAX[id]
void test() {
cout << "[test]" << endl;
for (int i = 1; i <= n; i ++) {
cout << i << ": " << MAX[i] << endl;
}
}
int main() {
cin >> n;
for (int i = 1; i <= n; i ++) cin >> a[i];
int j = 1;
for (int i = 1; i <= n; i ++) {
while (j < i && !( a[i] <= 2 * a[j] )) {
j ++;
}
while (st < ed && que[st] < j) st ++;
if (st == ed) {
MAX[i] = 1;
} else {
MAX[i] = MAX[ que[st] ] + 1;
}
sum = max(sum , MAX[i]);
while (st < ed && MAX[ que[st] ] <= MAX[i]) {
st ++;
}
que[ed++] = i;
}
// test();
cout << sum << endl;
return 0;
}