(* ========================================================================= *)
(* IMPERATIVE PROBABILITY DISTRIBUTIONS OVER INTEGERS                        *)
(* Copyright (c) 2005 Joe Leslie-Hurd, distributed under the MIT license     *)
(* ========================================================================= *)

structure IIntDist :> IIntDist =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type of imperative probability distributions over integers.             *)
(* ------------------------------------------------------------------------- *)

datatype distribution =
    Distribution of
      {size : int,
       tree : real Array.array,
       weights : real Array.array};

fun zero n =
    let
(*GomiDebug
      val _ = n > 0 orelse raise Bug "IIntDist.zero: nonpositive size"
*)

      val tree = Array.array (n,0.0)
      and weights = Array.array (n,0.0)
    in
      Distribution {size = n, tree = tree, weights = weights}
    end;

local
  fun uniformTree _ _ 1 = ()
    | uniformTree tree i w =
      let
        val r = w div 2
        val l = w - r
        val () = Array.update (tree, i, Real.fromInt r)
        val () = uniformTree tree (2 * i) l
        val () = uniformTree tree (2 * i + 1) r
      in
        ()
      end;
in
  fun uniform n =
    let
(*GomiDebug
      val _ = n > 0 orelse raise Bug "IIntDist.uniform: nonpositive size"
*)

      val tree = Array.array (n,1.0)
      and weights = Array.array (n,1.0)

      val () = Array.update (tree, 0, Real.fromInt n)
      val () = uniformTree tree 1 n
    in
      Distribution {size = n, tree = tree, weights = weights}
    end;
end;

fun size (Distribution {size = s, ...}) = s;

fun totalWeight (Distribution {tree,...}) = Array.sub (tree,0);

fun weight (Distribution {weights,...}) i = Array.sub (weights,i);

fun probability distribution i =
    let
      val w_i = weight distribution i
      val w = totalWeight distribution
(*GomiDebug
      val _ = w > 0.0 orelse
              raise Bug "IIntDist.probability: nonpositive totalWeight"
*)
    in
      w_i / w
    end;

local
  fun modifyParentWeight _ _ 0 = ()
    | modifyParentWeight tree d i =
      let
        val j = i div 2
        val () =
            if i = 2 * j then ()
            else Array.update (tree, j, Array.sub (tree,j) + d)
      in
        modifyParentWeight tree d j
      end;
in
  fun setWeight (Distribution {size,tree,weights}) i w =
      let
(*GomiDebug
        val _ = i >= 0 orelse raise Bug "IIntDist.setWeight: negative index"
        val _ = i < size orelse raise Bug "IIntDist.setWeight: large index"
        val _ = w >= 0.0 orelse raise Bug "IIntDist.setWeight: negative weight"
*)
        val d = w - Array.sub (weights,i)
        val () = Array.update (weights,i,w)
        val () = modifyParentWeight tree d (i + size)
      in
        ()
      end;
end;

fun setWeights dist weight =
    let
      fun g 0 = ()
        | g i =
          let
            val i = i - 1
            val w = weight i
            val () = setWeight dist i w
          in
            g i
          end
    in
      g (size dist)
    end;

fun tabulate n weight =
    let
      val dist = zero n

      val () = setWeights dist weight
    in
      dist
    end;

local
  fun chooseSample n tree w i =
      if i >= n then i - n
      else
        let
          val r = Array.sub (tree,i)
        in
          if w < r then chooseSample n tree w (2 * i + 1)
          else chooseSample n tree (w - r) (2 * i)
        end;
in
  fun sample (Distribution {size,tree,...}) =
      let
        val w = Array.sub (tree,0) * Portable.randomReal ()
      in
        chooseSample size tree w 1
      end;
end;

fun clone (Distribution {size,tree,weights}) =
    Distribution
      {size = size,
       tree = cloneArray tree,
       weights = cloneArray weights};

fun copy src dst =
    let
      val Distribution {tree = srcTree, weights = srcWeights, ...} = src
      and Distribution {tree = dstTree, weights = dstWeights, ...} = dst
(*GomiDebug
      val _ = srcTree <> dstTree orelse
              raise Bug "IIntDist.copy: same array"
      val _ = Array.length srcTree = Array.length dstTree orelse
              raise Bug "IIntDist.copy: different sizes"
*)
      val () = Array.copy {src = srcTree, dst = dstTree, di = 0}
      val () = Array.copy {src = srcWeights, dst = dstWeights, di = 0}
    in
      ()
    end;

fun foldl f b (Distribution {weights,...}) = Array.foldli f b weights;

fun foldr f b (Distribution {weights,...}) = Array.foldri f b weights;

fun toList distribution =
    let
      fun f i = (i, weight distribution i)
    in
      List.tabulate (size distribution, f)
    end;

(* ------------------------------------------------------------------------- *)
(* Pretty printing.                                                          *)
(* ------------------------------------------------------------------------- *)

fun pp ppA =
    Print.ppMap
      toList (Print.ppList (Print.ppOp2 " |->" ppA Print.ppReal));

val toString = Print.toString (pp Print.ppInt);

end
