HZOJ Function

比较神仙的一道dp,考试的时候还以为是打表找规律啥的。

我们重新描述一下这道题:一个10 9 × n的网格,每个格子有一个权值,每一列格子的权值都是相同的。从一个起点开始,每次可以向上走一格或者向左上角走一格,直到走到最上面一行为止,你需要最小化经过的格子的总权值。

然而我并没有看出来。

首先我们可以发现一些显然的性质,最优的路径之一一定形如:先往左上走若干步(可能不走),到达权值较小的一列后,一直往上走到顶。对于每个询问,枚举从起点出发最终会到达哪一列,就可以得到一个O(nq)的做法。

然而我也没有想到……

算了直接把作者的题解全放出来吧:

对于任意1 ≤ i ≤ j ≤ n, 从(x, j)出发最终到达第i列然后走到顶的代价,可以表示为一个关于x的一次函数,我们只关心这些一次函数的最小值,也就是这些直线形成的下凸壳。我们得到一个思路:将询问离线,按y从小到大排序,从最左边开始每次加入一条直线,维护下凸壳,然后在凸壳上二分即可得到答案。怎么维护下凸壳呢?对于一个点(x, y),它要么继承上一列x − 1的决策,要么就直接往上走到顶。并且我们发现,第二种情况只会出现在从顶端开始连续的一段中。于是我们只需要用栈维护凸壳即可。O((n + q) log n).

刚开始没怎么看懂,好像我的做法和题解也不是很一样,其实现在还有一些细节没有搞明白……

首先看暴力的式子:$ans=min(ans,sum[y(i)]-sum[j]+(x(i)-y(i)+j)*A[j]);$在y固定时,他是一个关于x的一次函数,即$y=kx+b$的形式,设走到j时停止然后向上走,那么$k=A[j],b=sum[y]-sum[j]+(j-y)*A[j];$

对于每一个j都是一条直线,那么这些直线构成了一个上凸壳。

我们可以用On的复杂度枚举y,用栈维护凸壳(添加直线是加在了坐标系的最左边),考虑y增加会给直线造成什么影响,只会使直线的截距发生改变而斜率不变,所以原来的凸壳仍然是对的。

那么考虑如何吧j=y的这条之间加入凸壳,首先将斜率大于这条直线的栈顶直线弹掉,然后交点也得是单调的,继续弹掉不合法的,(自己yy一下坐标系,横轴是询问的x,纵轴为最优解),

然后处理当前y的询问,直接二分栈找到当前x在坐标系中对应的直线就可以了(一定注意栈顶其实是坐标轴最左边的直线)。

放下代码(稍恶心):

 

HZOJ Function
 1 #include<algorithm>
 2 #include<iostream>
 3 #include<cstring>
 4 #include<cstdio>
 5 #define st sta[top]
 6 #define sm sta[mid]
 7 #define sm1 sta[mid+1]
 8 #define st1 sta[top-1]
 9 #define int LL
10 #define LL long long
11 using namespace std;
12 struct ques
13 {
14     int x,y,id;
15     #define x(i) que[i].x
16     #define y(i) que[i].y
17     #define id(i) que[i].id
18     friend bool operator < (ques a,ques b)
19     {return a.y<b.y;}
20 }que[500010];
21 int n,A[500010],q,maxx;
22 LL sum[500010],al[500010];
23 LL sta[500010],top;
24 double getx(int k1,int k2,int j1,int j2){return (double)(j2-j1)/(double)(k1-k2);}
25 inline int read();
26 signed main()
27 {    
28 //    freopen("function2.in","r",stdin);
29 //    freopen("out.out","w",stdout);
30     
31     n=read();
32     for(int i=1;i<=n;i++)A[i]=read(),sum[i]=sum[i-1]+A[i];
33     q=read();
34     for(int i=1;i<=q;i++)x(i)=read(),y(i)=read(),id(i)=i;
35     sort(que+1,que+q+1);
36     
37     int now=1;
38     for(int y=1;y<=n;y++)
39     {    
40         while(top&&A[sta[top]]>=A[y])top--;
41         while(top>1&&
42                 getx(A[y],A[st],0,sum[y]-sum[st]+A[st]*(st-y))
43               >=getx(A[st1],A[st],sum[y]-sum[st1]+A[st1]*(st1-y),sum[y]-sum[st]+A[st]*(st-y))
44                 )top--;
45         sta[++top]=y;
46         for(;y(now)==y&&now<=q;now++)
47         {
48             int l=1,r=top,mid;
49             while(l<r)
50             {
51                 mid=(l+r)>>1;
52                 double tx=getx(A[sm],A[sm1],sum[y]-sum[sm]+A[sm]*(sm-y),sum[y]-sum[sm1]+A[sm1]*(sm1-y));
53                 if(x(now)<=tx)l=mid+1;
54                 else r=mid;
55             }
56             mid=l;
57             al[id(now)]=sum[y]-sum[sm]+A[sm]*(x(now)-y+sm);
58         }
59         
60     }
61     for(int i=1;i<=q;i++)printf("%lld\n",al[i]);
62 }
63 inline int read()
64 {
65     int s=0,f=1;char a=getchar();
66     while(a<'0'||a>'9'){if(a=='-')f=-1;a=getchar();}
67     while(a>='0'&&a<='9'){s=s*10+a-'0';a=getchar();}
68     return s*f;
69 }
View Code

 

上一篇:HZOJ Dash Speed


下一篇:一个Java对象到底有多大?