AcWing 1236.递增三元组 二分 前缀和

给定三个整数数组

A=[A1,A2,…AN]
B=[B1,B2,…BN]
C=[C1,C2,…CN]

请你统计有多少个三元组 (i,j,k)满足:

  1. 1≤i,j,k≤N1
  2. Ai<Bj<Ck
输入格式

第一行包含一个整数 N

第二行包含 N个整数 A1,A2,…AN

第三行包含 N个整数 B1,B2,…BN

第四行包含 N个整数 C1,C2,…CN

输出格式

一个整数表示答案。

数据范围

1≤N≤105
0≤Ai,Bi,Ci≤105

输入样例:
3
1 1 1
2 2 2
3 3 3
输出样例:
27
 思路:

首先我们可以看数据量,10^5说明我们最多使用O(nlogn)的算法。首先我们清楚,在最中间的数组是决定性因素。因为我们要统计的是比中间的数组元素大和小的元素数量。对于中间数组的元素,我们要去第一和第三个数组寻找大与和小与的元素数量。一般的查找O(n)会使得整体变成了O(n^2)。因此我们需要使用二分法。

第二种思路,开一个大数组,将所有元素记录次数记录在其中。然后求前缀和,我们就能得到对应元素在该处比它大或小的元素数量,相乘相加即可。

代码:
import java.util.*;
public class Main{
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[][] arr = new int[3][n];
        for (int i = 0; i < 3; i++) {
            for (int j = 0; j < n; j++) {
                arr[i][j] = sc.nextInt();
            }
            Arrays.sort(arr[i]);
        }
        int[] a = new int[n];
        int[] c = new int[n];
        for (int i = 0; i < n; i++) {
            int tmp = arr[1][i];
            a[i] = div(arr[0],tmp)+1;
        }
        for (int i = 0; i < n; i++) {
            int tmp = arr[1][i];
            c[i] = n - big(arr[2],tmp);
        }
        long ans = 0;
        for (int i = 0; i < n; i++) {
            ans += (long)a[i]*c[i];
        }
        System.out.println(ans);
    }
    public static int div(int[] tar,int val){
        int l = 0;
        int r = tar.length-1;
        int ans = -1;
        while(l <= r){
            int mid = l + (r - l) / 2;
            if(val > tar[mid]){
                l = mid + 1; 
                ans = mid;
            }else{
                r = mid - 1;
            }
        }
        return ans;
    }
    public static int big(int[] tar,int val){
        int l = 0;
        int r = tar.length-1;
        int ans = tar.length;
        while(l <= r){
            int mid = l + (r - l) / 2;
            if(tar[mid] > val){
                r = mid - 1;
                ans = mid;
            }else{
                l = mid + 1;
            }
        }
        return ans;
    }
}
import java.util.*;
public class Main{
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[] sm = new int[100010];
        int[] big = new int[100010];
        for (int i = 0; i < n; i++) {
            int t = sc.nextInt();
            sm[t]++;
        }
        int[] b= new int[n];
        for (int i = 0; i < n; i++) {
            b[i] = sc.nextInt();
        }
        for (int i = 0; i < n; i++) {
            int t = sc.nextInt();
            big[t]++;
        }
        for (int i = 1; i < 100010; i++) {
            sm[i] += sm[i-1];
        }
        for(int i = 100008; i >= 0;i--){
            big[i] += big[i+1];
        }
        long ans = 0;
        for (int i = 0; i < n; i++) {
            int t = b[i];
            if(t == 0) continue;
            ans += (long)sm[t-1] * big[t + 1];
        }
        System.out.println(ans);
    }
}

上一篇:Radio Silence for mac 好用的防火墙软件


下一篇:JVM堆栈详解