Loading...
最长上升子序列长度的O(N*logN)算法推导
原文链接 http://huiming.io/2017/05/21/lis.html
注:以下为加速网络访问所做的原文缓存,经过重新格式化,可能存在格式方面的问题,或偶有遗漏信息,请以原文为准。
- 目录 {:toc}
前言
关于最长上升子序列的O(N*logN)算法已经有不少文章表述,可惜大都不够“好”(甚至语焉不详):我认为“好”的算法描述不但应该清晰地说明计算步骤,更应该讲清思路——即,这个算法是怎样思考得出的。这种思考的过程和方式才是精华之处。我试图用我的理解对这个算法给出一个尽量“好”的推导和表述,使你我一样的普通人都可以理解它的思路。
定义
一个序列(比如数组或字符串)的 子序列 是指从这个序列中选出的若干元素组成的新序列,并且新序列中元素的顺序与原序列中这些元素的顺序相同。比如,[1, 3, 2, 5, 4]
的一个子序列是[1, 3, 5]
,但不是[3, 1, 5]
。
一个序列的 上升子序列 是指对于它的一个子序列中任意两个元素a[i]
和a[j]
,若i < j
则a[i] < a[j]
。
一个序列的 最长上升子序列(Longest Increasing Subsequence,LIS) 是它的上升子序列中长度最长的(可能不止一个)。
O(N^2)算法
O(N^2)算法是一种相对容易得出的算法,以此为基础,我们可以改进它,进而得到O(N*logN)算法。所以即使你已经了解了O(N^2)算法,不妨再浏览一下,从这里开始整理一下思路。
首先,我们可以枚举一个序列的所有上升子序列,然后从中找出一个最长的。枚举/穷举法当然不是我们的最终追求,但枚举是重要的:计算机科学就是计数的科学,要做到既无重复又无遗漏地对一个集合进行计数并不总是十分容易;适当的枚举方法对于解决问题十分重要。
我们可以按照子序列的末尾元素(最后一个元素)对所有子序列做划分:把末尾元素相同的子序列归为一组(也可以按照首元素做划分,思路相同,解法相似)。这样我们就能用类似如下代码枚举:
vector<int> a; //原始序列
vector<int> l; //意义见下注释
//...
for (int i = 0; i < a.size(); i++) {
//对每个a[i], 枚举以a[i]结尾的所有上升子序列,得到最长的子序列,记其长度为l[i]
//...
}
//遍历l求最大值
但我们不必真的如上枚举(这会导致一个大于O(N^2)的算法),因为l[i]
可以通过{l[j] | 0 <= j < i}
得出:
对集合{ a[j] | 0 <= j < i && a[j] < a[i] }
中的任一a[j]
,把a[i]
加到l[j]
对应的最长子序列末尾就会的到一个新的上升子序列,并且l[i]
对应的最长子序列一定是这些新的子序列中的一个。也就是说,有如下 递推公式 :
l[i] = max({l[j] | 0 <= j < i && a[j] < a[i]}) + 1
若max的输入集合为空,则l[i] = 1
。完整代码如下:
#include <vector>
#include <algorithm>
using namespace std;
int lis1(const vector<int> & a) {
vector<int> l(a.size());
l[0] = 1;
for (int i = 1; i < a.size(); i++) {
int max = 0;
for (int j = 0; j < i; j++) {
if (a[j] < a[i] && l[j] > max) {
max = l[j];
}
}
l[i] = max + 1;
}
return *max_element(l.begin(), l.end());
}
这是一个O(N^2)算法。现在我们要把它提升为O(N*logN),关键在于“优化”第二重对j
的循环——显然,必须找到一种O(logN)的方式来计算l[i]
——只能在有序集合上进行二分查找,不能遍历。
O(N*logN)算法
回顾一下上面的代码,在计算l[i]
时如果我们能更“便捷”地找到a[i]
应该加入的子序列j
就好了。有没有可能呢?在开始计算l[i]
时,如果已知可能的最长上升子序列长度是i
,并且如果它的末尾元素比a[i]
小,则把a[i]
加入它的末尾就得到l[i] = i + 1
;如果它的末尾元素不比a[i]
小,或者不存在长度为i
的上升子序列,则考虑长度为i - 1
的上升子序列,如此重复,直至长度为1的上升子序列,如果它的末尾元素还是不小于a[i]
,则l[i]
= 1。这个查找过程似乎在暗示着某种有序序列。
如果我们定义m[i]
为长度为i + 1
的上升子序列中末尾元素的最小值(这样在计算l[i]
时,就依次检查m[i - 1]
直至m[0]
),用反证法易证m
为上升序列(这里请稍稍思考一下),因此可用二分查找来“优化”以上第二重循环。另外需要注意的是,m[i]
的计算不是(也不需要)“一步到位”而是“反复更新”的,但这并不影响计算l[i]
时利用m[i - 1]
及至m[0]
。这一点请从下面代码中仔细体会。完整的代码如下:
#include <vector>
#include <climits>
using namespace std;
//在升序序列a[begin, end)中查找最后一个小于v的元素的位置
//如果没有这样的元素,返回-1
int bsearch(const vector<int> & a, int v, int begin, int end) {
if (begin >= end)
return -1;
int mid = (begin + end) / 2;
if (a[mid] < v) {
if (mid + 1 >= end || a[mid + 1] >= v)
return mid;
else
return bsearch(a, v, mid + 1, end);
}
else {
return bsearch(a, v, begin, mid);
}
}
int lis2(const vector<int> & a) {
vector<int> m(a.size(), INT_MAX); //定义如上
m[0] = a[0];
int max = 1; //最长上升子序列长度
for (int i = 1; i < a.size(); i++) {
int j = bsearch(m, a[i], 0, i);
int l = j >= 0 ? j + 1 : 0;
if (m[l] > a[i])
m[l] = a[i];
l++; //以a[i]结尾的最长上升子序列的长度
if (l > max)
max = l;
}
return max;
}
更简洁的O(N*logN)算法
《挑战程序设计竞赛》(人民邮电出版社)里录有一种更简洁的算法实现,如下(纯C++实现,在原版上稍加修改):
#include <vector>
#include <climits>
#include <algorithm>
using namespace std;
int lis3(const vector<int> & a) {
vector<int> m(a.size(), INT_MAX);
for (int i = 0; i < a.size(); i++) {
*lower_bound(m.begin(), m.end(), a[i]) = a[i];
}
return lower_bound(m.begin(), m.end(), INT_MAX) - m.begin();
}
其中lower_bound(begin, end, val)是STL函数,返回有序集合[begin, end)上第一个不小于val
的元素的位置,或者end
如果没有的话。
此算法原理同上,但实现十分简洁。可惜原文并未多做解释。还请读者细细品味。