LeetCode 891. 子序列宽度之和
一道hard题目,经历了多轮优化。
题目描述
给定一个整数数组 A ,考虑 A 的所有非空子序列。
对于任意序列 S ,设 S 的宽度是 S 的最大元素和最小元素的差。
返回 A 的所有子序列的宽度之和。
由于答案可能非常大,请返回答案模 10^9+7。
示例:
输入:[2,1,3]
输出:6
解释:
子序列为 [1],[2],[3],[2,1],[2,3],[1,3],[2,1,3] 。
相应的宽度是 0,0,0,1,1,2,2 。
这些宽度之和是 6 。
提示:
- 1 <= A.length <= 20000
- 1 <= A[i] <= 20000
解题思路
第一次尝试,DFS 超时。
思路为:枚举所有可能的组合,,并求其组合中最大最小数的差值,自然得出计算结果。也就是使用普通的 dfs 来解,结果自然是超时了。
class Solution {
constexpr static int64_t MOD = 1e9+7;
int64_t max_min(vector<int>& sub) {
if (sub.empty()) return 0;
int maxVal = sub[0];
int minVal = sub[0];
for (int x : sub) {
maxVal = max(maxVal, x);
minVal = min(minVal, x);
}
return (int64_t)maxVal - minVal;
}
void dfs(int64_t& res, vector<int>& nums, vector<int>& sub, int start) {
if (start >= nums.size()) {
res += max_min(sub);
res %= MOD;
return;
}
{
sub.push_back(nums[start]);
dfs(res, nums, sub, start + 1);
sub.pop_back();
}
{
dfs(res, nums, sub, start + 1);
}
}
public:
int sumSubseqWidths(vector<int>& nums) {
int64_t res = 0;
vector<int> sub;
dfs(res, nums, sub, 0);
return res % MOD;
}
}; // TLE
第二次尝试,乘方计算,溢出。
既然直接枚举所有组合的办法会超时,那么能不能不给出组合,直接获得各种组合的最值之差呢?可以,只需要枚举最大值和最小值就可以了。
这里我们首先对数组进行排序,然后枚举每种可能的最值组合,每种组合对应的子序列个数就是 [i,j] 区间中去掉 i,j 之后的元素的2的幂次方个。
然而这份代码会报错,原因是整数计算溢出。元素最大有 2w 个,所以组合的个数可以达到 2^(2w),这显然不是普通的 int64 类型能够存下的,自热会溢出。
class Solution {
constexpr static int64_t MOD = 1e9+7;
public:
int sumSubseqWidths(vector<int>& nums) {
int64_t res = 0;
size_t n = nums.size();
sort(nums.begin(), nums.end());
for (int i=0; i<n; i++) {
for (int j=i+1; j<n; j++) {
res += ((nums[j] - nums[i]) * (1LL << (j-i-1)));
res %= MOD;
}
}
return res;
}
};
第三次尝试,避免溢出的乘方计算,超时。
仔细计算,每次计算指数的时候,中间插入取模操作以避免溢出,结果倒是不溢出了,可是还是超时。
class Solution {
constexpr static int64_t MOD = 1e9+7;
public:
int sumSubseqWidths(vector<int>& nums) {
int64_t res = 0;
size_t n = nums.size();
sort(nums.begin(), nums.end());
for (int i=0; i<n; i++) {
for (int j=i+1; j<n; j++) {
// res += ((nums[j] - nums[i]) * (1LL << (j-i-1)));
// res %= MOD;
// 避免溢出 // 1 <= n <= 20000, 1 <= nums[i] <= 20000
int64_t exp2 = 1;
int k = j-i-1;
while (k > 30) {
exp2 *= (1LL << 30);
exp2 %= MOD;
k -= 30;
}
exp2 *= (1LL << k);
exp2 %= MOD;
res += ((nums[j] - nums[i]) * exp2);
res %= MOD;
// 重复计算太多次,还是TLE
}
}
return res;
}
};
第四次尝试,带记忆的、避免溢出的乘方计算,超时。
通过记忆化,查表的方式,来避免重复计算高次幂,能够避免超时吗?还不够,还是会超时。
int64_t exp2s[20003];
bool inited = false;
class Solution {
constexpr static int64_t MOD = 1e9+7;
public:
int sumSubseqWidths(vector<int>& nums) {
if (!inited) {
memset(exp2s, -1, sizeof(exp2s));
inited = true;
}
int64_t res = 0;
size_t n = nums.size();
sort(nums.begin(), nums.end());
for (int i=0; i<n; i++) {
for (int j=i+1; j<n; j++) {
// res += ((nums[j] - nums[i]) * (1LL << (j-i-1)));
// res %= MOD;
if (exp2s[j-i-1] < 0) {
int64_t exp2 = 1;
int k = j-i-1;
while (k > 30) {
exp2 *= (1LL << 30);
exp2 %= MOD;
k -= 30;
}
exp2 *= (1LL << k);
exp2 %= MOD;
exp2s[j-i-1] = exp2;
}
res += ((nums[j] - nums[i]) * exp2s[j-i-1]);
res %= MOD;
}
}
return res;
}
};
进一步减少计算,公式变形。
这里我们看到上面的方法中,累加计算的次数是 N^2 次,能不能降低到 N 次?可以!
我们知道这道题的答案 res = sum((nums[j] - nums[i]) * 2 ^ (j-i-1))
,如何把这个公式里的 i 和 j 两个变量变成1个呢?
我们换一种视角,不去看每个组合对结果的贡献值,而是去考察每个元素的贡献值。显然贡献有两种,作为最大值的时候是正贡献,作为最小值的时候是负贡献。对于第i个元素,其正贡献的组合个数是 2 ^ (i-1)
个,负贡献的组合个数是 2 ^ (n-i-1)
个。从而一边循环就可以计算出结果。
constexpr static int64_t MOD = 1e9+7;
int64_t exp2s[20003];
bool inited = false;
int64_t exp2(int k) {
if (!inited) {
memset(exp2s, -1, sizeof(exp2s));
inited = true;
}
if (exp2s[k] >= 0) {
return exp2s[k];
}
int64_t exp2 = 1;
int t = k;
while (t > 20) {
exp2 *= (1LL << 20);
exp2 %= MOD;
t -= 20;
}
exp2 *= (1LL << t);
exp2 %= MOD;
return exp2s[k] = exp2;
}
class Solution {
public:
int sumSubseqWidths(vector<int>& nums) {
int64_t res = 0;
size_t n = nums.size();
sort(nums.begin(), nums.end());
for (int i=0; i<n; i++) {
res += nums[i] * (exp2(i) - exp2(n-i-1));
res %= MOD;
}
return res;
}
};
然而这里还是不对,会报错说 Line 38: Char 17: runtime error: -3.16913e+30 is outside the range of representable values of type ‘long long‘ (solution.cpp)
也就是 res += nums[i] * (exp2(i) - exp2(n-i-1));
这一句溢出了。
奇怪的是 我们的 exp2 返回结果是对 1e9+7 取模的,nums[i] 也是 1到2w之间的一个整数,怎么都不应该得到 -3e30 这个量级的数。
如果我们对上面的代码做一点小改动,就能让这一个case通过,那就是把 size_t 改为 int。
经过排查发现,原来是STL 里有一个 std::exp2() 函数,这个函数支持的参数类型为 double,我们调用的 exp2 被转发到 std::exp2 中了……
给函数改个名字就好了。
然而还是不对,还会core,因为leetcode 不讲武德,说好了数组大小不超过2w,结果有的case元素有3w+,还有一个有10w个元素?!!!
之后仍然 TLE,因为我们计算exp2函数每次只存了最终结果,没有缓存中间结果,所以基本还是每次从头算了一遍exp2 …… (orz 最近写题状态有问题啊)
以下是最终版代码
参考代码
constexpr static int64_t MOD = 1e9+7;
int64_t exp2s[100003];
bool inited = false;
int64_t exp2mod1e9_7(int k) {
if (!inited) {
inited = true;
int64_t exp2 = 1;
for (int i=0; i<100003; i++) {
exp2s[i] = exp2;
exp2 = (exp2 << 1) % MOD;
}
}
if (k < 100003) {
return exp2s[k];
}
int t = 100001;
int64_t exp2 = exp2s[t];
for (; t + 30 < k; t+=30) {
exp2 *= (1LL << 30) % MOD;
exp2 %= MOD;
}
exp2 *= (1LL << (k-t)) % MOD;
exp2 %= MOD;
return exp2;
}
class Solution {
public:
int sumSubseqWidths(vector<int>& nums) {
int64_t res = 0;
size_t n = nums.size();
sort(nums.begin(), nums.end());
for (int i=0; i<n; i++) {
res += nums[i] * (exp2mod1e9_7(i) - exp2mod1e9_7(n-i-1));
res %= MOD;
}
return res;
}
}; // AC