Select 问题: 在一个无序的数组中 找到第 n 大的元素。
思路 1: 排序,O(NlgN)
思路 2: 利用快排的 RandomizedPartition(), 平均复杂度是 O(N)
思路 3: 同样是利用快排的 Partition(), 但是选择 pivot 的时候不是采用随机,而是通过一种特殊的方法。从而使复杂度最坏情况下是 O(N)。
本文介绍 STL 算法库中 nth_elemnt 的实现代码。
STL 采用的算法是: 当数组长度 <= 3时, 采用插入排序。
当长度 > 3时, 采用快排 Partition 的思想;
一、使用说明
void
nth_element (RandomAccessIteratorbeg,
RandomAccessIterator
nth,
RandomAccessIterator
end)
void
nth_element (RandomAccessIterator beg,
RandomAccessIterator nth,
RandomAccessIterator end,
BinaryPredicate op)
1. 两个函数都是让 第 n 个位置上的元素就位,
所有在位置 n 之前的元素都小于或等于它,
所有在位置 n 之后的元素都大于或等于它。
2. 复杂度: 平均复杂度是 O(N)
以下例子是使用范例:
// copyright @ L.J.SHOU Feb.23, 2014
#include <iostream>
#include <algorithm>
#include <iterator>
using namespace std; int main(void)
{
int a[]={3,5,2,6,1,4}; nth_element(a, a+3, a+sizeof(a)/sizeof(int));
cout << "The fourth element is: " << a[3] << endl; // output array a[]
copy(a, a+sizeof(a)/sizeof(int),
ostream_iterator<int>(cout, " "));
return 0;
}
程序输出结果:
The fourth element is: 4
2 1 3 4 6 5
二、源码分析
// nth_element() and its auxiliary functions. template <class _RandomAccessIter, class _Tp>
void __nth_element(_RandomAccessIter __first, _RandomAccessIter __nth,
_RandomAccessIter __last, _Tp*) {
while (__last - __first > 3) {
_RandomAccessIter __cut =
__unguarded_partition(__first, __last,
_Tp(__median(*__first,
*(__first + (__last - __first)/2),
*(__last - 1))));
if (__cut <= __nth)
__first = __cut;
else
__last = __cut;
}
__insertion_sort(__first, __last);
} template <class _RandomAccessIter>
inline void nth_element(_RandomAccessIter __first, _RandomAccessIter __nth,
_RandomAccessIter __last) {
__STL_REQUIRES(_RandomAccessIter, _Mutable_RandomAccessIterator);
__STL_REQUIRES(typename iterator_traits<_RandomAccessIter>::value_type,
_LessThanComparable);
__nth_element(__first, __nth, __last, __VALUE_TYPE(__first));
}
template <class _RandomAccessIter, class _Tp>
_RandomAccessIter __unguarded_partition(_RandomAccessIter __first,
_RandomAccessIter __last,
_Tp __pivot)
{
while (true) {
while (*__first < __pivot)
++__first;
--__last;
while (__pivot < *__last)
--__last;
if (!(__first < __last))
return __first;
iter_swap(__first, __last);
++__first;
}
}
_unguarded_partition 就是快排的 partition, 将数组分成两部分,左边的元素都小于或者等于 pivot, 右边的元素都大于或者等于 pivot.
从上述代码可以看出, nth_element 采用的 pivot 是 首元素,尾元素,中间元素,三个数的median.
通过_unguarded_partition 将数组分成两部分,
如果 nth 这个迭代器在左半边,则继续在左半边搜索;
若 nth 在右半边, 则在右半边搜索;
直到数组的长度 <= 3,时, 采用插入排序。这时 nth 迭代器所指向的数就归位了,而且它的左边元素都小于或者等于它, 右边元素都大于或者等于它。