▶ 书中第二章部分程序,加上自己补充的代码,包括各种优化的快排
package package01; import edu.princeton.cs.algs4.In;
import edu.princeton.cs.algs4.StdOut;
import edu.princeton.cs.algs4.StdRandom; public class class01
{
private class01() {} public static void sort1(Comparable[] a) // 基本的快排
{
StdRandom.shuffle(a); // 数组 a 随机化
sortKernel1(a, 0, a.length - 1);
} private static void sortKernel1(Comparable[] a, int lo, int hi) // 快排递归内核
{
if (hi <= lo)
return;
int j = partition1(a, lo, hi);
sortKernel1(a, lo, j - 1); // 用不到 sortKernel1 的返回值
sortKernel1(a, j + 1, hi);
} private static int partition1(Comparable[] a, int lo, int hi) // 快排分组
{
int i = lo, j = hi + 1;
for (Comparable v = a[lo];; exch(a, i, j)) // a[lo] 作为分界值,每次循环前要调整 i 和 j,所以初值要过量
{ // 没有 break 退出时调整两个错位的值
for (i++; i < hi && less(a[i], v); i++); // 寻找大于 a[lo] 的元素
for (j--; j > lo && less(v, a[j]); j--); // 寻找小于 a[lo] 的元素
if (i >= j) // 指针交叉,说明已经排好了
break;
}
exch(a, lo, j); // 安置哨兵
return j; // 返回哨兵的位置
} public static Comparable select(Comparable[] a, int k) // 从数组中选出第 k 小的数字
{
StdRandom.shuffle(a);
int lo = 0, hi = a.length - 1;
for (; lo < hi;)
{
int i = partition1(a, lo, hi); // 用 a[0] 分段,这是函数 partition 返回安置位置的原因
if (i > k) // 左段过长
hi = i - 1;
else if (i < k) // 右段过长
lo = i + 1;
else
return a[i]; // 找到了指定的分位数
}
return a[lo]; // 指针交叉或 k 不合法?
} public static void sort2(Comparable[] a) // 优化版本,三路快排
{
StdRandom.shuffle(a);
sortKernel2(a, 0, a.length - 1);
} private static void sortKernel2(Comparable[] a, int lo, int hi) // 三路路快排的递归内核,没有 partition 程序
{
if (hi <= lo)
return;
int pLess = lo, pEqual = lo + 1, pGreater = hi; // 三个指针分别指向小于部分、等于部分、大于部分,且 pEqual 指向当前检查元素
for (Comparable v = a[lo]; pEqual <= pGreater;)
{
int cmp = a[pEqual].compareTo(v);
if (cmp < 0)
exch(a, pLess++, pEqual++); // 当前元素较小,放到左边
else if (cmp > 0)
exch(a, pEqual, pGreater--); // 当前元素较大,放到右边
else
pEqual++; // 当前元素等于哨兵,仅自增相等部分的指针
}
sortKernel2(a, lo, pLess - 1);
sortKernel2(a, pGreater + 1, hi);
} private static final int INSERTION_SORT_CUTOFF = 8; // 数组规模不大于 8 时使用插入排序 public static void sort3(Comparable[] a) // 优化版本,三采样 - 插入排序
{
sortKernel3(a, 0, a.length - 1); // 没有数组随机化
} private static void sortKernel3(Comparable[] a, int lo, int hi) // 递归内核
{
if (hi <= lo)
return;
if (hi - lo + 1 <= INSERTION_SORT_CUTOFF) // 规模较小,使用插入排序
{
insertionSort(a, lo, hi + 1); //Insertion.sort(a, lo, hi + 1);,教材源码使用 edu.princeton.cs.algs4.Insertion
return;
}
int j = partition3(a, lo, hi);
sortKernel3(a, lo, j - 1);
sortKernel3(a, j + 1, hi);
} private static int partition3(Comparable[] a, int lo, int hi) // 分组
{
exch(a, median3(a, lo, lo + (hi - lo + 1) / 2, hi), lo); // 开头、中间和结尾各找 1 元求中位数作为哨兵 int i = lo, j = hi + 1;
Comparable v = a[lo]; // 注意相比函数 partition1 少了一层循环 for (i++; less(a[i], v); i++)
{
if (i == hi) // 低指针到达右端,所有元素都比哨兵小
{
exch(a, lo, hi); // 哨兵放到最后,返回哨兵位置
return hi;
}
}
for (j--; less(v, a[j]); j--)
{
if (j == lo + 1) // 高指针到达左端,所有元素都比哨兵大
return lo; // 直接返回原哨兵位置
}
for (; i < j;) // 在低指针和高指针之间分组
{
exch(a, i, j);
for (i++; less(a[i], v); i++);
for (j--; less(v, a[j]); j--);
}
exch(a, lo, j);
return j;
} private static final int MEDIAN_OF_3_CUTOFF = 40; // 数组规模较小时不采用三采样 public static void sort4(Comparable[] a) // 优化版本,使用
{
sortKernel4(a, 0, a.length - 1);
} private static void sortKernel4(Comparable[] a, int lo, int hi) // 递归内核
{
int n = hi - lo + 1;
if (n <= INSERTION_SORT_CUTOFF) // 小规模数组用插入排序
{
insertionSort(a, lo, hi + 1);
return;
}
else if (n <= MEDIAN_OF_3_CUTOFF) // 中规模数组用三采样
exch(a, median3(a, lo, lo + n / 2, hi), lo);
else // 大规模数组使用 Tukey ninther 采样,相当于 9 点采样
{
int eps = n / 8, mid = lo + n / 2;
int ninther = median3(a,
median3(a, lo, lo + eps, lo + eps + eps),
median3(a, mid - eps, mid, mid + eps),
median3(a, hi - eps - eps, hi - eps, hi)
);
exch(a, ninther, lo);
} int i = lo, p = lo, j = hi+1, q = hi + 1; // Bentley-McIlroy 三向分组
for (Comparable v = a[lo]; ;)
{
for (i++; i < hi && less(a[i], v); i++);
for (j--; j > lo && less(v, a[j]); j--); if (i == j && eq(a[i], v)) // 指针交叉,单独处理相等的情况
exch(a, ++p, i);
if (i >= j)
break; exch(a, i, j);
if (eq(a[i], v))
exch(a, ++p, i);
if (eq(a[j], v))
exch(a, --q, j);
}
i = j + 1;
for (int k = lo; k <= p; k++)
exch(a, k, j--);
for (int k = hi; k >= q; k--)
exch(a, k, i++);
sortKernel4(a, lo, j);
sortKernel4(a, i, hi);
} private static void insertionSort(Comparable[] a, int lo, int hi) // 公用插入排序
{
for (int i = lo; i < hi; i++)
{
for (int j = i; j > lo && less(a[j], a[j - 1]); j--)
exch(a, j, j - 1);
}
} private static int median3(Comparable[] a, int i, int j, int k) // 计算 3 元素的中位数
{
return less(a[i], a[j]) ?
(less(a[j], a[k]) ? j : less(a[i], a[k]) ? k : i) :
(less(a[k], a[j]) ? j : less(a[k], a[i]) ? k : i);
} private static boolean less(Comparable v, Comparable w)
{
if (v == w) // 相同引用时避免内存访问
return false;
return v.compareTo(w) < 0;
} private static boolean eq(Comparable v, Comparable w)
{
if (v == w) // 相同引用时避免内存访问
return true;
return v.compareTo(w) == 0;
} private static void exch(Object[] a, int i, int j)
{
Object swap = a[i];
a[i] = a[j];
a[j] = swap;
} private static void show(Comparable[] a)
{
for (int i = 0; i < a.length; i++)
StdOut.println(a[i]);
} public static void main(String[] args)
{
In in = new In(args[0]);
String[] a = in.readAllStrings(); class01.sort1(a);
//class01.sort2(a);
//class01.sort3(a);
//class01.sort4(a);
show(a);
/*
for (int i = 0; i < a.length; i++) // 使用函数 select 进行排序,实际上是逐步挑出 a 中排第 1、第 2、第 3 …… 的元素
{
String ith = (String)class01.select(a, i);
StdOut.println(ith);
}
*/
}
}