(* ========================================================================= *)
(* FORMULAS OVER TERMINAL GO POSITIONS                                       *)
(* Copyright (c) 2005 Joe Leslie-Hurd, distributed under the MIT license     *)
(* ========================================================================= *)

structure Formula :> Formula =
struct

open Useful;

infixr ==

val op== = Portable.pointerEqual;

(* ------------------------------------------------------------------------- *)
(* Integers associated with terminal go positions.                           *)
(* ------------------------------------------------------------------------- *)

datatype integer =
  (* Primitive integer operations *)
    Integer of int
  | Negate of integer
  | Add of integer * integer
  | Multiply of integer * integer

  (* The number of black territory points minus the number of white points *)
  | PointsScore

  (* The number of territory points in a group *)
  | PointsGroup of Point.point;

fun compareInteger i1_i2 =
    case i1_i2 of
      (Integer i1, Integer i2) => Int.compare (i1,i2)
    | (Integer _, _) => LESS
    | (_, Integer _) => GREATER
    | (Negate i1, Negate i2) => compareInteger (i1,i2)
    | (Negate _, _) => LESS
    | (_, Negate _) => GREATER
    | (Add i1, Add i2) => prodCompare compareInteger compareInteger (i1,i2)
    | (Add _, _) => LESS
    | (_, Add _) => GREATER
    | (Multiply i1, Multiply i2) =>
      prodCompare compareInteger compareInteger (i1,i2)
    | (Multiply _, _) => LESS
    | (_, Multiply _) => GREATER
    | (PointsScore,PointsScore) => EQUAL
    | (PointsScore,_) => LESS
    | (_,PointsScore) => GREATER
    | (PointsGroup p1, PointsGroup p2) => Point.compare (p1,p2);

fun isValidInteger dim =
    let
      fun valid integer =
          case integer of
            Integer _ => true
          | Negate i => valid i
          | Add (i,j) => valid i andalso valid j
          | Multiply (i,j) => valid i andalso valid j
          | PointsScore => true
          | PointsGroup p => Dimensions.member p dim
    in
      valid
    end;

fun transformInteger' sym =
    let
      fun trans integer =
          case integer of
            Integer _ => integer
          | Negate i =>
            let
              val i' = trans i
            in
              if i' == i then integer else Negate i'
            end
          | Add (i,j) =>
            let
              val i' = trans i
              and j' = trans j
            in
              if i' == i andalso j' == j then integer else Add (i',j')
            end
          | Multiply (i,j) =>
            let
              val i' = trans i
              and j' = trans j
            in
              if i' == i andalso j' == j then integer else Multiply (i',j')
            end
          | PointsScore => integer
          | PointsGroup p =>
            let
              val p' = Symmetry.transformPoint sym p
            in
              if p' == p then integer else PointsGroup p'
            end
    in
      trans
    end;

fun transformInteger sym integer =
    if Symmetry.isIdentity sym then integer else transformInteger' sym integer;

(* ------------------------------------------------------------------------- *)
(* Point status in terminal go positions.                                    *)
(* ------------------------------------------------------------------------- *)

datatype status =
  (* Primitive status operations *)
    Status of Status.status

  (* Status of an individual point *)
  | StatusPoint of Point.point;

fun compareStatus s1_s2 =
    case s1_s2 of
      (Status s1, Status s2) => Status.compare (s1,s2)
    | (Status _, _) => LESS
    | (_, Status _) => GREATER
    | (StatusPoint p1, StatusPoint p2) => Point.compare (p1,p2);

fun isValidStatus dim =
    let
      fun valid status =
          case status of
            Status _ => true
          | StatusPoint p => Dimensions.member p dim
    in
      valid
    end;

fun transformStatus' sym =
    let
      fun trans status =
          case status of
            Status s =>
            let
              val s' = Status.transform sym s
            in
              if s' == s then status else Status s'
            end
          | StatusPoint p =>
            let
              val p' = Symmetry.transformPoint sym p
            in
              if p' == p then status else StatusPoint p'
            end
    in
      trans
    end;

fun transformStatus sym status =
    if Symmetry.isIdentity sym then status else transformStatus' sym status;

(* ------------------------------------------------------------------------- *)
(* Sides in terminal go positions.                                           *)
(* ------------------------------------------------------------------------- *)

datatype side =
  (* Primitive status operations *)
    Side of Side.side option
  | Opponent of side

  (* Side of a point status *)
  | SideStatus of status
  | SideStoneStatus of status
  | SideEyeStatus of status;

fun compareSide s1_s2 =
    case s1_s2 of
      (Side s1, Side s2) => optionCompare Side.compare (s1,s2)
    | (Side _, _) => LESS
    | (_, Side _) => GREATER
    | (Opponent s1, Opponent s2) => compareSide (s1,s2)
    | (Opponent _, _) => LESS
    | (_, Opponent _) => GREATER
    | (SideStatus s1, SideStatus s2) => compareStatus (s1,s2)
    | (SideStatus _, _) => LESS
    | (_, SideStatus _) => GREATER
    | (SideStoneStatus s1, SideStoneStatus s2) => compareStatus (s1,s2)
    | (SideStoneStatus _, _) => LESS
    | (_, SideStoneStatus _) => GREATER
    | (SideEyeStatus s1, SideEyeStatus s2) => compareStatus (s1,s2);

fun isValidSide dim =
    let
      fun valid side =
          case side of
            Side _ => true
          | Opponent s => valid s
          | SideStatus s => isValidStatus dim s
          | SideStoneStatus s => isValidStatus dim s
          | SideEyeStatus s => isValidStatus dim s
    in
      valid
    end;

fun transformSide' sym =
    let
      fun trans side =
          case side of
            Side NONE => side
          | Side (SOME s) =>
            let
              val s' = Symmetry.transformSide sym s
            in
              if s' == s then side else Side (SOME s')
            end
          | Opponent s =>
            let
              val s' = trans s
            in
              if s' == s then side else Opponent s'
            end
          | SideStatus s =>
            let
              val s' = transformStatus sym s
            in
              if s' == s then side else SideStatus s'
            end
          | SideStoneStatus s =>
            let
              val s' = transformStatus sym s
            in
              if s' == s then side else SideStoneStatus s'
            end
          | SideEyeStatus s =>
            let
              val s' = transformStatus sym s
            in
              if s' == s then side else SideEyeStatus s'
            end
    in
      trans
    end;

fun transformSide sym side =
    if Symmetry.isIdentity sym then side else transformSide' sym side;

val blackSide = Side (SOME Side.Black)
and whiteSide = Side (SOME Side.White)
and noSide = Side NONE;

(* ------------------------------------------------------------------------- *)
(* Formulas over terminal go positions.                                      *)
(* ------------------------------------------------------------------------- *)

datatype formula =
  (* Primitive logical operations *)
    Boolean of bool
  | Not of formula
  | And of formula * formula
  | Or of formula * formula
  | Implies of formula * formula
  | Iff of formula * formula

  (* Integer formulas *)
  | LessThan of integer * integer
  | LessEqual of integer * integer
  | IntegerEqual of integer * integer
  | GreaterEqual of integer * integer
  | GreaterThan of integer * integer

  (* Status formulas *)
  | StatusEqual of status * status
  | StatusMember of status * StatusSet.set

  (* Side formulas *)
  | SideEqual of side * side

  (* Group formulas *)
  | ConnectedGroup of Point.point * Point.point
  | SekiGroup of Point.point;

fun compare p1_p2 =
    case p1_p2 of
      (Boolean b1, Boolean b2) => boolCompare (b1,b2)
    | (Boolean _, _) => LESS
    | (_, Boolean _) => GREATER
    | (Not p1, Not p2) => compare (p1,p2)
    | (Not _, _) => LESS
    | (_, Not _) => GREATER
    | (And p1, And p2) => prodCompare compare compare (p1,p2)
    | (And _, _) => LESS
    | (_, And _) => GREATER
    | (Or p1, Or p2) => prodCompare compare compare (p1,p2)
    | (Or _, _) => LESS
    | (_, Or _) => GREATER
    | (Implies p1, Implies p2) => prodCompare compare compare (p1,p2)
    | (Implies _, _) => LESS
    | (_, Implies _) => GREATER
    | (Iff p1, Iff p2) => prodCompare compare compare (p1,p2)
    | (Iff _, _) => LESS
    | (_, Iff _) => GREATER
    | (LessThan i1, LessThan i2) =>
      prodCompare compareInteger compareInteger (i1,i2)
    | (LessThan _, _) => LESS
    | (_, LessThan _) => GREATER
    | (LessEqual i1, LessEqual i2) =>
      prodCompare compareInteger compareInteger (i1,i2)
    | (LessEqual _, _) => LESS
    | (_, LessEqual _) => GREATER
    | (IntegerEqual i1, IntegerEqual i2) =>
      prodCompare compareInteger compareInteger (i1,i2)
    | (IntegerEqual _, _) => LESS
    | (_, IntegerEqual _) => GREATER
    | (GreaterEqual i1, GreaterEqual i2) =>
      prodCompare compareInteger compareInteger (i1,i2)
    | (GreaterEqual _, _) => LESS
    | (_, GreaterEqual _) => GREATER
    | (GreaterThan i1, GreaterThan i2) =>
      prodCompare compareInteger compareInteger (i1,i2)
    | (GreaterThan _, _) => LESS
    | (_, GreaterThan _) => GREATER
    | (StatusEqual s1, StatusEqual s2) =>
      prodCompare compareStatus compareStatus (s1,s2)
    | (StatusEqual _, _) => LESS
    | (_, StatusEqual _) => GREATER
    | (StatusMember s1, StatusMember s2) =>
      prodCompare compareStatus StatusSet.compare (s1,s2)
    | (StatusMember _, _) => LESS
    | (_, StatusMember _) => GREATER
    | (SideEqual s1, SideEqual s2) =>
      prodCompare compareSide compareSide (s1,s2)
    | (SideEqual _, _) => LESS
    | (_, SideEqual _) => GREATER
    | (ConnectedGroup p1, ConnectedGroup p2) =>
      prodCompare Point.compare Point.compare (p1,p2)
    | (ConnectedGroup _, _) => LESS
    | (_, ConnectedGroup _) => GREATER
    | (SekiGroup p1, SekiGroup p2) => Point.compare (p1,p2);

fun isValid dim =
    let
      fun valid pattern =
          case pattern of
            Boolean _ => true
          | Not p => valid p
          | And (p,q) => valid p andalso valid q
          | Or (p,q) => valid p andalso valid q
          | Implies (p,q) => valid p andalso valid q
          | Iff (p,q) => valid p andalso valid q
          | LessThan (i,j) =>
            isValidInteger dim i andalso isValidInteger dim j
          | LessEqual (i,j) =>
            isValidInteger dim i andalso isValidInteger dim j
          | IntegerEqual (i,j) =>
            isValidInteger dim i andalso isValidInteger dim j
          | GreaterEqual (i,j) =>
            isValidInteger dim i andalso isValidInteger dim j
          | GreaterThan (i,j) =>
            isValidInteger dim i andalso isValidInteger dim j
          | StatusEqual (s,t) => isValidStatus dim s andalso isValidStatus dim t
          | StatusMember (s,t) => isValidStatus dim s
          | SideEqual (s,t) => isValidSide dim s andalso isValidSide dim t
          | ConnectedGroup (p,q) =>
            Dimensions.member p dim andalso Dimensions.member q dim
          | SekiGroup p => Dimensions.member p dim
    in
      valid
    end;

fun transform' sym =
    let
      fun trans formula =
          case formula of
            Boolean _ => formula
          | Not p =>
            let
              val p' = trans p
            in
              if p' == p then formula else Not p'
            end
          | And (p,q) =>
            let
              val p' = trans p
              and q' = trans q
            in
              if p' == p andalso q' == q then formula else And (p',q')
            end
          | Or (p,q) =>
            let
              val p' = trans p
              and q' = trans q
            in
              if p' == p andalso q' == q then formula else Or (p',q')
            end
          | Implies (p,q) =>
            let
              val p' = trans p
              and q' = trans q
            in
              if p' == p andalso q' == q then formula else Implies (p',q')
            end
          | Iff (p,q) =>
            let
              val p' = trans p
              and q' = trans q
            in
              if p' == p andalso q' == q then formula else Iff (p',q')
            end
          | LessThan (i,j) =>
            let
              val i' = transformInteger' sym i
              and j' = transformInteger' sym j
            in
              if i' == i andalso j' == j then formula else LessThan (i',j')
            end
          | LessEqual (i,j) =>
            let
              val i' = transformInteger' sym i
              and j' = transformInteger' sym j
            in
              if i' == i andalso j' == j then formula else LessEqual (i',j')
            end
          | IntegerEqual (i,j) =>
            let
              val i' = transformInteger' sym i
              and j' = transformInteger' sym j
            in
              if i' == i andalso j' == j then formula else IntegerEqual (i',j')
            end
          | GreaterEqual (i,j) =>
            let
              val i' = transformInteger' sym i
              and j' = transformInteger' sym j
            in
              if i' == i andalso j' == j then formula else GreaterEqual (i',j')
            end
          | GreaterThan (i,j) =>
            let
              val i' = transformInteger' sym i
              and j' = transformInteger' sym j
            in
              if i' == i andalso j' == j then formula else GreaterThan (i',j')
            end
          | StatusEqual (s,t) =>
            let
              val s' = transformStatus' sym s
              and t' = transformStatus' sym t
            in
              if s' == s andalso t' == t then formula else StatusEqual (s',t')
            end
          | StatusMember (s,t) =>
            let
              val s' = transformStatus' sym s
              and t' = StatusSet.transform sym t
            in
              if s' == s andalso t' == t then formula else StatusMember (s',t')
            end
          | SideEqual (s,t) =>
            let
              val s' = transformSide' sym s
              and t' = transformSide' sym t
            in
              if s' == s andalso t' == t then formula else SideEqual (s',t')
            end
          | ConnectedGroup (p,q) =>
            let
              val p' = Symmetry.transformPoint sym p
              and q' = Symmetry.transformPoint sym q
            in
              if p' == p andalso q' == q then formula else ConnectedGroup (p',q')
            end
          | SekiGroup p =>
            let
              val p' = Symmetry.transformPoint sym p
            in
              if p' == p then formula else SekiGroup p'
            end
    in
      trans
    end;

fun transform sym formula =
    if Symmetry.isIdentity sym then formula else transform' sym formula;

val alwaysTrue = Boolean true
and alwaysFalse = Boolean false;

(* ------------------------------------------------------------------------- *)
(* Game result.                                                              *)
(* ------------------------------------------------------------------------- *)

fun isBlackWin komi =
    let
      val minBlack = Komi.minPointsBlackWin komi
    in
      GreaterEqual (PointsScore, Integer minBlack)
    end;

fun isWhiteWin komi =
    let
      val maxWhite = Komi.maxPointsWhiteWin komi
    in
      LessEqual (PointsScore, Integer maxWhite)
    end;

fun isSideWin komi side =
    case side of
      Side.Black => isBlackWin komi
    | Side.White => isWhiteWin komi;

fun isDraw komi =
    case Komi.exactPointsDraw komi of
      NONE => alwaysFalse
    | SOME points => IntegerEqual (PointsScore, Integer points);

fun isWin komi = Not (isDraw komi);

(* ------------------------------------------------------------------------- *)
(* Individual point status.                                                  *)
(* ------------------------------------------------------------------------- *)

local
  val blackStone = Status.Stone Side.Black
  and whiteStone = Status.Stone Side.White
  and blackEye = Status.Eye Side.Black
  and whiteEye = Status.Eye Side.White;

  val allStones = StatusSet.fromList [blackStone,whiteStone]
  and allEyes = StatusSet.fromList [blackEye,whiteEye];

  val blackTerritory = StatusSet.fromList [blackStone,blackEye]
  and whiteTerritory = StatusSet.fromList [whiteStone,whiteEye];
in
  fun isBlackStone point = StatusEqual (StatusPoint point, Status blackStone);

  fun isWhiteStone point = StatusEqual (StatusPoint point, Status whiteStone);

  fun isSideStone Side.Black = isBlackStone
    | isSideStone Side.White = isWhiteStone;

  fun isStone point = StatusMember (StatusPoint point, allStones);

  fun isBlackEye point = StatusEqual (StatusPoint point, Status blackEye);

  fun isWhiteEye point = StatusEqual (StatusPoint point, Status whiteEye);

  fun isSideEye Side.Black = isBlackEye
    | isSideEye Side.White = isWhiteEye;

  fun isEye point = StatusMember (StatusPoint point, allEyes);

  fun isBlackTerritory point = StatusMember (StatusPoint point, blackTerritory);

  fun isWhiteTerritory point = StatusMember (StatusPoint point, whiteTerritory);

  fun isSideTerritory Side.Black = isBlackTerritory
    | isSideTerritory Side.White = isWhiteTerritory;

  fun isSekiPoint point = StatusEqual (StatusPoint point, Status Status.Seki);

  fun isTerritory point = Not (isSekiPoint point);
end;

end

structure FormulaOrdered =
struct type t = Formula.formula val compare = Formula.compare end

structure FormulaMap = KeyMap (FormulaOrdered);

structure FormulaSet = ElementSet (FormulaMap);
