这题有人说部分分O(n3)暴力,然而我暴力都没写过,调了半天也没用……还是看题解吧
首先,咱把A * ( h – minH ) + B * ( s – minS ) <= C 变个型,得到 A * h + B * s - C <= A * minH + B * minS. 令 sum = A * h + B * s - C,如果我们把所有球员按sum排序,就能保证取球员的时候是单调的,如果 i 能取,则 j (j < i) 一定能取。
然后我们第一层循环枚举minS,第二层循环枚举minH,然后设两个指针L, R,表示当前符合sum <= A * minH + B * minS的区间,但同时我们还要保证h >= minH && s >= minS,而且根据h >= minH还要保证,s <= minS + C / B.于是就有一个很【强】的做法,我们先把符合的条件的s添加进去,再把队列中不符合条件的h踢出。
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<cstdlib>
#include<vector>
#include<queue>
#include<stack>
#include<cctype>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a) memset(a, 0, sizeof(a))
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-;
const int maxn = 5e3 + ;
inline ll read()
{
ll ans = ;
char ch = getchar(), last = ' ';
while(!isdigit(ch)) last = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << ) + (ans << ) + ch - '', ch = getchar();
if(last == '-') ans = -ans;
return ans;
}
inline void write(ll x)
{
if(x < ) putchar('-'), x = -x;
if(x >= ) write(x / );
putchar(x % + '');
} int n;
ll A, B, C;
struct Node
{
int h, s; ll sum;
}Sum[maxn], H[maxn], S[maxn];
bool cmp1(Node a, Node b) {return a.sum < b.sum;}
bool cmp2(Node a, Node b) {return a.h < b.h;}
bool cmp3(Node a, Node b) {return a.s < b.s;} int ans = ; int main()
{
n = read(); A = read(); B = read(); C = read();
for(int i = ; i <= n; ++i)
{
Sum[i].h = read(), Sum[i].s = read(); Sum[i].sum = (ll)A * Sum[i].h + (ll)B * Sum[i].s - C;
H[i].h = Sum[i].h; H[i].s = Sum[i].s; H[i].sum = Sum[i].sum;
S[i].h = Sum[i].h; S[i].s = Sum[i].s; S[i].sum = Sum[i].sum;
}
sort(Sum + , Sum + n + , cmp1);
sort(H + , H + n + , cmp2);
sort(S + , S + n + , cmp3);
for(int i = ; i <= n; ++i) //枚举minS
{
int L = , R = , cnt = ;
ll Mins = S[i].s;
ll Lims = Mins + C / B;
for(int j = ; j <= n; ++j) //枚举minH
{
ll Limsum = A * H[j].h + B * Mins;
while(R < n && Sum[R + ].sum <= Limsum) //合法区间,但不能保证s,h符合
{
R++;
if(Mins <= Sum[R].s && Sum[R].s <= Lims) cnt++; //符合条件的s
}
while(L < n && H[L + ].h < H[j].h) //维护合法区间左端点
{
L++;
if(Mins <= H[L].s && H[L].s <= Lims) cnt--; //踢出不符合的h
//因为[L, R]中可能有cnt之外的,所以要判断哪些属于cnt,踢出时再cnt--
}
ans = max(ans, cnt);
}
}
write(ans); enter;
return ;
}