Ocaml Programming: Correct + Efficient + Beautiful

记忆化

记忆化是一种强大的技术,可以加速简单的递归算法,而无需改变算法的工作方式。这是通过“记住”计算结果来实现的,这样以前计算的结果就不需要重新计算。我们将使用命令式数据结构(例如数组和哈希表)来说明这个原理。

以下是书籍 OCaml 编程:正确 + 高效 + 美观 中页面 “记忆化” 的摘录,经许可在此转载。

斐波那契数列

让我们再次考虑计算第 n 个斐波那契数的问题。朴素的递归实现需要指数时间,因为对相同的斐波那契数进行了反复计算。

let rec fib n = if n < 2 then 1 else fib (n - 1) + fib (n - 2)

注意:准确地说,它的运行时间是 O(pⁿ),其中 p 是黄金分割率,(1 + √5)/2。

如果我们在计算斐波那契数时记录它们,我们可以避免这种冗余工作。其思想是,每当我们计算 f n 时,我们都将其存储在一个由 n 索引的表中。在这种情况下,索引键是整数,因此我们可以使用数组来实现此表。

let fibm n =
  let memo : int option array = Array.make (n + 1) None in
  let rec f_mem n =
    match memo.(n) with
    | Some result -> (* computed already *) result
    | None ->
        let result =
          if n < 2 then 1 else f_mem (n - 1) + f_mem (n - 2)
        in
        (* record in table *)
        memo.(n) <- Some result;
        result
  in
  f_mem n

fibm 内部定义的函数 f_mem 包含原始的递归算法,但在进行该计算之前,它首先检查结果是否已计算并存储在表中,如果是,则直接返回结果。

我们如何分析此函数的运行时间?如果我们不考虑它碰巧进行的任何递归调用,则在对 f_mem 的单个调用中花费的时间为 O(1)。现在我们寻找一种方法,通过找到正在取得的进展的某种度量来限制递归调用的总数。

一个好的进度度量选择,不仅在这里,而且在许多使用记忆化的场景中,是表中非空条目的数量(即包含 Some n 而不是 None 的条目)。每次 f_mem 进行两次递归调用时,它也会将非空条目的数量增加一个(用新值填充表中以前为空的条目)。由于表只有 n 个条目,因此对 f_mem 的总调用次数最多只能为 O(n),总运行时间为 O(n)(因为我们在上面确定了每个调用都需要 O(1) 时间)。因此,记忆化带来的这种加速将运行时间从指数降低到线性,这是一个巨大的变化——例如,对于 n=4,记忆化带来的加速超过了百万倍!

能够应用记忆化的关键在于存在反复解决的公共子问题。因此,我们可以使用一些额外的存储来节省重复计算。

虽然此代码使用了命令式结构(特别是数组更新),但副作用在 fibm 函数之外不可见。因此,从客户端的角度来看,fibm 是函数式的。无需提及内部使用的命令式实现(即良性的副作用)。

使用高阶函数进行记忆化

现在我们已经看到了一个记忆化一个函数的示例,让我们使用高阶函数来记忆化任何函数。首先,考虑记忆化非递归函数 f 的情况。在这种情况下,我们只需要创建一个哈希表,用于存储 f 被调用的每个参数对应的值(并且要记忆化多参数函数,我们可以使用柯里化和反柯里化将其转换为单参数函数)。

let memo f =
  let h = Hashtbl.create 11 in
  fun x ->
    try Hashtbl.find h x
    with Not_found ->
      let y = f x in
      Hashtbl.add h x y;
      y

然而,对于递归函数,需要修改递归调用结构。这可以独立于正在记忆化的函数进行抽象。

let memo_rec f =
  let h = Hashtbl.create 16 in
  let rec g x =
    try Hashtbl.find h x
    with Not_found ->
      let y = f g x in
      Hashtbl.add h x y;
      y
  in
  g

现在我们可以使用这种通用的记忆化技术稍微重写上面原始的 fib 函数。

let fib_memo =
  let rec fib self n =
    if n < 2 then 1 else self (n - 1) + self (n - 2)
  in
  memo_rec fib

仅仅为了好玩:派对优化

假设我们要为一家组织结构图是二叉树的公司举办派对。每个员工都有一个相关的“乐趣值”,我们希望受邀员工的集合具有最大的总乐趣值。但是,如果某个员工的上级被邀请,那么该员工就不会觉得有趣,因此我们永远不会邀请组织结构图中连接的两个员工。(此问题的不太有趣的名称是树中的最大权重独立集。)对于一个有 n 个员工的组织结构图,有 2ⁿ 个可能的邀请列表,因此比较每个有效邀请列表的乐趣值的朴素算法需要指数时间。

我们可以使用记忆化将此算法转换为线性时间算法。我们首先定义一个变体类型来表示员工。每个节点中的 int 是乐趣值。

type tree = Empty | Node of int * tree * tree

现在,我们如何递归地解决这个问题?一个重要的观察结果是,在任何树中,不包含根节点的最优邀请列表将是左子树和右子树的最优邀请列表的并集。而包含根节点的最优邀请列表将是左孩子和右孩子的最优邀请列表的并集,这些列表不包含它们各自的根节点。因此,拥有函数来优化邀请列表(在需要邀请根节点的情况下和在排除根节点的情况下)似乎很有用。我们将这两个函数分别称为 `party_in` 和 `party_out`。那么 `party` 函数的结果就是这两个函数中的最大值。

module Unmemoized = struct
  type tree =
    | Empty
    | Node of int * tree * tree

  (* Returns optimum fun for t. *)
  let rec party t = max (party_in t) (party_out t)

  (* Returns optimum fun for t assuming the root node of t
   * is included. *)
  and party_in t =
    match t with
    | Empty -> 0
    | Node (v, left, right) -> v + party_out left + party_out right

  (* Returns optimum fun for t assuming the root node of t
   * is excluded. *)
  and party_out t =
    match t with
    | Empty -> 0
    | Node (v, left, right) -> party left + party right
end

这段代码具有指数级的运行时间。但请注意,`party` 函数只有 `n` 种可能的不同的调用。如果我们修改代码以记忆这些调用的结果,则性能将与 `n` 成线性关系。这是一个记忆 `party` 结果并计算实际邀请列表的版本。请注意,此代码直接在树中记忆结果。

module Memoized = struct
  (* This version memoizes the optimal fun value for each tree node. It
     also remembers the best invite list. Each tree node has the name of
     the employee as a string. *)
  type tree =
    | Empty
    | Node of
        int * string * tree * tree * (int * string list) option ref

  let rec party t : int * string list =
    match t with
    | Empty -> (0, [])
    | Node (_, name, left, right, memo) -> (
        match !memo with
        | Some result -> result
        | None ->
            let infun, innames = party_in t in
            let outfun, outnames = party_out t in
            let result =
              if infun > outfun then (infun, innames)
              else (outfun, outnames)
            in
            memo := Some result;
            result)

  and party_in t =
    match t with
    | Empty -> (0, [])
    | Node (v, name, l, r, _) ->
        let lfun, lnames = party_out l and rfun, rnames = party_out r in
        (v + lfun + rfun, name :: lnames @ rnames)

  and party_out t =
    match t with
    | Empty -> (0, [])
    | Node (_, _, l, r, _) ->
        let lfun, lnames = party l and rfun, rnames = party r in
        (lfun + rfun, lnames @ rnames)
end

为什么记忆化对于解决此问题如此有效?与斐波那契算法一样,我们具有重叠子问题属性,其中朴素的递归实现多次使用相同的参数调用 `party` 函数。记忆化保存了所有这些调用。此外,`party` 优化问题具有最优子结构属性,这意味着问题的最优答案是从子问题的最优答案计算出来的。并非所有优化问题都具有此属性。有效地将记忆化用于优化问题的关键是找出如何编写实现该算法并具有这两个属性的递归函数。有时这需要仔细思考。

帮助改进我们的文档

鼓励您为 CS3110 GitHub 存储库中此页面的原始来源做出贡献。

OCaml

创新。社区。安全。