题目大意
现在给出一个栈,然后给出三种操作
- 将数x入栈
- 将栈顶出栈
- 求栈的中位数
输入
每组包含一个测试用例
第一行是一个正整数N≤105,接下来有N行命令,格式如下
Push key
Pop
PeekMedian
Push key
将key
入栈,Pop
将栈顶元素出栈,PeekMedian
求栈中所有元素的中位数。key
是一个正整数且不大于105
输出
对每个Pop
指令输出出栈的元素,对每个PeekMedian
指令输出栈中所有元素的中位数。如果栈中没有元素可以出栈或输出就输出Invalid
样例输入
17
Pop
PeekMedian
Push 3
PeekMedian
Push 2
PeekMedian
Push 1
PeekMedian
Pop
Pop
Push 5
Push 4
PeekMedian
Pop
Pop
Pop
Pop
样例输出
Invalid
Invalid
3
2
2
1
2
4
4
5
3
Invalid
解析
这题可以用树状数组来解。
只是树状数组中下标为i的存放的是i之前的数有多少个,这样用树状数组求前缀和再和栈的大小进行比较就可以找到栈的中位数
可是这题用python会超时,c++可以AC
超时的python:
# -*- coding: utf-8 -*-
# @Time : 2019/6/13 11:10
# @Author : ValarMorghulis
# @File : 1057.py
s = list()
c = [0 for i in range(100861)]
def lowBit(x):
return x & (-x)
def update(x, v):
i = x
while i < 100861:
c[i] += v
i += lowBit(i)
def getSum(x):
sum, i = 0, x
while i >= 1:
sum += c[i]
i -= lowBit(i)
return sum
def find():
left, right, k = 1, 100861, (len(s) + 1) // 2
while left < right:
mid = (left + right) // 2
if getSum(mid) >= k:
right = mid
else:
left = mid + 1
print(left)
def solve():
global c, s
n = int(input())
for i in range(n):
command = input().split()
if command[0][1] == 'u':
s.append(int(command[1]))
update(int(command[1]), 1)
elif command[0][1] == 'o':
if len(s) == 0:
print("Invalid")
else:
update(s[-1], -1)
print(s[-1])
s.pop()
elif command[0][1] == 'e':
if len(s) == 0:
print("Invalid")
else:
find()
if __name__ == "__main__":
solve()
AC的c++
#include<stdio.h>
#include<stdlib.h>
#include<iostream>
#include<algorithm>
#include<vector>
#include<queue>
#include<stack>
#include<cstring>
#include<string>
#include<cmath>
#define inf 0xffffffff
#define lowbit(i) ((i)&(-(i)))
using namespace std;
const int maxn=100861;
stack<int> s;
int c[maxn];
void update(int x, int v)
{
for(int i=x; i<maxn; i+=lowbit(i))
c[i]+=v;
}
int sum(int x)
{
int t=0;
for(int i=x; i>=1; i-=lowbit(i))
t+=c[i];
return t;
}
void find()
{
int left=0;
int right=maxn;
int mid;
int k=(s.size()+1)>>1;
while(left<right)
{
mid=(left+right)>>1;
if(sum(mid)>=k)
right=mid;
else
left=mid+1;
}
printf("%d\n", left);
}
int main()
{
int n;
scanf("%d", &n);
char command[15];
for(int i=0; i<n; i++)
{
scanf("%s", command);
if(command[1]=='u')
{
int t;
scanf("%d", &t);
update(t, 1);
s.push(t);
}
else
if(command[1]=='o')
if(s.size()==0)
printf("Invalid\n");
else
{
update(s.top(), -1);
printf("%d\n", s.top());
s.pop();
}
else
if(command[1]=='e')
if(s.size()==0)
printf("Invalid\n");
else
find();
}
return 0;
}