Link.
Description.
给定一个序列 \(\{a_i\}\),定义 \(f(l,r)\) 表示 \([l,r]\) 中取出一些不相邻数的最大和。
求 \(\sum_{l=1}^n\sum_{r=l}^nf(l,r)\)。
Solution.
分治,拆贡献,然后接下来需要考虑跨过 \(l\le md,r>md\) 的 \(f(l,r)\) 贡献。
左右两边除了选中点之外互不干扰,我们考虑直接 dp
的合并。
先左边做一遍 dp
,右边做一遍 dp
,设 \(f_l(i),g_l(i),f_r(i),g_r(i)\) 表示从 \(md/md+1\) dp
到 \(i\),\(md/md+1\) 选或不选的最大和。
答案 \(f(a,b)\) 就是 \(\max(f_l(a)+g_r(b),g_l(a)+f_r(b))\)。
考虑决策,若 \(f_l(a)+g_r(b)<g_l(a)+f_r(b)\),则有 \(f_r(b)-g_r(b)>f_l(a)-g_l(a)\)
然后直接左边按照 \(f_l(a)-g_l(b)\) 排序,就能找到决策点。
然后维护一下 \(f_l(x)\) 的前缀和和 \(g_l(x)\) 的后缀和即可。
Coding.
点击查看代码
//是啊,你就是那只鬼了,所以被你碰到以后,就轮到我变成鬼了{{{
#include<bits/stdc++.h>
using namespace std;typedef long long ll;
template<typename T>inline void read(T &x)
{
x=0;char c=getchar(),f=0;
for(;c<48||c>57;c=getchar()) if(!(c^45)) f=1;
for(;c>=48&&c<=57;c=getchar()) x=(x<<1)+(x<<3)+(c^48);
f?x=-x:x;
}
template<typename T,typename...L>inline void read(T &x,L&...l) {read(x),read(l...);}/*}}}*/
const int N=100005,P=1e9+7;int n,a[N],id[N],rs;
ll fl[N],gl[N],fr[N],gr[N],sf[N],sg[N];
struct ${ll f,g;char operator<($ b) const {return f-g<b.f-b.g;}}F[N];
inline void solve(int l,int r)
{
int md=(l+r)>>1;if(l==r) return (rs+=a[l])%=P,void();else solve(l,md),solve(md+1,r);
fl[md+1]=0,fl[md]=a[md];for(int i=md-1;i>=l;i--) fl[i]=max(fl[i+1],fl[i+2]+a[i]);
gl[md]=0,gl[md-1]=a[md-1];for(int i=md-2;i>=l;i--) gl[i]=max(gl[i+1],gl[i+2]+a[i]);
fr[md]=0,fr[md+1]=a[md+1];for(int i=md+2;i<=r;i++) fr[i]=max(fr[i-1],fr[i-2]+a[i]);
gr[md+1]=0,gr[md+2]=a[md+2];for(int i=md+3;i<=r;i++) gr[i]=max(gr[i-1],gr[i-2]+a[i]);
int tt=0;for(int i=l;i<=md;i++) F[++tt]=($){fl[i],gl[i]};
sort(F+1,F+tt+1);for(int i=1;i<=tt;i++) sf[i]=sf[i-1]+F[i].f,sg[i]=sg[i-1]+F[i].g;
//printf("%d %d %d : %d\n",l,md,r,rs);
//if(l==1&&md==2&&r==3) for(int i=1;i<=tt;i++) printf("ll %lld %lld\n",F[i].f,F[i].g);
for(int i=md+1;i<=r;i++)
{
int wh=lower_bound(F+1,F+tt+1,($){fr[i],gr[i]})-F-1;
//if(l==1&&md==2&&r==3) printf("rr %lld %lld\n",fr[i],gr[i]);
rs=(rs+(sf[tt]-sf[wh])+gr[i]%P*(tt-wh))%P;
rs=(rs+sg[wh]+fr[i]%P*wh)%P;
}
//printf("%d %d %d : %d\n",l,md,r,rs);
}
int main()
{
read(n);for(int i=1;i<=n;i++) read(a[i]);
return solve(1,n),printf("%d\n",rs),0;
}