之前是讲过斜率优化DP的,只不过那时候还没有学过计算几何,所以说讲的不是很深,所以说这里来介绍一下如何从几何意义上也就凸包的角度理解斜率优化DP。

友情链接:

【算法介绍】斜率优化优化动态规划 | 祝馀宫

从例题出发

题目传送门

题目大意

给出 $N$ 个单词,每个单词有个非负权值 $a_i$,现在要将它们分成连续的若干段,每段的代价为此段单词的权值和的平方,还要加一个常数 $M$,即 $(\sum a_i)^2+M$。现在想求出一种最优方案,使得总费用之和最小。

解题思路

首先我们进行一个暴力的推导,也就是最原始的转移方程式。

这个不难,我们对于 $i$ 只需要枚举一个断点即可,规定 $j$ 属于上一段的部分:

这个的转移显然是 $O(n^2)$ 的,接下来我们要对其进行优化,那么这就是斜率优化了。首先我们对于这个式子进行一些简单的移项,可以得到下面这个:

这时候我们令 $y=dp[j]+sum[j]^2,x=sum[j]$,就可以得到这样一个式子:

那么我们可以这样理解它,我们已经求出来了若干个点 $(x,y)$(因为 $j$ 是从计算完的 DP 值里面枚举出来的)然后我们要以这些点求出一个最小的 $dp[i]-sum[i]^2-M$,如果把它设置为 $b$ 的话,简单移项就可以发现这其实是一个一次函数,而其中某一个点 $(x,y)$ 就是这条直线上的一个节点。

那么显然如果我们要求出最小的一个 $b$,我们只需要把这个直线从无限小的地方不断向上平移,碰到的第一个点就是我们想要的使 $b$ 最小的答案,我们将其称为决策点。

由于我们的斜率 $k=2\times sum[i]$,而每一段的价值肯定不会是负数,也就是说我们的斜率是单调递增的。这就只能说明,我们能取到答案的点一定是这些决策点所组成的凸包的下凸壳。

友情链接:

【算法介绍】计算几何入门 | 祝馀宫

那么我们不难画出一个图,可以发现在这个下凸壳上我们的决策点的价值是从劣到优再到劣的,这就让我们有可以二分/三分的选择。这里开始不同的斜率优化就有不同的处理方式了:

  1. 斜率单调,$x$ 单调
  2. 斜率不单调或者 $x$ 不单调
  3. 都不单调

这三个的处理方式各不相同,但是有一个通解的方式是李超线段树或者平衡树。第一点就是很正常的我们这道题的配置,第二点可以用单调队列维护。除了第一点的复杂度可以优化到 $O(n)$ 以外另外两个都是 $O(n\log n)$ 的。

最后我们来看一下这道题的代码。

正确代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include<bits/stdc++.h>
using namespace std;

const int N = 500010;

struct point {
long long x, y;
point(long long _x = 0, long long _y = 0) : x(_x), y(_y) {}
}p[N];

long long cross(point o, point a, point b) {
return (a.x - o.x) * (b.y - o.y) - (b.x - o.x) * (a.y - o.y);
}

long long sum[N];
int Q[N], a[N];

int n, m;

int dp[N];

long long calc(point a, int i) {
return -2LL*sum[i] * a.x + a.y;
}

bool worse(point a, point b, int i) {
return calc(a, i) >= calc(b, i);
}

int main() {
while(scanf("%d%d", &n, &m) == 2) {
sum[0] = 0;
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
sum[i] = sum[i - 1] + a[i];
}
int head = 0, tail = 0;
p[0] = point(0, 0);
Q[tail++] = 0;
dp[0] = 0;

for(int i = 1; i <= n; i++) {
while(head < tail - 1 && worse(p[Q[head]], p[Q[head + 1]], i)) head++;
dp[i] = calc(p[Q[head]], i) + sum[i] * sum[i] + m;
p[i] = point(sum[i], dp[i] + sum[i] * sum[i]);
while(head < tail - 1 && cross(p[Q[tail - 2]], p[Q[tail - 1]], p[i]) <= 0) {
tail--;
}
Q[tail++] = i;
}
printf("%lld\n", dp[n]);
}
return 0;
}

总结

实际上这样一个推导过程是通用的,接下来多做一些题,用这个思路从头多推几遍,就可以掌握了。这里也算是补齐了之前不会凸包的漏洞(你才为什么那一段凸包的那么短一笔带过)