问题描述
给定一个整数数组,返回range sum 落在给定区间[lower, upper] (包含lower和upper)的个数。range sum S(i, j) 表示数组中第i 个元素到j 个元素之和。
Note:
A naive algorithm of O(n2) is trivial. You MUST do better than that.
Example:
Input: nums = [-2,5,-1], lower = -2, upper = 2,
Output: 3
Explanation: The three ranges are : [0,0], [2,2], [0,2] and their respective sums are: -2, -1, 2.
分析
这个题目比较难,楼主第一次面对这种题型,直接缴械投降。参考了各位大神的解题思路,总结了两种解法。一种是TreeMap思路,另外一种是使用segment tree (or binary index tree)。题目寻找需要range sum 在[lower, upper] 之间的个数,满足条件的case用数学公式表达为:
lower <= sum[i] - sum[j] <= upper, i > j, sum[i] is prefix sum of nums at index of i.
也就是
sum[i] - high <= sum[j] <= sum[i] - lower, i > j, sum[i] is prefix sum of nums at index of i.
(or
lower + sum[j] <= sum[i] <= sum[j] + higher, i > j, sum[i] is prefix sum of nums at index of i.)
那么我们的问题可以转化为求落在[sum[i] - high,sum[i] - lower] 区间sum[j]的个数, i = 0....n, j < i。
无论是TreeMap还是Segment Tree,总体的时间复杂度都为nlogn。
实现
TreeMap
TreeMap 的key 是prefixsum, value 是相对应的个数。主要使用TreeMap的subMap的方法,求得落在区间内[sum[i] - high, sum[i] - lower]的sum[j]的个数。
public int countRangeSum(int[] nums, int lower, int upper) {
if(nums == null || nums.length == 0){
return 0;
}
//key is the sum[i], value is the corresponding count
// sum[i] - sum[j] in [lower, upper], transform to find how many sum[j] 在区间[sum[i] - high, sum[i] - lower]。
TreeMap<Long, Integer> map = new TreeMap();
long sum = 0;
int cnt = 0;
for(int i = 0; i < nums.length; i++){
sum += nums[i];
//sum[0, i]满足case
if(sum >= lower && sum <= upper){
cnt++;
}
//find sum[j] 的个数that lies in [sum[i] - high, sum[i] - lower]之间
cnt += map.subMap(sum - upper, true, sum - lower, true).values().stream().mapToInt(Integer::valueOf).sum();
map.put(sum, map.getOrDefault(sum, 0) + 1);
}
return cnt;
}
Segment Tree
Segment Tree每个节点保存区间段的范围和落在这个区间内prefix sum的个数。
class Node {
Node left;
Node right;
//落在区间内的个数
int count;
long min;
long max;
public Node(long min, long max) {
this.min = min;
this.max = max;
}
}
//构建segement tree
private Node buildTree(Long[] valArr, int low, int high) {
if(low > high) return null;
Node root = new Node(valArr[low], valArr[high]);
if(low == high) return root;
int mid = low + (high - low)/2;
root.left = buildTree(valArr, low, mid);
root.right = buildTree(valArr, mid+1, high);
return root;
}
private void update(Node root, Long val) {
if(root == null) return;
if(val >= root.min && val <= root.max) {
root.count++;
update(root.left, val);
update(root.right, val);
}
}
private int query(Node root, long min, long max) {
if(root == null) return 0;
if(min > root.max || max < root.min) return 0;
if(min <= root.min && max >= root.max) return root.count;
return query(root.left, min, max) + query(root.right, min, max);
}
public int countRangeSum(int[] nums, int lower, int upper) {
if(nums == null || nums.length == 0) return 0;
int ans = 0;
Set<Long> valSet = new HashSet<Long>();
long sum = 0;
for(int i = 0; i < nums.length; i++) {
sum += (long) nums[i];
valSet.add(sum);
}
Long[] valArr = valSet.toArray(new Long[0]);
Arrays.sort(valArr);
Node root = buildTree(valArr, 0, valArr.length-1);
sum = nums[0];
ans += (sum >= lower && sum <= upper) ? 1:0;
for(int i = 1; i < nums.length; i++) {
//sum[i]
update(root, sum);
//sum[j]
sum += (long) nums[i];
ans += (sum >= lower && sum <= upper) ? 1:0;
ans += query(root, (long)sum - upper, (long)sum - lower);
}
return ans;
}