最长上升子序列长度的O(N*logN)算法推导

2017-05-21 Robert Zhang 更多博文 » 博客 » GitHub »

algorithm 挑战程序设计竞赛

原文链接 http://huiming.io/2017/05/21/lis.html
注:以下为加速网络访问所做的原文缓存,经过重新格式化,可能存在格式方面的问题,或偶有遗漏信息,请以原文为准。


  • 目录 {:toc}

前言

关于最长上升子序列的O(N*logN)算法已经有不少文章表述,可惜大都不够“好”(甚至语焉不详):我认为“好”的算法描述不但应该清晰地说明计算步骤,更应该讲清思路——即,这个算法是怎样思考得出的。这种思考的过程和方式才是精华之处。我试图用我的理解对这个算法给出一个尽量“好”的推导和表述,使你我一样的普通人都可以理解它的思路。

定义

一个序列(比如数组或字符串)的 子序列 是指从这个序列中选出的若干元素组成的新序列,并且新序列中元素的顺序与原序列中这些元素的顺序相同。比如,[1, 3, 2, 5, 4]的一个子序列是[1, 3, 5],但不是[3, 1, 5]

一个序列的 上升子序列 是指对于它的一个子序列中任意两个元素a[i]a[j],若i < ja[i] < a[j]

一个序列的 最长上升子序列(Longest Increasing Subsequence,LIS) 是它的上升子序列中长度最长的(可能不止一个)。

O(N^2)算法

O(N^2)算法是一种相对容易得出的算法,以此为基础,我们可以改进它,进而得到O(N*logN)算法。所以即使你已经了解了O(N^2)算法,不妨再浏览一下,从这里开始整理一下思路。

首先,我们可以枚举一个序列的所有上升子序列,然后从中找出一个最长的。枚举/穷举法当然不是我们的最终追求,但枚举是重要的:计算机科学就是计数的科学,要做到既无重复又无遗漏地对一个集合进行计数并不总是十分容易;适当的枚举方法对于解决问题十分重要。

我们可以按照子序列的末尾元素(最后一个元素)对所有子序列做划分:把末尾元素相同的子序列归为一组(也可以按照首元素做划分,思路相同,解法相似)。这样我们就能用类似如下代码枚举:

vector<int> a; //原始序列
vector<int> l; //意义见下注释
//...
for (int i = 0; i < a.size(); i++) {
  //对每个a[i], 枚举以a[i]结尾的所有上升子序列,得到最长的子序列,记其长度为l[i]
  //...
}
//遍历l求最大值

但我们不必真的如上枚举(这会导致一个大于O(N^2)的算法),因为l[i]可以通过{l[j] | 0 <= j < i}得出:

对集合{ a[j] | 0 <= j < i && a[j] < a[i] }中的任一a[j],把a[i]加到l[j]对应的最长子序列末尾就会的到一个新的上升子序列,并且l[i]对应的最长子序列一定是这些新的子序列中的一个。也就是说,有如下 递推公式

l[i] = max({l[j] | 0 <= j < i && a[j] < a[i]}) + 1

若max的输入集合为空,则l[i] = 1。完整代码如下:

#include <vector>
#include <algorithm>

using namespace std;

int lis1(const vector<int> & a) {
  vector<int> l(a.size());
  l[0] = 1;
  for (int i = 1; i < a.size(); i++) {
    int max = 0;
    for (int j = 0; j < i; j++) {
      if (a[j] < a[i] && l[j] > max) {
        max = l[j];
      }
    }
    l[i] = max + 1;
  }
  return *max_element(l.begin(), l.end());
}

这是一个O(N^2)算法。现在我们要把它提升为O(N*logN),关键在于“优化”第二重对j的循环——显然,必须找到一种O(logN)的方式来计算l[i]——只能在有序集合上进行二分查找,不能遍历。

O(N*logN)算法

回顾一下上面的代码,在计算l[i]时如果我们能更“便捷”地找到a[i]应该加入的子序列j就好了。有没有可能呢?在开始计算l[i]时,如果已知可能的最长上升子序列长度是i,并且如果它的末尾元素比a[i]小,则把a[i]加入它的末尾就得到l[i] = i + 1;如果它的末尾元素不比a[i]小,或者不存在长度为i的上升子序列,则考虑长度为i - 1的上升子序列,如此重复,直至长度为1的上升子序列,如果它的末尾元素还是不小于a[i],则l[i] = 1。这个查找过程似乎在暗示着某种有序序列。

如果我们定义m[i]为长度为i + 1的上升子序列中末尾元素的最小值(这样在计算l[i]时,就依次检查m[i - 1]直至m[0]),用反证法易证m为上升序列(这里请稍稍思考一下),因此可用二分查找来“优化”以上第二重循环。另外需要注意的是,m[i]的计算不是(也不需要)“一步到位”而是“反复更新”的,但这并不影响计算l[i]时利用m[i - 1]及至m[0]。这一点请从下面代码中仔细体会。完整的代码如下:

#include <vector>
#include <climits>

using namespace std;

//在升序序列a[begin, end)中查找最后一个小于v的元素的位置
//如果没有这样的元素,返回-1
int bsearch(const vector<int> & a, int v, int begin, int end) {
  if (begin >= end)
    return -1;

  int mid = (begin + end) / 2;
  if (a[mid] < v) {
    if (mid + 1 >= end || a[mid + 1] >= v)
      return mid;
    else
      return bsearch(a, v, mid + 1, end);
  }
  else {
    return bsearch(a, v, begin, mid);
  }
}

int lis2(const vector<int> & a) {
  vector<int> m(a.size(), INT_MAX); //定义如上
  m[0] = a[0];
  int max = 1;            //最长上升子序列长度
  for (int i = 1; i < a.size(); i++) {
    int j = bsearch(m, a[i], 0, i);
    int l = j >= 0 ? j + 1 : 0;
    if (m[l] > a[i])
      m[l] = a[i];
    l++;    //以a[i]结尾的最长上升子序列的长度
    if (l > max)
      max = l;
  }
  return max;
}

更简洁的O(N*logN)算法

《挑战程序设计竞赛》(人民邮电出版社)里录有一种更简洁的算法实现,如下(纯C++实现,在原版上稍加修改):

#include <vector>
#include <climits>
#include <algorithm>

using namespace std;

int lis3(const vector<int> & a) {
  vector<int> m(a.size(), INT_MAX);
  for (int i = 0; i < a.size(); i++) {
    *lower_bound(m.begin(), m.end(), a[i]) = a[i];
  }
  return lower_bound(m.begin(), m.end(), INT_MAX) - m.begin();
}

其中lower_bound(begin, end, val)是STL函数,返回有序集合[begin, end)上第一个不小于val的元素的位置,或者end如果没有的话。

此算法原理同上,但实现十分简洁。可惜原文并未多做解释。还请读者细细品味。