(* ========================================================================= *)
(* ESTIMATING PROBABILITIES                                                  *)
(* Copyright (c) 2005 Joe Leslie-Hurd, distributed under the MIT license     *)
(* ========================================================================= *)

structure Probability :> Probability =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Constants.                                                                *)
(* ------------------------------------------------------------------------- *)

val INTERVALS = 10;

val EPSILON = 1e~6;

(* ------------------------------------------------------------------------- *)
(* Helper functions.                                                         *)
(* ------------------------------------------------------------------------- *)

fun square x : real = x * x;

fun tabulatesVector length mks z =
    let
      val r = ref z

      fun mk i =
          let
            val ref z = r
            val (x,z) = mks (i,z)
            val () = r := z
          in
            x
          end

      val v = Vector.tabulate (length,mk)

      val ref z = r
    in
      (v,z)
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.tabulatesVector: " ^ bug);
*)

fun mapsVector fs z v =
    let
      val r = ref z

      fun f x =
          let
            val ref z = r
            val (y,z) = fs (x,z)
            val () = r := z
          in
            y
          end

      val v = Vector.map f v

      val ref z = r
    in
      (v,z)
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.mapsVector: " ^ bug);
*)

(* ------------------------------------------------------------------------- *)
(* Intervals of [0,1].                                                       *)
(* ------------------------------------------------------------------------- *)

val maxInterval = INTERVALS - 1;

val allIntervals = interval 0 INTERVALS;

fun foldIntervals f z = List.foldl f z allIntervals;

val realIntervals = Real.fromInt INTERVALS;

val intervalWidth = 1.0 / realIntervals;

val halfIntervalWidth = 0.5 * intervalWidth;

fun xToInt x =
    let
      val y = x * realIntervals
      val i = Real.floor x
    in
      (i, y - Real.fromInt i)
    end;

fun xFromInt (i,e) = (Real.fromInt i + e) / realIntervals;

(* ------------------------------------------------------------------------- *)
(* Midpoints (i.e., boundaries between intervals).                           *)
(* ------------------------------------------------------------------------- *)

val midpoints = INTERVALS - 1;

val maxMidpoint = midpoints - 1;

val ppMidpoint = Print.ppMap (fn i => i + 1) Print.ppInt;

fun complementMid i = maxMidpoint - i;

fun xFromMid i = xFromInt (i,1.0);

val logMid =
    let
      val logs = Vector.tabulate (midpoints, Math.ln o xFromMid)
    in
      fn i => Vector.sub (logs,i)
    end;

(* ------------------------------------------------------------------------- *)
(* pdf operations.                                                           *)
(* ------------------------------------------------------------------------- *)

type pdf = real Vector.vector;

fun intPdf (p : pdf) i =
    let
(*GomiDebug
      val _ = 0 <= i orelse raise Bug "negative index"
      val _ = i < INTERVALS orelse raise Bug "large index"
*)
    in
      Vector.sub (p,i)
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.intPdf: " ^ bug);
*)

fun tabulatePdf mk : pdf =
    Vector.tabulate (INTERVALS,mk)
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.tabulatePdf: " ^ bug);
*)

fun tabulatesPdf mks z =
    tabulatesVector INTERVALS mks z
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.tabulatesPdf: " ^ bug);
*)

fun mapPdf f (p : pdf) : pdf =
    Vector.map f p
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.mapPdf: " ^ bug);
*)

fun mapsPdf fs z (p : pdf) =
    mapsVector fs z p
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.mapsPdf: " ^ bug);
*)

local
  fun mx (i,p,z) = if i = 0 then p else Real.max (p,z);
in
  fun maxPdf (p : pdf) = Vector.foldli mx 0.0 p;
end;

fun sumPdf (p : pdf) = Vector.foldl op+ 0.0 p;

fun unlogPdf (p : pdf) : pdf =
    let
      val m = maxPdf p

      fun unlog l = Math.exp (l - m)
    in
      mapPdf unlog p
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.unlogPdf: " ^ bug);
*)

(* ------------------------------------------------------------------------- *)
(* cdf operations.                                                           *)
(* ------------------------------------------------------------------------- *)

type cdf = real Vector.vector;

fun midCdf (c : cdf) i =
    let
(*GomiDebug
      val _ = 0 <= i orelse raise Bug "negative index"
      val _ = i < midpoints orelse raise Bug "large index"
*)
    in
      Vector.sub (c,i)
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.midCdf: " ^ bug);
*)

fun intCdf c i =
    let
(*GomiDebug
      val _ = 0 <= i orelse raise Bug "negative index"
      val _ = i <= INTERVALS orelse raise Bug "large index"
*)
    in
      if i = 0 then 0.0
      else if i = INTERVALS then 1.0
      else midCdf c (i - 1)
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.intCdf: " ^ bug);
*)

fun diffMidCdf (c : cdf) i =
    let
      val p = midCdf c i
    in
      if i = 0 then p else p - midCdf c (i - 1)
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.diffMidCdf: " ^ bug);
*)

fun tabulateCdf mk : cdf =
    Vector.tabulate (midpoints,mk)
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.tabulateCdf: " ^ bug);
*)

fun tabulatesCdf mks z =
    tabulatesVector midpoints mks z
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.tabulatesCdf: " ^ bug);
*)

fun mapCdf f (c : cdf) : cdf =
    Vector.map f c
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.mapCdf: " ^ bug);
*)

fun mapsCdf fs z (c : cdf) =
    mapsVector fs z c
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.mapsCdf: " ^ bug);
*)

fun foldCdf f z (c : cdf) =
    let
      fun g i x z =
          if i = midpoints then z
          else
            let
              val x = x + intervalWidth
              val z = f (x, midCdf c i, z)
              val i = i + 1
            in
              g i x z
            end
    in
      g 0 0.0 z
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.foldCdf: " ^ bug);
*)

fun fold2Cdf f z (c1 : cdf) (c2 : cdf) =
    let
      fun g i x z =
          if i = midpoints then z
          else
            let
              val x = x + intervalWidth
              val z = f (x, midCdf c1 i, midCdf c2 i, z)
              val i = i + 1
            in
              g i x z
            end
    in
      g 0 0.0 z
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.fold2Cdf: " ^ bug);
*)

fun foldPdfCdf f z c =
    let
      fun g (x2,y2,(x1,y1,z)) =
          let
            val d = (y2 - y1) * realIntervals
            val z = f (x1,x2,d,z)
          in
            (x2,y2,z)
          end

      val (x1,y1,z) = foldCdf g (0.0,0.0,z) c
      val x2 = 1.0
      and y2 = 1.0
      val d = (y2 - y1) * realIntervals
      val z = f (x1,x2,d,z)
    in
      z
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.foldPdfCdf: " ^ bug);
*)

fun expectationCdf c f =
    let
      fun g (_,x2,y,(s,fx1)) =
          let
            val fx2 = f x2
            val s = s + 0.5 * (fx1 + fx2) * y
          in
            (s,fx2)
          end

      val (s,_) = foldPdfCdf g (0.0, f 0.0) c
    in
      s * intervalWidth
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.expectationCdf: " ^ bug);
*)

local
  val inverseConj : (int * (int * real) list) list Vector.vector =
      let
        fun add i j k w l =
            if k = maxInterval orelse w <= 0.0 then l
            else
              (* P (x_{i - 1} < A < x_i) * P (x_{j - 1} < B < x_j) *)
              (* = (P (A < x_i) - P (A < x_{i - 1})) * *)
              (*   (P (B < x_j) - P (B < x_{j - 1})) *)
              (* = P (A < x_i) * P (B < x_j) - *)
              (*   P (A < x_{i - 1}) * P (B < x_j) - *)
              (*   P (A < x_i) * P (B < x_{j - 1}) + *)
              (*   P (A < x_{i - 1}) * P (B < x_{j - 1}) *)
              let
                val l = (k,i,j,w) :: l
                val l = if i = 0 then l else (k, i - 1, j, ~w) :: l
                val l = if j = 0 then l else (k, i, j - 1, ~w) :: l
              in
                if i = 0 orelse j = 0 then l else (k, i - 1, j - 1, w) :: l
              end

        fun sq i j =
            let
              val xi = xFromInt (i,0.0)
              and yi = xFromInt (i,1.0)
              and xj = xFromInt (j,0.0)
              and yj = xFromInt (j,1.0)

              val x = xi * xj
              and y = yi * yj

              fun f (k,l) =
                  let
                    val a = Real.max (xFromInt (k,0.0), x)
                    and b = Real.min (xFromInt (k,1.0), y)
                    val w = (b - a) / (y - x)
                  in
                    add i j k w l
                  end
            in
              foldIntervals f
            end

        fun rect (i,l) = foldIntervals (uncurry (sq i)) l

        val groupSum =
            let
              fun f (j,l) =
                  let
                    val w = List.foldl op+ 0.0 l
                  in
                    if Real.abs w < EPSILON then NONE else SOME (j,w)
                  end
            in
              fn (i,l) => (i, List.mapPartial f (groupsByFst l))
            end

        fun finalizeMid (k,l) =
            let
              val (l1,l2) = divideWhile (equal k o fst) l
              val l1 = List.map snd l1
              val l1 = groupsByFst l1
(*GomiTrace5
              val () =
                  let
                    val funOp = " ->"

                    val pp =
                        Print.ppList
                          (Print.ppOp2 funOp ppMidpoint
                             (Print.ppList
                                (Print.ppOp2 funOp ppMidpoint Print.ppReal)))
                  in
                    Print.trace pp "Probability.inverseConj: l1" l1
                  end
*)
              val l1 = List.map groupSum l1
            in
              (l1,l2)
            end

        fun finalize l =
            let
              val l = List.map (fn (k,i,j,w) => (k,(i,(j,w)))) l
              val l = sort (prodCompare Int.compare
                             (prodCompare Int.compare
                               (prodCompare Int.compare Real.compare))) l
(*GomiTrace5
              val () =
                  let
                    val assign = " <-"
                    and sep1 = ","
                    and sep2 = ":"

                    val pp =
                        Print.ppList
                          (Print.ppOp2 assign ppMidpoint
                            (Print.ppOp2 sep1 ppMidpoint
                              (Print.ppOp2 sep2 ppMidpoint Print.ppReal)))
                  in
                    trace ("Probability.inverseConj: l =\n" ^
                           Print.toString pp l ^ "\n")
                  end
*)

              val (v,x) = tabulatesVector midpoints finalizeMid l

(*GomiDebug
              val _ = List.all (equal maxInterval o fst) x orelse
                      raise Bug "bad final x"
*)
            in
              v
            end
      in
        finalize (foldIntervals rect [])
      end
(*GomiDebug
      handle Bug bug => raise Bug ("Probability.inverseConj: " ^ bug);
*)

  fun midPlusCdf c i = if i = midpoints then 1.0 else midCdf c i;
in
  val inverseConjInfo =
      let
        val ppInverseConj =
            Print.ppMap (enumerate o Vector.foldr op:: [])
              (Print.ppList
                 (Print.ppOp2 " <-" ppMidpoint
                    (Print.ppList
                       (Print.ppOp2 " ->" ppMidpoint
                          (Print.ppList
                             (Print.ppOp2 " ->" ppMidpoint
                                Print.ppPercent))))));
      in
        Print.toString ppInverseConj inverseConj
      end;

  fun conjCdf c1 c2 =
      let
        fun sum2 p1 ((i2,w),s) = s + p1 * (midPlusCdf c2 i2) * w

        fun sum1 ((i1,l),s) = List.foldl (sum2 (midPlusCdf c1 i1)) s l

        fun mks (i,s) =
            let
              val s' = List.foldl sum1 s (Vector.sub (inverseConj,i))

              val s' = if s' < s then s else if s' > 1.0 then 1.0 else s'
            in
              (s',s')
            end

        val (c,_) = tabulatesCdf mks 0.0
      in
        c
      end
(*GomiDebug
      handle Bug bug => raise Bug ("Probability.conjCdf: " ^ bug);
*)
end

fun combineCdf c1 c2 =
    let
      fun mks (i,(y1,y2,s)) =
          let
            val y1' = midCdf c1 i
            and y2' = midCdf c2 i
            val p1 = y1' - y1
            and p2 = y2' - y2
            val s = s + p1 * p2
          in
            (s,(y1',y2',s))
          end

      val (c,(y1,y2,s)) = tabulatesCdf mks (0.0,0.0,0.0)
      val s = s + (1.0 - y1) * (1.0 - y2)

      val c =
          if s >= 1.0 then c
(*GomiDebug
          else if s <= 0.0 then raise Bug "zero probability"
*)
          else
            let
              val n = 1.0 / (1.0 - s)
              fun f y = y * n
            in
              mapCdf f c
            end
    in
      c
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.combineCdf: " ^ bug);
*)

fun quartileCdf c q =
    let
      fun f a ax b bx =
          if b - a <= 1 then
            let
(*GomiDebug
              val _ = b - a = 1 orelse raise Bug "shrunk"
*)
              (* q = ax * e + bx * (1 - e) = e * (ax - bx) + bx *)
              (* ==> *)
              (* e = (q - bx) / (ax - bx) = (bx - q) / (bx - ax) *)
              val e = (bx - q) / (bx - ax)
            in
              xFromInt (a,e)
            end
          else
            let
              val i = (a + b) div 2
              val ix = intCdf c i
            in
              if ix < q then f a ax i ix else f i ix b bx
            end
    in
      f 0 0.0 INTERVALS 1.0
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.quartileCdf: " ^ bug);
*)

fun fromPdfCdf (p : pdf) : cdf =
    let
      val n = sumPdf p
      val n = 1.0 / n

      fun mks (i,s) =
          let
            val p = intPdf p i * n
            val s = s + p
          in
            (s,s)
          end

      val (c,s) = tabulatesCdf mks 0.0

(*GomiDebug
      val _ = Real.abs (1.0 - (s + intPdf p maxInterval * n)) < EPSILON orelse
              raise Bug "bad final s"
*)
    in
      c
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.fromPdfCdf: " ^ bug);
*)

(* ------------------------------------------------------------------------- *)
(* A type of probability estimates.                                          *)
(* ------------------------------------------------------------------------- *)

datatype probability = Probability of cdf;

fun tabulate mk = Probability (tabulateCdf mk);

fun complement (Probability c) =
    let
      fun mk i = 1.0 - midCdf c (complementMid i)
    in
      tabulate mk
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.complement: " ^ bug);
*)

fun conj (Probability c1) (Probability c2) =
    Probability (conjCdf c1 c2)
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.conj: " ^ bug);
*)

fun disj p1 p2 = complement (conj (complement p1) (complement p2));

(* P (max(A,B) < x) = P (A < x /\ B < x) = P (A < x) * P (B < x) *)

fun max (Probability c1) (Probability c2) =
    let
      fun mk i = midCdf c1 i * midCdf c2 i
    in
      tabulate mk
    end;

fun min p1 p2 = complement (max (complement p1) (complement p2));

fun combine (Probability c1) (Probability c2) =
    Probability (combineCdf c1 c2)
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.combine: " ^ bug);
*)

(* ------------------------------------------------------------------------- *)
(* Statistics.                                                               *)
(* ------------------------------------------------------------------------- *)

fun cdf p x =
    let
      val (i,e) = xToInt x
    in
      if i = INTERVALS then 1.0
      else
        let
          val Probability c = p
          val a = intCdf c i
          and b = intCdf c (i + 1)
        in
          e * a + (1.0 - e) * b
        end
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.cdf: " ^ bug);
*)

fun pdf (Probability c) x =
    let
      val (i,_) = xToInt x
      val i = if i = INTERVALS then i - 1 else i
      val a = intCdf c i
      and b = intCdf c (i + 1)
    in
      (b - a) * realIntervals
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.pdf: " ^ bug);
*)

fun expectation (Probability c) f = expectationCdf c f;

fun mean p = expectation p I;

fun variance p =
    let
      val m = mean p

      fun f x = square (x - m)
    in
      expectation p f
    end;

fun standardDeviation p = Math.sqrt (variance p);

fun quartile (Probability c) q = quartileCdf c q;

fun median p = quartile p 0.5;

fun lowerQuartile p = quartile p 0.25;

fun upperQuartile p = quartile p 0.75;

(* P (A <= B) = Integrate P (A <= x /\ B = x) dx *)

fun lessThan (Probability c1) (Probability c2) =
    let
      fun less (_,p1',p2',(s,p1,p2)) =
          let
            val d2 = p2' - p2

            (* Increment s by *)
            (*   p1 * d2 + 0.5 * (p1' - p1) * d2 *)
            (* = (p1 + 0.5 * (p1' - p1)) * d2 *)
            (* = (0.5 * p1' + 0.5 * p1) * d2 *)
            (* = 0.5 * (p1' + p1) * d2 *)
            (* Looks like the area of a trapezium *)
            val s = s + 0.5 * (p1' + p1) * d2
          in
            (s,p1',p2')
          end

      val (s,p1,p2) = fold2Cdf less (0.0,0.0,0.0) c1 c2
    in
      s + 0.5 * (1.0 + p1) * (1.0 - p2)
    end;

(* P (max(A,B) <= B) = P (A <= B) *)
(* The trick is calculating A as we go. *)

fun maxLessThan (Probability c1) (Probability c2) =
    let
      fun less (_,p1',p2',(s,p1,p2)) =
          if p2' < EPSILON then (s,p1',p2')
          else
            let
              val p1' = p1' / p2'

              val d2 = p2' - p2

              val s = s + 0.5 * (p1' + p1) * d2
            in
              (s,p1',p2')
            end

      val (s,p1,p2) = fold2Cdf less (0.0,0.0,0.0) c1 c2
    in
      s + 0.5 * (1.0 + p1) * (1.0 - p2)
    end;

(* ------------------------------------------------------------------------- *)
(* Distances between probability estimate cdfs.                              *)
(* ------------------------------------------------------------------------- *)

fun distanceLp p (Probability c1) (Probability c2) =
    let
      fun f (_,y1,y2,z) = z + Math.pow (Real.abs (y2 - y1), p)
    in
      Math.pow (fold2Cdf f 0.0 c1 c2, 1.0 / p)
    end;

val distanceL1 =
    let
      fun f (_,y1,y2,z) = z + Real.abs (y2 - y1)
    in
      fn Probability c1 => fn Probability c2 => fold2Cdf f 0.0 c1 c2
    end;

val distanceL2 =
    let
      fun f (_,y1,y2,z) = z + square (y2 - y1)
    in
      fn Probability c1 =>
         fn Probability c2 =>
            Math.sqrt (fold2Cdf f 0.0 c1 c2)
    end;

val distanceLinf =
    let
      fun f (_,y1,y2,z) = Real.max (z, Real.abs (y2 - y1))
    in
      fn Probability c1 => fn Probability c2 => fold2Cdf f 0.0 c1 c2
    end;

(* ------------------------------------------------------------------------- *)
(* Distributions over probabilities.                                         *)
(* ------------------------------------------------------------------------- *)

val uniformDistribution = tabulate xFromMid;

fun betaDistribution {alpha : real, beta : real} =
    let
(*GomiDebug
      val _ = alpha > 0.0 orelse raise Bug "nonpositive alpha"
      and _ = beta > 0.0 orelse raise Bug "nonpositive beta"
*)

      val a = alpha - 1.0
      and b = beta - 1.0

      (* beta x = x ^ a * (1 - x) ^ b *)

      fun logBeta x = a * Math.ln x + b * Math.ln (1.0 - x)

      (* d (beta x) / d x *)
      (* = a * x ^ (a - 1) * (1 - x) ^ b - x ^ a * b * (1 - x) ^ (b - 1) *)
      (* = x ^ (a - 1) * (1 - x) ^ (b - 1) * (a * (1 - x) - b * x) *)

      (* Therefore, turning point is the solution of *)
      (* x ^ (a - 1) * (1 - x) ^ (b - 1) * (a * (1 - x) - b * x) = 0 *)
      (* ==> a * (1 - x) - b * x = 0 *)
      (* ==> a - a * x - b * x = 0 *)
      (* ==> a = (a + b) * x *)
      (* ==> x = a / (a + b) *)

      val turningPoint =
          if Real.abs (a + b) >= EPSILON then a / (a + b)
          else if a >= EPSILON then 1.0
          else if b >= EPSILON then 0.0
          else 0.5

      fun mks (i,x) =
          let
            val x' = x + intervalWidth

            val px =
                if i = 0 then x'
                else if i = maxInterval then x
                else if x > turningPoint then x
                else if x' >= turningPoint then turningPoint
                else x'

            val p = logBeta px
          in
            (p,x')
          end

      val (p,x) = tabulatesPdf mks 0.0

(*GomiDebug
      val _ = Real.abs (1.0 - x) < EPSILON orelse raise Bug "bad final x"
*)

      val p = unlogPdf p

      val c = fromPdfCdf p
    in
      Probability c
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.betaDistribution: " ^ bug);
*)

(*GomiDebug
val () =
    let
      val beta_1_1 = betaDistribution {alpha = 1.0, beta = 1.0}

      val _ = distanceLinf beta_1_1 uniformDistribution < EPSILON orelse
              raise Bug "{alpha = 1.0, beta = 1.0} <> uniform"
    in
      ()
    end
    handle Bug bug => raise Bug ("Probability: betaDistribution tests: " ^ bug);
*)

(* ------------------------------------------------------------------------- *)
(* Estimating probabilities from frequency counts.                           *)
(* ------------------------------------------------------------------------- *)

type frequency = {count : int, total : int};

val zeroFrequency : frequency = {count = 0, total = 0};

fun fromFrequency {count,total} =
    let
      val alpha = Real.fromInt (count + 1)
      val beta = Real.fromInt (total + 1 - count)
    in
      betaDistribution {alpha = alpha, beta = beta}
    end;

(* ------------------------------------------------------------------------- *)
(* A (multi-)set of probability estimates.                                   *)
(* ------------------------------------------------------------------------- *)

type set = probability IntMap.map;

val emptySet : set = IntMap.new ();

fun addSet set i_p = IntMap.insert set i_p;

fun maxSet set =
    let
      fun maxI i =
          let
            fun multI (_, Probability c, z) = z * midCdf c i
          in
            IntMap.foldl multI 1.0 set
          end
    in
      tabulate maxI
    end;

local
  fun maxDist set =
      let
        val m = maxSet set
      in
        IntMap.transform (maxLessThan m) set
      end;
in
  fun isMaxSet set =
      (case IntMap.size set of
         1 => IntMap.transform (K 1.0) set
(*GomiDebug
       | 0 => raise Bug "empty set"
*)
       | _ => maxDist set)
(*GomiDebug
      handle Bug bug => raise Bug ("Probability.isMaxSet: " ^ bug);
*)
end;

fun randomIterator iter prob =
    case iter of
      NONE => raise Bug "total probability less than 1"
    | SOME iter =>
      let
        val (k,p) = IntMap.readIterator iter
        val prob = prob - p
      in
        if prob <= 0.0 then k
        else randomIterator (IntMap.advanceIterator iter) prob
      end;

fun randomSet set =
    randomIterator (IntMap.mkIterator set) (Portable.randomReal ())
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.randomSet: " ^ bug);
*)

fun randomMaxSet set =
    randomSet (isMaxSet set)
(*GomiDebug
    handle Bug bug => raise Bug ("Probability.randomMaxSet: " ^ bug);
*)

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty printing.                                              *)
(* ------------------------------------------------------------------------- *)

val toString = percentToString o mean;

val pp = Print.ppMap toString Print.ppString;

end
