题目:https://loj.ac/problem/3059
一段 A 选一个 B 的话, B 是这段 A 的平均值。因为 \( \sum (A_i-B)^2 = \sum A_i^2 - 2*B \sum A_i + len*B^2 \) ,这是关于 B 的二次方程,对称轴是 \( B = - \frac{-2*\sum A_i}{2*len} \) ,恰是 A 的平均值。
所以自己前 10 分写了 “ dp[ i ][ j ] 表示前 i 个 A 、最后一段的 B = j ” 的 DP , n,m <= 100 的写了 “ dp[ i ] 表示前 i 个 A 的答案、转移枚举 i 所在的段到哪为止 ” 的 DP 。
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define db double
using namespace std;
int rdn()
{
int ret=;bool fx=;char ch=getchar();
while(ch>''||ch<''){if(ch=='-')fx=;ch=getchar();}
while(ch>=''&&ch<='')ret=ret*+ch-'',ch=getchar();
return fx?ret:-ret;
}
const int mod=;
int upt(int x){while(x>=mod)x-=mod;while(x<)x+=mod;return x;}
int pw(int x,int k)
{int ret=;while(k){if(k&)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=;}return ret;} int n,m;
namespace S1{
const int N=,M=; const ll INF=1e16;
ll Mn(ll a,ll b){return a<b?a:b;}
ll Mx(ll a,ll b){return a>b?a:b;}
int a[N]; ll dp[N][M],f[N][M];
ll Sqr(int x){return (ll)x*x;}
int calc()
{
int mx=;
for(int i=;i<=n;i++)mx=Mx(mx,a[i]);
dp[][]=;
for(int i=;i<=n;i++)
{
f[i][]=INF;
for(int j=;j<=mx;j++)
{
ll sm=Sqr(j-a[i]); dp[i][j]=INF;
for(int k=i-;k>=;k--)
{
dp[i][j]=Mn(dp[i][j],f[k][j]+sm);
if(k)sm+=Sqr(j-a[k]);
}
f[i][j]=Mn(f[i][j-],dp[i][j]);
}
}
ll ret=INF;
for(int j=;j<=mx;j++) ret=Mn(ret,dp[n][j]);
return ret%mod;
}
void solve()
{
for(int i=;i<=n;i++) a[i]=rdn();
printf("%d\n",calc());
for(int i=,u,k,d;i<=m;i++)
{
u=rdn();k=rdn(); d=a[u];a[u]=k;
printf("%d\n",calc());
a[u]=d;
}
}
}
namespace S2{
const int N=; const db INF=1e16;
int a[N];db dp[N],f[N];int ans[N];
db cal(int l,int r,db d)
{
db ret=;
for(int i=l;i<=r;i++)
ret+=(a[i]-d)*(a[i]-d);
return ret;
}
int cal2(int l,int r,ll sm)
{
sm=(ll)sm*pw(r-l+,mod-)%mod;
int ret=;
for(int i=l;i<=r;i++)
ret=(ret+(ll)(a[i]-sm)*(a[i]-sm))%mod;
return ret;
}
int calc()
{
for(int i=;i<=n;i++)
{
db sm=a[i]; dp[i]=INF;
for(int j=i-;j>=;j--)
{
db d=sm/(i-j);
if(d>=f[j])
{
db k=cal(j+,i,d);
if((dp[j]+k<dp[i])||(dp[j]+k==dp[i]&&d<f[i]))
{
dp[i]=dp[j]+k; f[i]=d;
ans[i]=upt(ans[j]+cal2(j+,i,sm));
}
}
sm+=a[j];
}
}
return ans[n];
}
void solve()
{
for(int i=;i<=n;i++) a[i]=rdn();
printf("%d\n",calc());
for(int i=,u,k,d;i<=m;i++)
{
u=rdn();k=rdn(); d=a[u];a[u]=k;
printf("%d\n",calc());
a[u]=d;
}
}
}
int main()
{
n=rdn();m=rdn();
if(n<=){S1::solve();return ;}
if(n<=){S2::solve();return ;}
return ;
}
应该更大胆一点。结论是可以贪心做那个 DP 的过程,用栈维护现有的 A 的段,如果往后添一个 A 会使得栈顶段平均值 > 栈顶前一个段平均值,就把栈顶和它前面那个段合并起来;则合并后的段平均值比原来 “栈顶前面那个段” 的平均值大,不会使更前面不合法。这样就有 50 分了。
考虑每次有修改一个位置该怎么做。
一个很好的思路是预处理每个前缀、后缀的栈的样子(用主席树存各时刻的栈),询问的时候拼一下即可。
刚才那个贪心的过程,不是从前往后做而是从后往前做,做出来的栈的形态还是一样的。因为在一个 A 的最优划分中,任意一个 A 换一下所属的段都不会变优;如果是等价的话,会分成尽量多的段,所以从前往后还是从后往前与最后的形态无关。
预处理的东西用主席树存起来。自己的写法是线段树第 i 个位置存了第 i 个段的右/左端点和平均值,区间存的是区间里最后一个/第一个段的信息。
如果知道修改的这个位置所属的段是 [ L , R ] ,那么 [ 1 , L-1 ] 部分的划分就是预处理出的那个,[ R+1 , n ] 的划分也是预处理出的,所以找一下 [ L , R ] 是哪即可。
找 [ L , R ] 可以在线段树上二分。设修改位置是 qi ,先找 R ,在表示 [ qi+1 , n ] 这个后缀的线段树上二分(就是每次看一下 mid+1 是否可行),如果 mid+1 可行,就进左孩子里找,因为段数越多越优;可行的意思是 mid+1 这个段作为 R 后面的第一个段是否满足 “不降” 的要求。
固定一个 R ,可以一样地在线段树上找到它对应的 L ,就是在表示 [ 1 , qi-1 ] 的线段树上二分,看 mid 作为 L 前面的第一个段是否可行;mid 是 L 前面的第一个段的话, mid 这段的右端点就是 L ,又知 R ,就可以求出 qi 所在段的平均值,看看是不是比 mid 这段大于等于即可。如果 mid 可行,就尝试去右孩子找(如果找不到还是要返回 mid )。
所以判断 “ mid+1 这个段作为 R 后面的第一个段是否可行 ” 的流程就是先找出 R (此时的 R 就是 mid+1 这段左端点的前一个位置)对应的 L ,则已知 [ L , R ] 的平均值,看看是不是比 mid+1 这一段的平均值小。
算答案的时候别直接用原来的式子枚举 [ L , R ] 地算,用那个 \( \sum A_i^2 - 2*B\sum A_i + len*B^2 \) ,预处理 A 的前缀和、 A2 的前缀和即可。
#include<cstdio>
#include<cstring>
#include<algorithm>
#define db double
#define ll long long
#define ls Ls[cr]
#define rs Rs[cr]
using namespace std;
int rdn()
{
int ret=;bool fx=;char ch=getchar();
while(ch>''||ch<''){if(ch=='-')fx=;ch=getchar();}
while(ch>=''&&ch<='')ret=ret*+ch-'',ch=getchar();
return fx?ret:-ret;
}
const int N=1e5+,M=5e6+,mod=;
int upt(int x){while(x>=mod)x-=mod;while(x<)x+=mod;return x;}
int pw(int x,int k)
{int ret=;while(k){if(k&)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=;}return ret;}
int Sqr(int x){return (ll)x*x%mod;} int n,a[N],a2[N],qi,qk,qk2,s2[N];ll s[N];
struct Node{ int p;db v; Node(int p=,db v=):p(p),v(v) {} }tI;
struct Dt{ int x,y; Dt(int x=,int y=):x(x),y(y) {} }dI;
db cal(int l,int r)
{
ll ret=s[r]-s[l-]; if(qi>=l&&qi<=r)ret+=qk-a[qi];
return (db)ret/(r-l+);
}
int cal2(int l,int r)
{
ll d=s[r]-s[l-]; if(qi>=l&&qi<=r)d+=qk-a[qi];
ll x=d%mod*pw(r-l+,mod-)%mod; d%=mod;//d%=mod
int ret=upt(s2[r]-s2[l-]);
if(qi>=l&&qi<=r)ret=upt(ret+qk2-a2[qi]);
ret=upt(ret-*x*d%mod);
ret=(ret+(ll)(r-l+)*x%mod*x)%mod;
return ret;
}
namespace P{
int ct[N],ans[N];
int tot,rt[N],Ls[M],Rs[M],dfn[M],tim; bool tg[M];
Node vl[M],I;
int nwnd(int pr)
{
if(pr&&dfn[pr]==tim)return pr;
int cr=++tot; dfn[cr]=tim; if(!pr)return cr;
ls=Ls[pr]; rs=Rs[pr]; vl[cr]=vl[pr]; tg[cr]=tg[pr];
return cr;
}
void pshd(int cr)
{
if(!tg[cr])return; tg[cr]=;
ls=nwnd(ls); rs=nwnd(rs); vl[ls]=vl[rs]=; tg[ls]=tg[rs]=;
}
void pshp(int cr){if(vl[rs].p)vl[cr]=vl[rs]; else vl[cr]=vl[ls];}
void ins(int l,int r,int &cr,int p,Node k)
{
cr=nwnd(cr); if(l==r){vl[cr]=k;return;}
int mid=l+r>>; pshd(cr);
if(p<=mid)ins(l,mid,ls,p,k); else ins(mid+,r,rs,p,k);
pshp(cr);
}
void mdfy(int l,int r,int &cr,int L,int R)
{
cr=nwnd(cr); if(l>=L&&r<=R){vl[cr]=I;tg[cr]=;return;}
int mid=l+r>>; pshd(cr);
if(L<=mid)mdfy(l,mid,ls,L,R); if(mid<R)mdfy(mid+,r,rs,L,R);
pshp(cr);
}
Dt qry(int l,int r,int cr,int R,int p)//no !cr appear
{
if(l==r)
{
if(cal(vl[cr].p+,p)>=vl[cr].v)return Dt(l,vl[cr].p);
else return dI;
}
int mid=l+r>>; pshd(cr);
if(mid>=R)return qry(l,mid,ls,R,p);
if(cal(vl[ls].p+,p)>=vl[ls].v)//mid is ok
{
Dt d=qry(mid+,r,rs,R,p);
if(d.y)return d; else return Dt(mid,vl[ls].p);
}
else return qry(l,mid,ls,R,p);
}
int qryx(int l,int r,int cr,int R,int p)
{
if(l==r)
{
if(cal(vl[cr].p+,p)>=vl[cr].v)return vl[cr].p;
else return ;
}
int mid=l+r>>; pshd(cr);
if(mid>=R)return qryx(l,mid,ls,R,p);
if(cal(vl[ls].p+,p)>=vl[ls].v)//mid is ok
{
int d=qryx(mid+,r,rs,R,p);
if(d)return d; else return vl[ls].p;
}
else return qryx(l,mid,ls,R,p);
}
void solve()
{
ins(,n,rt[],,I); Dt d;
for(int i=;i<=n;i++)
{
tim++; rt[i]=nwnd(rt[i-]);
d=qry(,n,rt[i],ct[i-],i);
if(d.x<ct[i-])mdfy(,n,rt[i],d.x+,ct[i-]);
ct[i]=d.x+;
ins(,n,rt[i],ct[i],Node(i,cal(d.y+,i)));
ans[i]=upt(ans[d.y]+cal2(d.y+,i));
}
}
int qryx(int p){ return qryx(,n,rt[qi-],ct[qi-],p);}
};
namespace S{
const db INF=1e9+;
int ct[N],ans[N],lm;
int tot,rt[N],Ls[M],Rs[M],dfn[M],tim; bool tg[M];
Node vl[M],I;
int nwnd(int pr)
{
if(pr&&dfn[pr]==tim)return pr;
int cr=++tot; dfn[cr]=tim; if(!pr)return cr;
ls=Ls[pr]; rs=Rs[pr]; vl[cr]=vl[pr]; tg[cr]=tg[pr];
return cr;
}
void pshd(int cr)
{
if(!tg[cr])return; tg[cr]=;
ls=nwnd(ls); rs=nwnd(rs); vl[ls]=vl[rs]=; tg[ls]=tg[rs]=;
}
void pshp(int cr){if(vl[ls].p)vl[cr]=vl[ls]; else vl[cr]=vl[rs];}
void ins(int l,int r,int &cr,int p,Node k)
{
cr=nwnd(cr); if(l==r){vl[cr]=k;return;}
int mid=l+r>>; pshd(cr);
if(p<=mid)ins(l,mid,ls,p,k); else ins(mid+,r,rs,p,k);
pshp(cr);
}
void mdfy(int l,int r,int &cr,int L,int R)
{
cr=nwnd(cr); if(l>=L&&r<=R){vl[cr]=I;tg[cr]=;return;}
int mid=l+r>>; pshd(cr);
if(L<=mid)mdfy(l,mid,ls,L,R); if(mid<R)mdfy(mid+,r,rs,L,R);
pshp(cr);
}
Dt qry(int l,int r,int cr,int L,int p)
{
if(l==r)
{
if(cal(p,vl[cr].p-)<=vl[cr].v) return Dt(l,vl[cr].p);
else return dI;
}
int mid=l+r>>; pshd(cr);
if(mid<L)return qry(mid+,r,rs,L,p);
if(cal(p,vl[rs].p-)<=vl[rs].v)//mid+1 is ok
{
Dt d=qry(l,mid,ls,L,p);
if(d.y)return d; else return Dt(mid+,vl[rs].p);
}
else return qry(mid+,r,rs,L,p);
}
Dt qryx(int l,int r,int cr,int L)
{
if(l==r)
{
int d=P::qryx(vl[cr].p-);
if(cal(d+,vl[cr].p-)<=vl[cr].v)return Dt(d,vl[cr].p);
else return dI;
}
int mid=l+r>>; pshd(cr);
if(mid<L)return qryx(mid+,r,rs,L);
int d=P::qryx(vl[rs].p-);
if(cal(d+,vl[rs].p-)<=vl[rs].v)//mid+1 is ok
{
Dt ret=qryx(l,mid,ls,L);
if(ret.y)return ret; else return Dt(d,vl[rs].p);//.y for .x can be 0
}
else return qryx(mid+,r,rs,L);
}
void solve()
{
I=Node(n+,INF); lm=n+;
ins(,lm,rt[lm],lm,I); ct[lm]=lm; Dt d;
for(int i=n;i;i--)
{
tim++; rt[i]=nwnd(rt[i+]);
d=qry(,lm,rt[i],ct[i+],i);
if(d.x>ct[i+])mdfy(,lm,rt[i],ct[i+],d.x-);
ct[i]=d.x-;
ins(,lm,rt[i],ct[i],Node(i,cal(i,d.y-)));
ans[i]=upt(ans[d.y]+cal2(i,d.y-));
}
}
Dt qryx(){ return qryx(,lm,rt[qi+],ct[qi+]);}
}
int main()
{
n=rdn();int m=rdn();
for(int i=;i<=n;i++)
{ a[i]=rdn(); a2[i]=(ll)a[i]*a[i]%mod;
s[i]=s[i-]+a[i]; s2[i]=(s2[i-]+(ll)a[i]*a[i])%mod;}
P::solve(); S::solve();
printf("%d\n",P::ans[n]);
while(m--)
{
qi=rdn(); qk=rdn(); qk2=(ll)qk*qk%mod;
Dt d=S::qryx();
int ans=upt(P::ans[d.x]+S::ans[d.y]);
ans=upt(ans+cal2(d.x+,d.y-));
printf("%d\n",ans);
}
return ;
}