Description
给定一个正整数序列 \(a_1,a_2,\cdots,a_n\),求
\[\sum_{i=1}^n\sum_{j=i}^n(j-i+1)\min(a_i,a_{i+1},\cdots,a_j)\max(a_i,a_{i+1},\cdots,a_j)
\]
\]
Input
第 \(1\) 行,一个整数 \(n\);
第 \(2\dots n+1\) 行,每行一个整数表示序列 \(a\)。
Output
输出答案对 \(10^9\) 取模后的结果。
Sample Input
4
2
4
1
4
Sample Output
109
HINT
\(n \le 500000,1 \le a_i \le 10^8\)
Solution
CDQ分治,从右往左枚举 \([l,mid]\) 的每个点 \(i\),设 \([i,mid]\) 的最小值为 \(mn\),最大值为 \(mx\),同时在 \([mid+1,r]\) 中维护两个指针 \(a,b\),满足 \(\min[mid+1,a]\ge mn,\max[mid+1,b]\le mx\)。假设 \(a<b\),那么 \([mid+1,r]\) 就被分成了三块,我们分别考虑 \(j\) 在每个块内的答案:
- 若 \(mid<j\le a\),
\[\begin{eqnarray}
ans&=&mn\cdot mx\sum_{j=mid+1}^{a}(j-i+1)\\
&=&\frac{(a+mid-2i+3)(a-mid)}{2}
\end{eqnarray}
\]
ans&=&mn\cdot mx\sum_{j=mid+1}^{a}(j-i+1)\\
&=&\frac{(a+mid-2i+3)(a-mid)}{2}
\end{eqnarray}
\]
- 若 \(a< j \le b\),
\[\begin{eqnarray}
ans&=&mx\cdot\sum_{j=a+1}^b(j-i+1)\min[a+1,j]\\
&=&mx\left(\sum_{j=a+1}^bj\cdot\min[a+1,j]-(i-1)\sum_{j=a+1}^b\min[a+1,j]\right)\\
&=&mx\left(\sum_{j=a+1}^bj\cdot\min[1,j]-(i-1)\sum_{j=a+1}^b\min[1,j]\right)
\end{eqnarray}
\]
ans&=&mx\cdot\sum_{j=a+1}^b(j-i+1)\min[a+1,j]\\
&=&mx\left(\sum_{j=a+1}^bj\cdot\min[a+1,j]-(i-1)\sum_{j=a+1}^b\min[a+1,j]\right)\\
&=&mx\left(\sum_{j=a+1}^bj\cdot\min[1,j]-(i-1)\sum_{j=a+1}^b\min[1,j]\right)
\end{eqnarray}
\]
- 若 \(b< j\le r\),
\[\begin{eqnarray}
ans&=&\sum_{j=b+1}^r(j-i+1)\min[b+1,j]\max[b+1,j]\\
&=&\sum_{j=b+1}^rj\cdot\min[b+1,j]\max[b+1,j]-(i-1)\sum_{j=b+1}^r\min[b+1,j]\max[b+1,j]\\
&=&\sum_{j=b+1}^rj\cdot\min[1,j]\max[1,j]-(i-1)\sum_{j=b+1}^r\min[1,j]\max[1,j]
\end{eqnarray}
\]
ans&=&\sum_{j=b+1}^r(j-i+1)\min[b+1,j]\max[b+1,j]\\
&=&\sum_{j=b+1}^rj\cdot\min[b+1,j]\max[b+1,j]-(i-1)\sum_{j=b+1}^r\min[b+1,j]\max[b+1,j]\\
&=&\sum_{j=b+1}^rj\cdot\min[1,j]\max[1,j]-(i-1)\sum_{j=b+1}^r\min[1,j]\max[1,j]
\end{eqnarray}
\]
其中 \(\sum_{j=a+1}^bj\cdot\min[1,j]\),\(\sum_{j=a+1}^b\min[1,j]\),\(\sum_{j=b+1}^rj\cdot\min[1,j]\max[1,j]\),\(\sum_{j=b+1}^r\min[1,j]\max[1,j]\) 可以预处理前缀和得到。
Code
#include <cstdio>
#include <algorithm>
using std::max; using std::min;
const int N = 500005, mod = 1000000000;
int a[N], n, s1[N], s2[N], s3[N], s4[N], s5[N], s6[N], ans;
int read() {
int x = 0; char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
return x;
}
void solve(int l, int r) {
if (l == r) { ans = (ans + 1LL * a[l] * a[l]) % mod; return; }
int mid = (l + r) >> 1, mn = mod, mx = 0, x;
solve(l, mid), solve(mid + 1, r);
for (int i = mid + 1; i <= r; ++i) {
mn = min(mn, a[i]), mx = max(mx, a[i]);
s1[i] = (s1[i - 1] + 1LL * i * mn) % mod;
if ((s2[i] = s2[i - 1] + mn) >= mod) s2[i] -= mod;
s3[i] = (s3[i - 1] + 1LL * i * mx) % mod;
if ((s4[i] = s4[i - 1] + mx) >= mod) s4[i] -= mod;
s5[i] = (s5[i - 1] + 1LL * i * mn % mod * mx) % mod;
s6[i] = (s6[i - 1] + 1LL * mn * mx) % mod;
}
mn = mod, mx = 0;
for (int i = mid, j = mid, k = mid; i >= l; --i) {
mn = min(mn, a[i]), mx = max(mx, a[i]);
while (j < r && a[j] >= mn && a[j + 1] >= mn) ++j;
while (k < r && a[k] <= mx && a[k + 1] <= mx) ++k;
if ((x = min(j, k))) ans = (ans + 1LL * mn * mx % mod * (((x + mid - i - i + 3LL) * (x - mid) >> 1) % mod)) % mod;
if (j < k) ans = (ans + ((s1[k] - s1[j]) - 1LL * (i - 1) * (s2[k] - s2[j])) % mod * mx) % mod;
if (j > k) ans = (ans + ((s3[j] - s3[k]) - 1LL * (i - 1) * (s4[j] - s4[k])) % mod * mn) % mod;
if ((x = max(j, k)) < r) ans = (ans + s5[r] - s5[x] - 1LL * (i - 1) * (s6[r] - s6[x])) % mod;
}
}
int main() {
n = read();
for (int i = 1; i <= n; ++i) a[i] = read();
solve(1, n);
if (ans < 0) ans += mod;
printf("%d\n", ans);
return 0;
}