(* ========================================================================= *)
(* IMAGES                                                                    *)
(* Copyright (c) 2004 Joe Leslie-Hurd, distributed under the MIT license     *)
(* ========================================================================= *)

structure Image :> Image =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Image types.                                                              *)
(* ------------------------------------------------------------------------- *)

datatype colour = Colour of {red : real, green : real, blue : real};

datatype pixel = Pixel of {colour : colour, opacity : real};

datatype coord = Coord of {x : int, y : int};

datatype image = Image of {width : int, height : int, pixels : coord -> pixel};

(* ------------------------------------------------------------------------- *)
(* Checking invariants.                                                      *)
(* ------------------------------------------------------------------------- *)

(*BasicDebug
fun checkInterval r =
    (if not (Real.isFinite r) then raise Bug "not finite"
     else if r < 0.0 then raise Bug ("negative: " ^ Real.toString r)
     else if r > 1.0 then raise Bug ("overflow: " ^ Real.toString r)
     else ())
    handle Bug bug => raise Bug ("Image.checkInterval:\n" ^ bug);

fun checkColour c =
    let
      val Colour {red,green,blue} = c

      val () =
          checkInterval red
          handle Bug bug => raise Bug ("red component:\n" ^ bug)

      val () =
          checkInterval green
          handle Bug bug => raise Bug ("green component:\n" ^ bug)

      val () =
          checkInterval blue
          handle Bug bug => raise Bug ("blue component:\n" ^ bug)
    in
      ()
    end
    handle Bug bug => raise Bug ("Image.checkColour:\n" ^ bug);

fun checkPixel p =
    let
      val Pixel {colour,opacity} = p

      val () =
          checkColour colour
          handle Bug bug => raise Bug ("colour component:\n" ^ bug)

      val () =
          checkInterval opacity
          handle Bug bug => raise Bug ("opacity component:\n" ^ bug)
    in
      ()
    end
    handle Bug bug => raise Bug ("Image.checkPixel:\n" ^ bug);
*)

(* ------------------------------------------------------------------------- *)
(* Colours.                                                                  *)
(* ------------------------------------------------------------------------- *)

fun grey z = Colour {red = z, green = z, blue = z};

val black = grey 0.0;

val white = grey 1.0;

val red = Colour {red = 1.0, green = 0.0, blue = 0.0};

val green = Colour {red = 0.0, green = 1.0, blue = 0.0};

val blue = Colour {red = 0.0, green = 0.0, blue = 1.0};

fun colourDiff colour1 colour2 =
    let
      val Colour {red = r1, green = g1, blue = b1} = colour1
      and Colour {red = r2, green = g2, blue = b2} = colour2

      val rd = r2 - r1
      and gd = g2 - g1
      and bd = b2 - b1
    in
      Math.sqrt ((rd * rd + gd * gd + bd * bd) / 3.0)
    end;

fun colourMult colour1 colour2 =
    let
      val Colour {red = r1, green = g1, blue = b1} = colour1
      and Colour {red = r2, green = g2, blue = b2} = colour2

      val r = r1 * r2
      and g = g1 * g2
      and b = b1 * b2
    in
      Colour {red = r, green = g, blue = b}
    end;

local
  val max = Real.fromInt 255

  fun mk i = Real.fromInt i / max;
in
  fun fromRgbColour (r,g,b) =
      let
        val red = mk r
        and green = mk g
        and blue = mk b
      in
        Colour {red = red, green = green, blue = blue}
      end;
end;

(* ------------------------------------------------------------------------- *)
(* Pixels.                                                                   *)
(* ------------------------------------------------------------------------- *)

val transparent = Pixel {colour = black, opacity = 0.0};

fun solid c = Pixel {colour = c, opacity = 1.0};

fun pixelMap f (Pixel {colour = c, opacity = z}) =
    Pixel {colour = f c, opacity = z};

local
  fun normalize p q =
      let
        val sum = p + q
      in
        if sum <= 0.0 then (0.5,0.5)
        else
          let
            val t = p / sum
          in
            (t, 1.0 - t)
          end
      end;

(*BasicDebug
  val normalize = fn p => fn q =>
      let
        val () =
            if Real.isFinite p then ()
            else raise Bug "argument p not finite"

        val () =
            if Real.isFinite q then ()
            else raise Bug "argument q not finite"

        val r as (t,t') = normalize p q

        val () =
            if Real.isFinite t then ()
            else raise Bug "result t not finite"

        val () =
            if Real.isFinite t' then ()
            else raise Bug "result t' not finite"
      in
        r
      end
      handle Bug bug =>
        raise Bug ("Image.superimposePixels.normalize:\n" ^ bug);
*)
in
  fun superimposePixels superP subP =
      let
        val Pixel {colour = superC, opacity = superZ} = superP
        and Pixel {colour = subC, opacity = subZ} = subP

        val Colour {red = superR, green = superG, blue = superB} = superC
        and Colour {red = subR, green = subG, blue = subB} = subC

        val subZ' = (1.0 - superZ) * subZ

        val z = superZ + subZ'

        val (t,t') = normalize superZ subZ'

        val r = t * superR + t' * subR
        and g = t * superG + t' * subG
        and b = t * superB + t' * subB

        val c = Colour {red = r, green = g, blue = b}
      in
        Pixel {colour = c, opacity = z}
      end;
end;

(*BasicDebug
val superimposePixels = fn superP => fn subP =>
    let
      val () =
          checkPixel superP
          handle Bug bug => raise Bug ("argument superP:\n" ^ bug)

      val () =
          checkPixel subP
          handle Bug bug => raise Bug ("argument subP:\n" ^ bug)

      val p = superimposePixels superP subP

      val () =
          checkPixel p
          handle Bug bug => raise Bug ("result:\n" ^ bug)
    in
      p
    end
    handle Bug bug => raise Bug ("Image.superimposePixels:\n" ^ bug);
*)

(* ------------------------------------------------------------------------- *)
(* Image transformations.                                                    *)
(* ------------------------------------------------------------------------- *)

fun calculate image =
    let
      val Image {width,height,pixels} = image

      fun f (y,x) = pixels (Coord {x = x, y = y})

      val a = Array2.tabulate Array2.RowMajor (height,width,f)

      fun pixels (Coord {x,y}) = Array2.sub (a,y,x)
    in
      Image {width = width, height = height, pixels = pixels}
    end;

fun cmap f (Image {width = w, height = h, pixels = p}) =
    Image {width = w, height = h, pixels = S f p};

fun map f = cmap (K f);

fun crop {topLeft,bottomRight} image =
    let
      val Coord {x = x1, y = y1} = topLeft
      and Coord {x = x2, y = y2} = bottomRight

      val Image {width,height,pixels} = image

      val w = x2 - x1
      and h = y2 - y1

      fun p (Coord {x,y}) = pixels (Coord {x = x + x1, y = y + y1})
    in
      Image {width = w, height = h, pixels = p}
    end;

fun cropDiagonal image =
    let
      val Image {width,height,pixels} = image

      fun p (c as Coord {x,y}) = if x > y then transparent else pixels c
    in
      Image {width = width, height = height, pixels = p}
    end;

fun chopVertical {top,size} image =
    let
      val Image {width,height,pixels} = image

      val h = height - size

      fun p (Coord {x,y}) =
          let
            val y = if y < top then y else y + size
          in
            pixels (Coord {x = x, y = y})
          end
    in
      Image {width = width, height = h, pixels = p}
    end;

fun chopHorizontal {left,size} image =
    let
      val Image {width,height,pixels} = image

      val w = width - size

      fun p (Coord {x,y}) =
          let
            val x = if x < left then x else x + size
          in
            pixels (Coord {x = x, y = y})
          end
    in
      Image {width = w, height = height, pixels = p}
    end;

fun superimpose {sub,topLeft,super} =
    let
      val Image {width = w, height = h, pixels = superPixels} = super
      and Coord {x = left, y = top} = topLeft

      fun inSuper (Coord {x,y}) =
          left <= x andalso top <= y andalso x - left < w andalso y - top < h

      val Image {width, height, pixels = subPixels} = sub

      fun pixels v =
          let
            val subP = subPixels v
          in
            if not (inSuper v) then subP
            else
              let
                val Coord {x,y} = v

                val superP = superPixels (Coord {x = x - left, y = y - top})
              in
                superimposePixels superP subP
              end
          end
    in
      Image {width = width, height = height, pixels = pixels}
    end;

fun centre {sub,super} =
    let
      val Image {width = superW, height = superH, pixels = _} = super
      and Image {width, height, pixels = _} = sub

      val left = (width - superW) div 2
      and top = (height - superH) div 2

      val topLeft = Coord {x = left, y = top}
    in
      superimpose {sub = sub, topLeft = topLeft, super = super}
    end;

fun frame (borderPixel,border) image =
    let
      val Image {width = w, height = h, pixels = p} = image

      fun pixels v =
          let
            val Coord {x,y} = v
          in
            if x < border orelse w + border <= x orelse
               y < border orelse h + border <= y
            then borderPixel
            else p (Coord {x = x - border, y = y - border})
          end

      val width = w + 2 * border
      and height = h + 2 * border
    in
      Image {width = width, height = height, pixels = pixels}
    end;

fun colourScreen (colour',threshold) =
    let
      fun pred (Pixel {colour,...}) = colourDiff colour' colour < threshold

      fun screen p = if pred p then transparent else p
    in
      map screen
    end;

val greenScreen = colourScreen (green,0.01);

fun shadow {distance, opacity = opacityScore, colour} image =
    let
      val d = Real.floor (distance + 0.5 * Math.sqrt 2.0)

      val Image {width,height,pixels} = image

      fun calcOpacity (x1,y1) (x,y) =
          let
            fun within (x2,y2) =
                let
                  val xd = x2 - x1 and yd = y2 - y1

                  val d = Math.sqrt (xd * xd + yd * yd)
                in
                  if d < distance then 1.0
                  else if distance + 1.0 < d then 0.0
                  else d - distance
                end

            val xa = Real.fromInt x - 0.5 and ya = Real.fromInt y - 0.5

            val xb = xa + 1.0 and yb = ya + 1.0

            val score = opacityScore (pixels (Coord {x = x, y = y}))
          in
            score * (within (xa,ya) + within (xa,yb) +
                     within (xb,ya) + within (xb,yb)) / 4.0
          end

      fun maxOpacity v =
          let
            val Pixel {opacity,...} = pixels v

            val Coord {x,y} = v

            val xcoords = interval (x - d) (2 * d + 1)
            and ycoords = interval (y - d) (2 * d + 1)

            fun filt (x',y') =
                0 <= x' andalso x' < width andalso 0 <= y'
                andalso y' < height andalso (x <> x' orelse y <> y')

            val coords = List.filter filt (cart xcoords ycoords)

            val x_y = (Real.fromInt x, Real.fromInt y)
          in
            if opacity >= 1.0 then 1.0
            else
              let
                fun f (x_y',m) = Real.max (m, calcOpacity x_y x_y')
              in
                List.foldl f opacity coords
              end
          end

      fun p v = Pixel {colour = colour, opacity = maxOpacity v}
    in
      Image {width = width, height = height, pixels = p}
    end;

fun mirrorHorizontal image =
    let
      val Image {width, height, pixels = p} = image

      val maxHeight = height - 1

      fun pixels (Coord {x,y}) = p (Coord {x = x, y = maxHeight - y})
    in
      Image {width = width, height = height, pixels = pixels}
    end;

fun mirrorVertical image =
    let
      val Image {width, height, pixels = p} = image

      val maxWidth = width - 1

      fun pixels (Coord {x,y}) = p (Coord {x = maxWidth - x, y = y})
    in
      Image {width = width, height = height, pixels = pixels}
    end;

fun mirrorDiagonal image =
    let
      val Image {width = w, height = h, pixels = p} = image

      fun pixels (Coord {x,y}) = p (Coord {x = y, y = x})

      val width = h
      and height = w
    in
      Image {width = width, height = height, pixels = pixels}
    end;

(* ------------------------------------------------------------------------- *)
(* I/O.                                                                      *)
(* ------------------------------------------------------------------------- *)

local
  open Parse;

  infixr 9 >>++
  infixr 8 ++
  infixr 7 >>
  infixr 6 ||

  fun stripComments s =
      case hdTl (String.fields (equal #"#") s) of (_,[]) => s | (s,_) => s;

  val charParser = any >> (String.explode o stripComments);

  val numParser =
      atLeastOne (some Char.isDigit) >>
      (fn l =>
       case Int.fromString (String.implode l) of
         NONE => raise NoParse
       | SOME i => i);

  datatype tok = Alpha of char | Num of int;

  fun destNum (Num x) = x | destNum _ = raise Error "Image.destNum";

  val isNum = can destNum;

  val tokParser =
      numParser >> Num
      || (some Char.isAlpha >> Alpha)
      || (any >> (fn c => raise Error ("bad char '" ^ str c ^ "'")));

  val tokenParser = (many (some Char.isSpace) ++ tokParser) >> snd;

  val tokensParser =
      (tokenParser >> singleton)
      || atLeastOne (some Char.isSpace) >> (K []);

  val numParser = some isNum >> destNum;

  val headerParser =
      ((some (equal (Alpha #"P")) ++ some (equal (Num 3)))
       ++ numParser ++ numParser ++ numParser)
      >> (fn (_,(w,(h,m))) => {width = w, height = h, maxval = m});

  fun pixelParser maxval =
      let
        fun scale x =
            if x < 0 then raise Error "pixelParser: negative value"
            else if maxval < x then raise Error "pixelParser: value too big"
            else Real.fromInt x / Real.fromInt maxval
      in
        (numParser ++ numParser ++ numParser)
        >> (fn (r,(g,b)) =>
            Colour {red = scale r, green = scale g, blue = scale b})
      end;

  fun mkImage width height pixelStream =
      let
        val array = Array2.array (height,width,black)

        fun pixels (Coord {x,y}) = solid (Array2.sub (array,y,x))

        fun populate (c,(x,y)) =
            let
              val _ = y <> height orelse raise Error "too many pixels"

              val () = Array2.update (array,y,x,c)

              val (x,y) = if x = width - 1 then (0, y + 1) else (x + 1, y)
            in
              (x,y)
            end

        val (x,y) = Stream.foldl populate (0,0) pixelStream

        val _ = y = height orelse raise Error "not enough pixels"
      in
        Image {width = width, height = height, pixels = pixels}
      end;
in
  fun fromPlainPpm filename =
      (let
         val lines = Stream.fromTextFile filename

         val chars = everything charParser lines

         val toks = everything tokensParser chars

         val ({width,height,maxval},toks) = headerParser toks

         val pixels = everything (pixelParser maxval >> singleton) toks

         val image = mkImage width height pixels
       in
         image
       end
       handle NoParse => raise Error "parse error")
      handle Error err => raise Error ("Image.fromPlainPpm: " ^ err)
end;

fun toPlainPpm {filename,background,image} =
    let
      val maxValue = 255

      val Image {width,height,pixels} = image

      val backgroundPixel = solid background

      fun pixelColour (x,y) =
          let
            val p = pixels (Coord {x = x, y = y})

            val Pixel {colour,...} = superimposePixels p backgroundPixel
          in
            colour
          end

      fun toValue r =
          Int.toString (Real.round (r * Real.fromInt maxValue))

      fun outputPixels (x,y) () =
          if y = height then Stream.Nil
          else if x = width then outputPixels (0, y + 1) ()
          else
            let
              val Colour {red = r, blue = b, green = g} = pixelColour (x,y)

              val line = toValue r ^ " " ^ toValue g ^ " " ^ toValue b ^ "\n"
            in
              Stream.cons line (outputPixels (x + 1, y))
            end
    in
      Stream.toTextFile {filename = filename}
      (Stream.cons
       ("P3\n" ^
        Int.toString width ^ " " ^
        Int.toString height ^ " " ^
        Int.toString maxValue ^ "\n")
       (outputPixels (0,0)))
    end;

end
