[LeetCode 891.] 子序列宽度之和【hard】

LeetCode 891. 子序列宽度之和

一道hard题目,经历了多轮优化。

题目描述

给定一个整数数组 A ,考虑 A 的所有非空子序列。

对于任意序列 S ,设 S 的宽度是 S 的最大元素和最小元素的差。

返回 A 的所有子序列的宽度之和。

由于答案可能非常大,请返回答案模 10^9+7。

示例:

输入:[2,1,3]
输出:6
解释:
子序列为 [1],[2],[3],[2,1],[2,3],[1,3],[2,1,3] 。
相应的宽度是 0,0,0,1,1,2,2 。
这些宽度之和是 6 。

提示:

  • 1 <= A.length <= 20000
  • 1 <= A[i] <= 20000

解题思路

第一次尝试,DFS 超时。

思路为:枚举所有可能的组合,,并求其组合中最大最小数的差值,自然得出计算结果。也就是使用普通的 dfs 来解,结果自然是超时了。

class Solution {
    constexpr static int64_t MOD = 1e9+7;
    int64_t max_min(vector<int>& sub) {
        if (sub.empty()) return 0;
        int maxVal = sub[0];
        int minVal = sub[0];
        for (int x : sub) {
            maxVal = max(maxVal, x);
            minVal = min(minVal, x);
        }
        return (int64_t)maxVal - minVal;
    }
    void dfs(int64_t& res, vector<int>& nums, vector<int>& sub, int start) {
        if (start >= nums.size()) {
            res += max_min(sub);
            res %= MOD;
            return;
        }
        {
            sub.push_back(nums[start]);
            dfs(res, nums, sub, start + 1);
            sub.pop_back();
        }
        {
            dfs(res, nums, sub, start + 1);
        }
    }
public:
    int sumSubseqWidths(vector<int>& nums) {
        int64_t res = 0;
        vector<int> sub;
        dfs(res, nums, sub, 0);
        return res % MOD;
    }
}; // TLE

第二次尝试,乘方计算,溢出。

既然直接枚举所有组合的办法会超时,那么能不能不给出组合,直接获得各种组合的最值之差呢?可以,只需要枚举最大值和最小值就可以了。
这里我们首先对数组进行排序,然后枚举每种可能的最值组合,每种组合对应的子序列个数就是 [i,j] 区间中去掉 i,j 之后的元素的2的幂次方个。
然而这份代码会报错,原因是整数计算溢出。元素最大有 2w 个,所以组合的个数可以达到 2^(2w),这显然不是普通的 int64 类型能够存下的,自热会溢出。

class Solution {
    constexpr static int64_t MOD = 1e9+7;
public:
    int sumSubseqWidths(vector<int>& nums) {
        int64_t res = 0;
        size_t n = nums.size();
        sort(nums.begin(), nums.end());
        for (int i=0; i<n; i++) {
            for (int j=i+1; j<n; j++) {
                res += ((nums[j] - nums[i]) * (1LL << (j-i-1)));
                res %= MOD;
            }
        }
        return res;
    }
};

第三次尝试,避免溢出的乘方计算,超时。

仔细计算,每次计算指数的时候,中间插入取模操作以避免溢出,结果倒是不溢出了,可是还是超时。

class Solution {
    constexpr static int64_t MOD = 1e9+7;
public:
    int sumSubseqWidths(vector<int>& nums) {
        int64_t res = 0;
        size_t n = nums.size();
        sort(nums.begin(), nums.end());
        for (int i=0; i<n; i++) {
            for (int j=i+1; j<n; j++) {
                // res += ((nums[j] - nums[i]) * (1LL << (j-i-1)));
                // res %= MOD;

                // 避免溢出 // 1 <= n <= 20000, 1 <= nums[i] <= 20000
                int64_t exp2 = 1;
                int k = j-i-1;
                while (k > 30) {
                    exp2 *= (1LL << 30);
                    exp2 %= MOD;
                    k -= 30;
                }
                exp2 *= (1LL << k);
                exp2 %= MOD;
                res += ((nums[j] - nums[i]) * exp2);
                res %= MOD;
                // 重复计算太多次,还是TLE
            }
        }
        return res;
    }
};

第四次尝试,带记忆的、避免溢出的乘方计算,超时。

通过记忆化,查表的方式,来避免重复计算高次幂,能够避免超时吗?还不够,还是会超时。

int64_t exp2s[20003];
bool inited = false;

class Solution {
    constexpr static int64_t MOD = 1e9+7;
public:
    int sumSubseqWidths(vector<int>& nums) {
        if (!inited) {
            memset(exp2s, -1, sizeof(exp2s));
            inited = true;
        }

        int64_t res = 0;
        size_t n = nums.size();
        sort(nums.begin(), nums.end());
        for (int i=0; i<n; i++) {
            for (int j=i+1; j<n; j++) {
                // res += ((nums[j] - nums[i]) * (1LL << (j-i-1)));
                // res %= MOD;

                if (exp2s[j-i-1] < 0) {
                    int64_t exp2 = 1;
                    int k = j-i-1;
                    while (k > 30) {
                        exp2 *= (1LL << 30);
                        exp2 %= MOD;
                        k -= 30;
                    }
                    exp2 *= (1LL << k);
                    exp2 %= MOD;

                    exp2s[j-i-1] = exp2;
                }
                res += ((nums[j] - nums[i]) * exp2s[j-i-1]);
                res %= MOD;
            }
        }
        return res;
    }
};

进一步减少计算,公式变形。

这里我们看到上面的方法中,累加计算的次数是 N^2 次,能不能降低到 N 次?可以!
我们知道这道题的答案 res = sum((nums[j] - nums[i]) * 2 ^ (j-i-1)),如何把这个公式里的 i 和 j 两个变量变成1个呢?
我们换一种视角,不去看每个组合对结果的贡献值,而是去考察每个元素的贡献值。显然贡献有两种,作为最大值的时候是正贡献,作为最小值的时候是负贡献。对于第i个元素,其正贡献的组合个数是 2 ^ (i-1) 个,负贡献的组合个数是 2 ^ (n-i-1) 个。从而一边循环就可以计算出结果。

constexpr static int64_t MOD = 1e9+7;

int64_t exp2s[20003];
bool inited = false;
int64_t exp2(int k) {
    if (!inited) {
        memset(exp2s, -1, sizeof(exp2s));
        inited = true;
    }

    if (exp2s[k] >= 0) {
        return exp2s[k];
    }
    int64_t exp2 = 1;
    int t = k;
    while (t > 20) {
        exp2 *= (1LL << 20);
        exp2 %= MOD;
        t -= 20;
    }
    exp2 *= (1LL << t);
    exp2 %= MOD;

    return exp2s[k] = exp2;
}

class Solution {
public:
    int sumSubseqWidths(vector<int>& nums) {
        int64_t res = 0;
        size_t n = nums.size();
        sort(nums.begin(), nums.end());
        for (int i=0; i<n; i++) {
            res += nums[i] * (exp2(i) - exp2(n-i-1));
            res %= MOD;
        }
        return res;
    }
};

然而这里还是不对,会报错说 Line 38: Char 17: runtime error: -3.16913e+30 is outside the range of representable values of type ‘long long‘ (solution.cpp) 也就是 res += nums[i] * (exp2(i) - exp2(n-i-1)); 这一句溢出了。
奇怪的是 我们的 exp2 返回结果是对 1e9+7 取模的,nums[i] 也是 1到2w之间的一个整数,怎么都不应该得到 -3e30 这个量级的数。
如果我们对上面的代码做一点小改动,就能让这一个case通过,那就是把 size_t 改为 int。
经过排查发现,原来是STL 里有一个 std::exp2() 函数,这个函数支持的参数类型为 double,我们调用的 exp2 被转发到 std::exp2 中了……
给函数改个名字就好了。

然而还是不对,还会core,因为leetcode 不讲武德,说好了数组大小不超过2w,结果有的case元素有3w+,还有一个有10w个元素?!!!

之后仍然 TLE,因为我们计算exp2函数每次只存了最终结果,没有缓存中间结果,所以基本还是每次从头算了一遍exp2 …… (orz 最近写题状态有问题啊)

以下是最终版代码

参考代码

constexpr static int64_t MOD = 1e9+7;

int64_t exp2s[100003];
bool inited = false;
int64_t exp2mod1e9_7(int k) {
    if (!inited) {
        inited = true;
        int64_t exp2 = 1;
        for (int i=0; i<100003; i++) {
            exp2s[i] = exp2;
            exp2 = (exp2 << 1) % MOD;
        }
    }

    if (k < 100003) {
        return exp2s[k];
    }

    int t = 100001;
    int64_t exp2 = exp2s[t];
    for (; t + 30 < k; t+=30) {
        exp2 *= (1LL << 30) % MOD;
        exp2 %= MOD;
    }
    exp2 *= (1LL << (k-t)) % MOD;
    exp2 %= MOD;
    return exp2;
}

class Solution {
public:
    int sumSubseqWidths(vector<int>& nums) {
        int64_t res = 0;
        size_t n = nums.size();
        sort(nums.begin(), nums.end());
        for (int i=0; i<n; i++) {
            res += nums[i] * (exp2mod1e9_7(i) - exp2mod1e9_7(n-i-1));
            res %= MOD;
        }
        return res;
    }
}; // AC

[LeetCode 891.] 子序列宽度之和【hard】

上一篇:JS处理Java的Long类型数据精度丢失问题


下一篇:Mybatis-基本学习(下)