写写堆排序

2016-07-31 jude 更多博文 » 博客 » GitHub »

algorithm heap sort ruby

原文链接 http://judes.me/tech/2016/07/31/heap-sort.html
注:以下为加速网络访问所做的原文缓存,经过重新格式化,可能存在格式方面的问题,或偶有遗漏信息,请以原文为准。


原理

堆排序中的“堆”,它是:

  • 一棵完全二叉树
  • 树的每个节点都不比它的两个子节点小(有序)

由此得到最有用的信息:根节点是二叉树里面最大的元素

堆排序的过程是:

  • 构造有序的堆
  • 输出并删除最大的元素
  • 重复前面两个步骤

联想

  • 用数组表示二叉树,数组从索引 1 开始记录元素。这样做可以很方便地找到完全二叉树中某个节点的及其两个子节点。
  • 从下往上地把二叉树变成堆
  • 交换最大的元素与二叉树的最后一个元素的位置,把二叉树的节点数减 1 ,再把二叉树变成堆
  • 重复上一步骤,直到二叉树只剩下 1 个节点

把二叉树变成堆涉及一个重要的基本操作:下沉。就是通过不断比较父节点与子节点,如果父节点比子节点小,就交换两者的位置。如果从根节点开始下沉,会得到整棵树中最大的节点。

def sink(arr, index, len)
  while index * 2 <= len
    child_index = index * 2 
    if child_index < len && arr[child_index] < arr[child_index + 1]
      child_index += 1
    end
    break nil if arr[child_index] <= arr[index]
    exchange(arr, index, child_index)
    index = child_index
  end
end

def exchange(arr, i, j)
  arr[i], arr[j] = arr[j], arr[i]
end

用法

#!/usr/bin/env ruby
sorted_array = HS.sort(array)

大体架构

module HS
  extend self
  def sort(array)
    heap_sort(array)
  end
  private
  def heap_sort(array)
    return array if array.length < 2
    # 构建符合格式的数组
    # 构建堆
    # 排序

    array[1..-1]
  end
end

构建符合格式的数组

这个很简单,把数组第 1 个元素插入到数组末尾,然后把第 1 个元素置为 nil

array[0], array[array.length] = nil, array[0]

构建堆

把符合格式的数组看成是一棵完全二叉树,要构建堆的话,当然可以从数组的第 2 个元素开始下沉,一直到最后一个元素。

但这样的效率很低,单说最大可能交换的次数,假设 N 是树的高度,最大可能的交换次数一共是:

2^0*(N-1)+2^1*(N-2)+...+2^(N-2)*1

参考黑帮发展的形式可以得到比较好的思路:不用管一众小黑帮的喽啰,只要收拾了它们的头目,就可以掌管这些黑帮。一开始就要找到小黑帮的头目,把他们打沉了,再往上找他们的头目。这样最大可能交换的次数一共是:

2^0+2^1+...+2^(N-2)

实现这个技巧的代码也很简单:

def make_heap(array)
  index = (array.length - 1) / 2
  while index >= 1
    sink(array, index, array.length - 1)
    index -= 1
  end
  array
end

排序

我们最终要得到一个从小到大的数组,得不断地把堆中最大的元素移到数组的末尾,缩小堆的大小之后,再使堆有序。

def _sort(array)
  len = array.length - 1
  while len > 1
    exchange(array, 1, len)
    len -= 1
    sink(array, 1, len)
  end
end

完整的代码,加测试用例如下:

#!/usr/bin/env ruby

module HS
  extend self
  def sort(array)
    heap_sort(array)
  end
  private
  def heap_sort(array)
    return array if array.length < 2
    # 构建符合格式的数组
    array[0], array[array.length] = nil, array[0]
    # 构建堆
    make_heap(array)
    # 排序
    _sort(array)

    array[1..-1]
  end

  def make_heap(array)
    index = (array.length - 1) / 2
    while index >= 1
      sink(array, index, array.length - 1)
      index -= 1
    end
    array
  end

  def _sort(array)
    len = array.length - 1
    while len > 1
      exchange(array, 1, len)
      len -= 1
      sink(array, 1, len)
    end
  end

  def sink(arr, index, len)
    while index * 2 <= len
      child_index = index * 2
      if child_index < len && arr[child_index] < arr[child_index + 1]
        child_index += 1
      end
      break nil if arr[child_index] <= arr[index]
      exchange(arr, index, child_index)
      index = child_index
    end
  end

  def exchange(arr, i, j)
    arr[i], arr[j] = arr[j], arr[i]
  end

end

if __FILE__ == $0
  require 'test/unit'
  class TestHS < Test::Unit::TestCase
    def test_0
      input = []
      expected = []
      assert_equal HS.sort(input), expected, 'empty array not equal'
    end
    def test_1_0
      input = [0]
      expected = [0]
      assert_equal HS.sort(input), expected, 'one item array not equal'
    end
    def test_1_1
      input = [1]
      expected = [1]
      assert_equal HS.sort(input), expected, 'one item array not equal'
    end
    def test_2_0
      input = [0,1]
      expected = [0,1]
      assert_equal HS.sort(input), expected, 'two iteHS array not equal'
    end
    def test_2_1
      input = [1,0]
      expected = [0,1]
      assert_equal HS.sort(input), expected, 'two iteHS array not equal'
    end
    def test_2_2
      input = [1,1]
      expected = [1,1]
      assert_equal HS.sort(input), expected, 'two iteHS array not equal'
    end
    def test_3_0
      input = [0,1,2]
      expected = [0,1,2]
      assert_equal HS.sort(input), expected, 'three iteHS array not equal'
    end
    def test_3_1
      input = [0,2,1]
      expected = [0,1,2]
      assert_equal HS.sort(input), expected, 'three iteHS array not equal'
    end
    def test_3_2
      input = [2,1,0]
      expected = [0,1,2]
      assert_equal HS.sort(input), expected, 'three iteHS array not equal'
    end
    def test_3_3
      input = [2,1,1]
      expected = [1,1,2]
      assert_equal HS.sort(input), expected, 'three iteHS array not equal'
    end
    def test_3_4
      input = [1,1,1]
      expected = [1,1,1]
      assert_equal HS.sort(input), expected, 'three iteHS array not equal'
    end
    def test_4_0
      input = [0,1,2,3]
      expected = [0,1,2,3]
      assert_equal HS.sort(input), expected, 'four iteHS array not equal'
    end
    def test_4_1
      input = [3,1,2,0]
      expected = [0,1,2,3]
      assert_equal HS.sort(input), expected, 'four iteHS array not equal'
    end
    def test_4_2
      input = [3,1,2,1]
      expected = [1,1,2,3]
      assert_equal HS.sort(input), expected, 'four iteHS array not equal'
    end
    def test_10_0
      input = [9,8,7,6,5,4,3,2,1,0]
      expected = [0,1,2,3,4,5,6,7,8,9]
      assert_equal HS.sort(input), expected, 'ten iteHS array not equal'
    end
    def test_10_1
      input = [9,5,7,3,8,4,6,1,2,0]
      expected = [0,1,2,3,4,5,6,7,8,9]
      assert_equal HS.sort(input), expected, 'ten iteHS array not equal'
    end
  end
end