(* ========================================================================= *)
(* PATTERNS OVER GO POSITIONS                                                *)
(* Copyright (c) 2005 Joe Leslie-Hurd, distributed under the MIT license     *)
(* ========================================================================= *)

structure Pattern :> Pattern =
struct

open Useful;

infixr ==

val op== = Portable.pointerEqual;

(* ------------------------------------------------------------------------- *)
(* Board edges.                                                              *)
(* ------------------------------------------------------------------------- *)

datatype edge = LeftEdge | RightEdge | TopEdge | BottomEdge;

fun compareEdge e1_e2 =
    case e1_e2 of
      (LeftEdge,LeftEdge) => EQUAL
    | (LeftEdge,_) => LESS
    | (_,LeftEdge) => GREATER
    | (RightEdge,RightEdge) => EQUAL
    | (RightEdge,_) => LESS
    | (_,RightEdge) => GREATER
    | (TopEdge,TopEdge) => EQUAL
    | (TopEdge,_) => LESS
    | (_,TopEdge) => GREATER
    | (BottomEdge,BottomEdge) => EQUAL;

fun toIntEdge e =
    case e of
      RightEdge => 0
    | TopEdge => 1
    | LeftEdge => 2
    | BottomEdge => 3;

fun fromIntEdge i =
    case i of
      0 => RightEdge
    | 1 => TopEdge
    | 2 => LeftEdge
    | 3 => BottomEdge
    | _ => raise Bug "Pattern.edgeFromInt";

(* ------------------------------------------------------------------------- *)
(* Integers associated with go positions.                                    *)
(* ------------------------------------------------------------------------- *)

datatype integer =
  (* Primitive integer operations *)
    Integer of int
  | Negate of integer
  | Add of integer * integer
  | Multiply of integer * integer

  (* The number of stones in a block *)
  | StonesBlock of Point.point

  (* The number of liberty edges (ledges) from a block *)
  | LedgesBlock of Point.point;

fun compareInteger i1_i2 =
    case i1_i2 of
      (Integer i1, Integer i2) => Int.compare (i1,i2)
    | (Integer _, _) => LESS
    | (_, Integer _) => GREATER
    | (Negate i1, Negate i2) => compareInteger (i1,i2)
    | (Negate _, _) => LESS
    | (_, Negate _) => GREATER
    | (Add i1, Add i2) => prodCompare compareInteger compareInteger (i1,i2)
    | (Add _, _) => LESS
    | (_, Add _) => GREATER
    | (Multiply i1, Multiply i2) =>
      prodCompare compareInteger compareInteger (i1,i2)
    | (Multiply _, _) => LESS
    | (_, Multiply _) => GREATER
    | (StonesBlock p1, StonesBlock p2) => Point.compare (p1,p2)
    | (StonesBlock _, _) => LESS
    | (_, StonesBlock _) => GREATER
    | (LedgesBlock p1, LedgesBlock p2) => Point.compare (p1,p2);

fun isValidInteger dim =
    let
      fun valid integer =
          case integer of
            Integer _ => true
          | Negate i => valid i
          | Add (i,j) => valid i andalso valid j
          | Multiply (i,j) => valid i andalso valid j
          | StonesBlock p => Dimensions.member p dim
          | LedgesBlock p => Dimensions.member p dim
    in
      valid
    end;

fun transformInteger' sym =
    let
      fun trans integer =
          case integer of
            Integer _ => integer
          | Negate i =>
            let
              val i' = trans i
            in
              if i' == i then integer else Negate i'
            end
          | Add (i,j) =>
            let
              val i' = trans i
              and j' = trans j
            in
              if i' == i andalso j' == j then integer else Add (i',j')
            end
          | Multiply (i,j) =>
            let
              val i' = trans i
              and j' = trans j
            in
              if i' == i andalso j' == j then integer else Multiply (i',j')
            end
          | StonesBlock p =>
            let
              val p' = Symmetry.transformPoint sym p
            in
              if p' == p then integer else StonesBlock p'
            end
          | LedgesBlock p =>
            let
              val p' = Symmetry.transformPoint sym p
            in
              if p' == p then integer else LedgesBlock p'
            end
    in
      trans
    end;

fun transformInteger sym integer =
    if Symmetry.isIdentity sym then integer else transformInteger' sym integer;

(* ------------------------------------------------------------------------- *)
(* Sides in go positions.                                                    *)
(* ------------------------------------------------------------------------- *)

datatype side =
  (* Primitive status operations *)
    Side of Side.side option
  | Opponent of side

  (* Side to move *)
  | SideToMove

  (* Side of a point *)
  | SidePoint of Point.point;

fun compareSide s1_s2 =
    case s1_s2 of
      (Side s1, Side s2) => optionCompare Side.compare (s1,s2)
    | (Side _, _) => LESS
    | (_, Side _) => GREATER
    | (Opponent s1, Opponent s2) => compareSide (s1,s2)
    | (Opponent _, _) => LESS
    | (_, Opponent _) => GREATER
    | (SideToMove,SideToMove) => EQUAL
    | (SideToMove,_) => LESS
    | (_,SideToMove) => GREATER
    | (SidePoint p1, SidePoint p2) => Point.compare (p1,p2);

fun isValidSide dim =
    let
      fun valid side =
          case side of
            Side _ => true
          | Opponent s => valid s
          | SideToMove => true
          | SidePoint p => Dimensions.member p dim
    in
      valid
    end;

fun transformSide' sym =
    let
      fun trans side =
          case side of
            Side NONE => side
          | Side (SOME s) =>
            let
              val s' = Symmetry.transformSide sym s
            in
              if s' == s then side else Side (SOME s')
            end
          | Opponent s =>
            let
              val s' = trans s
            in
              if s' == s then side else Opponent s'
            end
          | SideToMove => side
          | SidePoint p =>
            let
              val p' = Symmetry.transformPoint sym p
            in
              if p' == p then side else SidePoint p'
            end
    in
      trans
    end;

fun transformSide sym side =
    if Symmetry.isIdentity sym then side else transformSide' sym side;

val blackSide = Side (SOME Side.Black)
and whiteSide = Side (SOME Side.White)
and noSide = Side NONE;

(* ------------------------------------------------------------------------- *)
(* Patterns in go positions.                                                 *)
(* ------------------------------------------------------------------------- *)

datatype pattern =
  (* Primitive logical operations *)
    Boolean of bool
  | Not of pattern
  | And of pattern * pattern
  | Or of pattern * pattern
  | Implies of pattern * pattern
  | Iff of pattern * pattern

  (* Integer patterns *)
  | LessThan of integer * integer
  | LessEqual of integer * integer
  | IntegerEqual of integer * integer
  | GreaterEqual of integer * integer
  | GreaterThan of integer * integer

  (* Side patterns *)
  | SideEqual of side * side

  (* Block patterns *)
  | ConnectedBlock of Point.point * Point.point

  (* Edge patterns *)
  | Edge of edge * int;

fun compare p1_p2 =
    case p1_p2 of
      (Boolean b1, Boolean b2) => boolCompare (b1,b2)
    | (Boolean _, _) => LESS
    | (_, Boolean _) => GREATER
    | (Not p1, Not p2) => compare (p1,p2)
    | (Not _, _) => LESS
    | (_, Not _) => GREATER
    | (And p1, And p2) => prodCompare compare compare (p1,p2)
    | (And _, _) => LESS
    | (_, And _) => GREATER
    | (Or p1, Or p2) => prodCompare compare compare (p1,p2)
    | (Or _, _) => LESS
    | (_, Or _) => GREATER
    | (Implies p1, Implies p2) => prodCompare compare compare (p1,p2)
    | (Implies _, _) => LESS
    | (_, Implies _) => GREATER
    | (Iff p1, Iff p2) => prodCompare compare compare (p1,p2)
    | (Iff _, _) => LESS
    | (_, Iff _) => GREATER
    | (LessThan i1, LessThan i2) =>
      prodCompare compareInteger compareInteger (i1,i2)
    | (LessThan _, _) => LESS
    | (_, LessThan _) => GREATER
    | (LessEqual i1, LessEqual i2) =>
      prodCompare compareInteger compareInteger (i1,i2)
    | (LessEqual _, _) => LESS
    | (_, LessEqual _) => GREATER
    | (IntegerEqual i1, IntegerEqual i2) =>
      prodCompare compareInteger compareInteger (i1,i2)
    | (IntegerEqual _, _) => LESS
    | (_, IntegerEqual _) => GREATER
    | (GreaterEqual i1, GreaterEqual i2) =>
      prodCompare compareInteger compareInteger (i1,i2)
    | (GreaterEqual _, _) => LESS
    | (_, GreaterEqual _) => GREATER
    | (GreaterThan i1, GreaterThan i2) =>
      prodCompare compareInteger compareInteger (i1,i2)
    | (GreaterThan _, _) => LESS
    | (_, GreaterThan _) => GREATER
    | (SideEqual s1, SideEqual s2) =>
      prodCompare compareSide compareSide (s1,s2)
    | (SideEqual _, _) => LESS
    | (_, SideEqual _) => GREATER
    | (ConnectedBlock p1, ConnectedBlock p2) =>
      prodCompare Point.compare Point.compare (p1,p2)
    | (ConnectedBlock _, _) => LESS
    | (_, ConnectedBlock _) => GREATER
    | (Edge e1, Edge e2) => prodCompare compareEdge Int.compare (e1,e2);

fun isValid dim =
    let
      val {files,ranks} = dim

      fun valid pattern =
          case pattern of
            Boolean _ => true
          | Not p => valid p
          | And (p,q) => valid p andalso valid q
          | Or (p,q) => valid p andalso valid q
          | Implies (p,q) => valid p andalso valid q
          | Iff (p,q) => valid p andalso valid q
          | LessThan (i,j) =>
            isValidInteger dim i andalso isValidInteger dim j
          | LessEqual (i,j) =>
            isValidInteger dim i andalso isValidInteger dim j
          | IntegerEqual (i,j) =>
            isValidInteger dim i andalso isValidInteger dim j
          | GreaterEqual (i,j) =>
            isValidInteger dim i andalso isValidInteger dim j
          | GreaterThan (i,j) =>
            isValidInteger dim i andalso isValidInteger dim j
          | SideEqual (s,t) => isValidSide dim s andalso isValidSide dim t
          | ConnectedBlock (p,q) =>
            Dimensions.member p dim andalso Dimensions.member q dim
          | Edge (e,i) =>
            case e of
              LeftEdge => i = 0
            | RightEdge => i = files - 1
            | TopEdge => i = ranks - 1
            | BottomEdge => i = 0
    in
      valid
    end;

fun transform' sym =
    let
      val Symmetry.Symmetry {rotate,reflect,files,ranks,...} = sym

      fun transEdge pattern (edge,i) =
          let
            val e = toIntEdge edge
            val e' = (e + rotate) mod 4
            val i' = if (e' <= 1) = (e <= 1) then i else ~i
            val (e',i') =
                if reflect andalso (e' = 0 orelse e' = 2) then (2 - e', ~i')
                else (e',i')
            val i' = i' + (if e' = 0 orelse e' = 2 then files else ranks)
          in
            if e' = e andalso i' = i then pattern
            else Edge (fromIntEdge e', i')
          end

      fun trans pattern =
          case pattern of
            Boolean _ => pattern
          | Not p =>
            let
              val p' = trans p
            in
              if p' == p then pattern else Not p'
            end
          | And (p,q) =>
            let
              val p' = trans p
              and q' = trans q
            in
              if p' == p andalso q' == q then pattern else And (p',q')
            end
          | Or (p,q) =>
            let
              val p' = trans p
              and q' = trans q
            in
              if p' == p andalso q' == q then pattern else Or (p',q')
            end
          | Implies (p,q) =>
            let
              val p' = trans p
              and q' = trans q
            in
              if p' == p andalso q' == q then pattern else Implies (p',q')
            end
          | Iff (p,q) =>
            let
              val p' = trans p
              and q' = trans q
            in
              if p' == p andalso q' == q then pattern else Iff (p',q')
            end
          | LessThan (i,j) =>
            let
              val i' = transformInteger' sym i
              and j' = transformInteger' sym j
            in
              if i' == i andalso j' == j then pattern else LessThan (i',j')
            end
          | LessEqual (i,j) =>
            let
              val i' = transformInteger' sym i
              and j' = transformInteger' sym j
            in
              if i' == i andalso j' == j then pattern else LessEqual (i',j')
            end
          | IntegerEqual (i,j) =>
            let
              val i' = transformInteger' sym i
              and j' = transformInteger' sym j
            in
              if i' == i andalso j' == j then pattern else IntegerEqual (i',j')
            end
          | GreaterEqual (i,j) =>
            let
              val i' = transformInteger' sym i
              and j' = transformInteger' sym j
            in
              if i' == i andalso j' == j then pattern else GreaterEqual (i',j')
            end
          | GreaterThan (i,j) =>
            let
              val i' = transformInteger' sym i
              and j' = transformInteger' sym j
            in
              if i' == i andalso j' == j then pattern else GreaterThan (i',j')
            end
          | SideEqual (s,t) =>
            let
              val s' = transformSide' sym s
              and t' = transformSide' sym t
            in
              if s' == s andalso t' == t then pattern else SideEqual (s',t')
            end
          | ConnectedBlock (p,q) =>
            let
              val p' = Symmetry.transformPoint sym p
              and q' = Symmetry.transformPoint sym q
            in
              if p' == p andalso q' == q then pattern else ConnectedBlock (p',q')
            end
          | Edge e_i => transEdge pattern e_i
    in
      trans
    end;

fun transform sym pattern =
    if Symmetry.isIdentity sym then pattern else transform' sym pattern;

val alwaysTrue = Boolean true
and alwaysFalse = Boolean false;

fun toString (_ : pattern) = "<pattern>";

val pp = Print.ppMap toString Print.ppString;

(* ------------------------------------------------------------------------- *)
(* Booleans.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun destBoolean (Boolean b) = SOME b
  | destBoolean _ = NONE;

fun isBoolean (Boolean _) = true
  | isBoolean _ = false;

fun equalBoolean b pat =
    case pat of
      Boolean b' => b = b'
    | _ => false;

val isTrue = equalBoolean true
and isFalse = equalBoolean false;

(* ------------------------------------------------------------------------- *)
(* Side to move.                                                             *)
(* ------------------------------------------------------------------------- *)

val blackToMove = SideEqual (SideToMove, Side (SOME Side.Black))
and whiteToMove = SideEqual (SideToMove, Side (SOME Side.White));

fun mkSideToMove Side.Black = blackToMove
  | mkSideToMove Side.White = whiteToMove;

fun destSideToMove pat =
    case pat of
      SideEqual (SideToMove, Side s) => s
    | _ => NONE;

fun isSideToMove pat =
    case pat of
      SideEqual (SideToMove, Side (SOME _)) => true
    | _ => false;

fun equalSideToMove s pat =
    case pat of
      SideEqual (SideToMove, Side (SOME s')) => Side.equal s s'
    | _ => false;

(* ------------------------------------------------------------------------- *)
(* Individual points.                                                        *)
(* ------------------------------------------------------------------------- *)

fun mkBlackStone point = SideEqual (SidePoint point, blackSide);

fun mkWhiteStone point = SideEqual (SidePoint point, whiteSide);

fun mkSideStone Side.Black = mkBlackStone
  | mkSideStone Side.White = mkWhiteStone;

fun mkEmpty point = SideEqual (SidePoint point, noSide);

fun mkStone point = Not (mkEmpty point);

(* ------------------------------------------------------------------------- *)
(* Normalized patterns.                                                      *)
(* ------------------------------------------------------------------------- *)

type normalizedPattern = pattern;

fun compareNormalizedConj pat1_pat2 =
    (case pat1_pat2 of
       (SideEqual (p1,q1), SideEqual (p2,q2)) =>
       (case (p1,p2) of
          (SideToMove,SideToMove) =>
          (case (q1,q2) of
             (Side (SOME s1), Side (SOME s2)) =>
             if Side.equal s1 s2 then SOME EQUAL else NONE
           | _ => raise Bug "unsupported SideEqual (SideToMove,*) pattern")
        | (SideToMove,_) => SOME LESS
        | (_,SideToMove) => SOME GREATER
        | (SidePoint pt1, SidePoint pt2) =>
          (case Point.compare (pt1,pt2) of
             EQUAL =>
             (case (q1,q2) of
                (Side s1, Side s2) => if s1 = s2 then SOME EQUAL else NONE
              | _ => raise Bug "unsupported SideEqual (SidePoint _, *) pattern")
           | ord => SOME ord)
        | _ => raise Bug "unsupported SideEqual (*,_) pattern")
     | _ => raise Bug "unsupported pattern")
    handle Bug bug => raise Bug ("Pattern.compareNormalizedConj: " ^ bug);

fun isNormalizedConj pat =
    case pat of
      SideEqual (p,q) =>
      (case p of
         SideToMove =>
         (case q of
            Side (SOME _) => true
          | _ => false)
       | SidePoint pt =>
         (case q of
            Side _ => true
          | _ => false)
       | _ => false)
    | _ => false;

fun strongerNormalizedConj pat1 pat2 =
    case (pat1,pat2) of
      (And (p1,q1), And (p2,q2)) =>
      (case compareNormalizedConj (p1,p2) of
         SOME EQUAL => strongerNormalizedConj q1 q2
       | SOME GREATER => strongerNormalizedConj pat1 q2
       | _ => false)
    | (And _, _) => false
    | (_, And (p2,q2)) =>
      (case compareNormalizedConj (pat1,p2) of
         SOME EQUAL => true
       | SOME GREATER => strongerNormalizedConj pat1 q2
       | _ => false)
    | _ =>
      (case compareNormalizedConj (pat1,pat2) of
         SOME EQUAL => true
       | _ => false);

local
  fun isNormalizedConjs prev pat =
      case pat of
        And (p,q) =>
        isNormalizedConj p andalso
        compareNormalizedConj (prev,p) = SOME LESS andalso
        isNormalizedConjs p q
      | _ =>
        isNormalizedConj pat andalso
        compareNormalizedConj (prev,pat) = SOME LESS;
in
  fun isNormalized pat =
      case pat of
        Boolean _ => true
      | And (p,q) => isNormalizedConj p andalso isNormalizedConjs p q
      | _ => isNormalizedConj pat;
end;

local
  fun revAnd [] acc = acc
    | revAnd (pat :: pats) acc = revAnd' pat pats acc

  and revAnd' pat pats acc = revAnd pats (And (pat,acc));

  fun normAnd acc pat1 pat2 =
      case (pat1,pat2) of
        (And (p1,q1), And (p2,q2)) =>
        (case compareNormalizedConj (p1,p2) of
           NONE => Boolean false
         | SOME LESS => normAnd (p1 :: acc) q1 pat2
         | SOME EQUAL => normAnd (p1 :: acc) q1 q2
         | SOME GREATER => normAnd (p2 :: acc) pat1 q2)
      | (And (p1,q1), _) =>
        (case compareNormalizedConj (p1,pat2) of
           NONE => Boolean false
         | SOME LESS => normAnd (p1 :: acc) q1 pat2
         | SOME EQUAL => revAnd' p1 acc q1
         | SOME GREATER => revAnd' pat2 acc pat1)
      | (_, And (p2,q2)) =>
        (case compareNormalizedConj (pat1,p2) of
           NONE => Boolean false
         | SOME LESS => revAnd' pat1 acc pat2
         | SOME EQUAL => revAnd' pat1 acc q2
         | SOME GREATER => normAnd (p2 :: acc) pat1 q2)
      | _ =>
        (case compareNormalizedConj (pat1,pat2) of
           NONE => Boolean false
         | SOME LESS => revAnd' pat1 acc pat2
         | SOME EQUAL => revAnd acc pat1
         | SOME GREATER => revAnd' pat2 acc pat2);
in
  fun andNormalized pat1 pat2 =
      let
(*GomiDebug
        val _ = isNormalized pat1 orelse raise Bug "pat1 unnormalized"
        val _ = isNormalized pat2 orelse raise Bug "pat2 unnormalized"
*)
      in
        case pat1 of
          Boolean b => if b then pat2 else pat1
        | _ =>
          case pat2 of
            Boolean b => if b then pat1 else pat2
          | _ => normAnd [] pat1 pat2
      end
(*GomiDebug
      handle Bug bug => raise Bug ("Pattern.andNormalized: " ^ bug);
*)
end;

fun normalize pat =
    let
      val pat' =
          case pat of
            And (p,q) => andNormalized (normalize p) (normalize q)
          | _ => pat
(*GomiDebug
      val _ = isNormalized pat' orelse raise Bug "unnormalizable"
*)
    in
      pat'
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Pattern.normalized: " ^ bug);
*)

fun strongerNormalized pat1 pat2 =
    let
(*GomiDebug
      val _ = isNormalized pat1 orelse raise Bug "pat1 unnormalized"
      val _ = isNormalized pat2 orelse raise Bug "pat2 unnormalized"
*)
    in
      case (pat1,pat2) of
        (Boolean b1, Boolean b2) => not b1 orelse b2
      | (Boolean b1, _) => not b1
      | (_, Boolean b2) => b2
      | _ => strongerNormalizedConj pat1 pat2
    end
(*GomiDebug
    handle Bug bug => raise Bug ("Pattern.strongerNormalized: " ^ bug);
*)

fun compareNormalized (pat1,pat2) =
    (case (strongerNormalized pat1 pat2, strongerNormalized pat2 pat1) of
       (true,true) => SOME EQUAL
     | (true,false) => SOME GREATER
     | (false,true) => SOME LESS
     | (false,false) => NONE)
(*GomiDebug
    handle Bug bug => raise Bug ("Pattern.compareNormalized: " ^ bug);
*)

fun strictlyStrongerNormalized pat1 pat2 =
    (case compareNormalized (pat1,pat2) of
       SOME GREATER => true
     | _ => false)
(*GomiDebug
    handle Bug bug => raise Bug ("Pattern.strictlyStrongerNormalized: " ^ bug);
*)

end

structure PatternOrdered =
struct type t = Pattern.pattern val compare = Pattern.compare end

structure PatternMap = KeyMap (PatternOrdered);

structure PatternSet =
struct

local
  structure S = ElementSet (PatternMap);
in
  open S;
end;

val pp =
    Print.ppMap
      toList
      (Print.ppBracket "{" "}" (Print.ppOpList "," Pattern.pp));

end
