(* ========================================================================= *)
(* DATABASE MAPPING PATTERNS TO FORMULA PROBABILITIES                        *)
(* Copyright (c) 2005 Joe Leslie-Hurd, distributed under the MIT license     *)
(* ========================================================================= *)

structure Database :> Database =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Constants.                                                                *)
(* ------------------------------------------------------------------------- *)

val EPSILON = 1e~6;

(* ------------------------------------------------------------------------- *)
(* Probability estimates.                                                    *)
(* ------------------------------------------------------------------------- *)

fun equalProbability p1 p2 = Probability.distanceLinf p1 p2 < EPSILON;

(* ------------------------------------------------------------------------- *)
(* A type of pattern set.                                                    *)
(* ------------------------------------------------------------------------- *)

type patterns = PatternSet.set;

type patternsUpdate = {add : patterns, remove : patterns};

val noPatterns : patterns = PatternSet.empty;

val emptyPatternsUpdate = {add = noPatterns, remove = noPatterns};

fun nullPatterns p = PatternSet.null p;

fun nullPatternsUpdate {add,remove} =
    nullPatterns add andalso nullPatterns remove;

fun updatePatterns p {add,remove} =
    PatternSet.union (PatternSet.difference p remove) add;

fun subtractPatterns new old =
    {add = PatternSet.difference new old,
     remove = PatternSet.difference old new};

(* ------------------------------------------------------------------------- *)
(* A type of formula frequencies.                                            *)
(* ------------------------------------------------------------------------- *)

type formulaFrequency = Probability.frequency FormulaMap.map;

val emptyFormulaFrequency : formulaFrequency = FormulaMap.new ();

val fromFormulaSetFormulaFrequency =
    let
      fun add (f,s) = FormulaMap.insert s (f, Probability.zeroFrequency)
    in
      FormulaSet.foldl add emptyFormulaFrequency
    end;

(* ------------------------------------------------------------------------- *)
(* A type of estimated formula probabilities.                                *)
(* ------------------------------------------------------------------------- *)

type estimate = Probability.probability FormulaMap.map;

type estimateUpdate = Probability.probability option FormulaMap.map;

val emptyEstimate : estimate = FormulaMap.new ();

val fromFormulaFrequencyEstimate =
    FormulaMap.transform Probability.fromFrequency;

local
  fun combineProb ((_,p),(_,q)) = SOME (Probability.combine p q);
in
  val combineEstimate = FormulaMap.union combineProb;
end;

val emptyEstimateUpdate : estimateUpdate = FormulaMap.new ();

local
  fun first (_,v) = SOME v;

  fun second (_,v) = v;

  fun both (_,(_,v)) = v;
in
  fun updateEstimate est estUp =
      FormulaMap.merge
        {first = first, second = second, both = both}
        est estUp;
end;

local
  fun first (_,v) = SOME NONE;

  fun second (_,v) = SOME (SOME v);

  fun both ((_,v),(_,v')) =
      if equalProbability v v' then NONE else SOME (SOME v');
in
  fun subtractEstimate new old =
      FormulaMap.merge
        {first = first, second = second, both = both}
        old new;
end;

(* ------------------------------------------------------------------------- *)
(* A type of pattern database.                                               *)
(* ------------------------------------------------------------------------- *)

datatype database =
    Database of
      {patternEstimate : formulaFrequency PatternMap.map};

local
  (* Interesting formulas are BlackTerritory(p) for all points p and *)
  (* BlackTerritory(p) ==> BlackTerritory(q) for all neighbours p and q *)

  fun interestingFormulas dim =
      let
        fun mkN pf (q,acc) =
            let
              val qf = Formula.isBlackTerritory q
              val pqf = Formula.Implies (pf,qf)
            in
              FormulaSet.add acc pqf
            end

        fun mk (p,acc) =
            let
              val pf = Formula.isBlackTerritory p
              val acc = FormulaSet.add acc pf
            in
              Dimensions.foldNeighbours (mkN pf) acc dim p
            end
      in
        Dimensions.fold mk FormulaSet.empty dim
      end;

  fun newPatternEstimate dim =
      let
        val ff = fromFormulaSetFormulaFrequency (interestingFormulas dim)
      in
        PatternMap.singleton (Pattern.alwaysTrue, ff)
      end;
in
  fun new (position : Position.parameters) =
      let
        val {dimensions,...} = position
        val patternEstimate = newPatternEstimate dimensions
      in
        Database
          {patternEstimate = patternEstimate}
      end;
end;

fun patterns database =
    let
      val Database {patternEstimate} = database
    in
      PatternSet.fromList (List.map fst (PatternMap.toList patternEstimate))
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Database.patterns: " ^ bug);
*)

local
  fun add (fm,_,acc) = FormulaSet.add acc fm

  fun inc (_,est,acc) = FormulaMap.foldl add acc est;
in
  fun formulas database =
      let
        val Database {patternEstimate} = database
      in
        PatternMap.foldl inc FormulaSet.empty patternEstimate
      end
(*GomiDebug
      handle Bug bug => raise Bug ("Database.formulas: " ^ bug);
*)
end;

local
  fun inc (_,est,c) = c + FormulaMap.size est;
in
  fun size database =
      let
        val Database {patternEstimate} = database
        val p = PatternMap.size patternEstimate
        val c = PatternMap.foldl inc 0 patternEstimate
        val f = FormulaSet.size (formulas database)
      in
        {patterns = p,
         connections = c,
         formulas = f}
      end;
end;

fun peekPattern database pat =
    let
      val Database {patternEstimate} = database
    in
      PatternMap.peek patternEstimate pat
    end;

fun getPattern database pat =
    case peekPattern database pat of
      SOME ff => ff
    | NONE => raise Bug "Database.getPattern";

(* TODO: For each formula, this should filter out subsumed patterns before *)
(* computing the estimate. *)
fun estimate database =
    let
      fun inc (pat,est) =
          let
            val ff = getPattern database pat
            val est' = fromFormulaFrequencyEstimate ff
          in
            combineEstimate est est'
          end
    in
      PatternSet.foldl inc emptyEstimate
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Database.estimate: " ^ bug);
*)

fun estimateAfter database pats est patsUpdate =
    if nullPatternsUpdate patsUpdate then emptyEstimateUpdate
    else
      let
        val pats' = updatePatterns pats patsUpdate
        val est' = estimate database pats'
      in
        subtractEstimate est' est
      end;

fun learn database pats fmPred =
    let
      fun updatePat (pat,patsEst) =
          let
            fun updateFm (fm,freq) =
                let
                  val {count,total} = freq
                  val count = if fmPred fm then count + 1 else count
                  val total = total + 1
                in
                  {count = count, total = total}
                end

            val patEst =
                case PatternMap.peek patsEst pat of
                  SOME ff => ff
                | NONE => raise Bug "not a database pattern"

            val patEst = FormulaMap.map updateFm patEst
          in
            PatternMap.insert patsEst (pat,patEst)
          end

      val Database {patternEstimate} = database

      val patternEstimate = PatternSet.foldl updatePat patternEstimate pats

      val database = Database {patternEstimate = patternEstimate}

      val patsUpdate = emptyPatternsUpdate
    in
      (database,patsUpdate)
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Database.learn: " ^ bug);
*)

(* ------------------------------------------------------------------------- *)
(* Pretty printing.                                                          *)
(* ------------------------------------------------------------------------- *)

fun toString database =
    let
      val {patterns = p, connections = c, formulas = f} = size database
    in
      "{" ^ Int.toString p ^ " pattern" ^ (if p = 1 then "" else "s") ^
      "} ---" ^ Int.toString c ^ "--> {" ^ Int.toString f ^ " formula" ^
      (if f = 1 then "" else "s") ^ "}"
    end;

val pp = Print.ppMap toString Print.ppString;

end
