-- | Models an expression tree
--
-- Examples:
--
-- >>> show $ BinOp (Val $ Value 1.0 $ meter 1) Plus (BinOp (Val $ Value 2 $ multiplier 1) Mult (Val $ Value 3.0 $ meter 1))
-- "(1.0 m + (2.0 * 3.0 m))"
--
-- >>> show $ BinOp (BinOp (Val $ Value 1.0 $ meter 1) Plus (Val $ Value 2.0 $ meter 1)) Mult (Val $ Value 3.0 $ multiplier 1)
-- "((1.0 m + 2.0 m) * 3.0)"

module Math.Haskellator.Internal.Expr (
      AstFold
    , AstValue
    , Bindings
    , Expr (..)
    , SimpleAstFold
    , Thunk (..)
    , Value (..)
    , bindVar
    , bindVars
    , foldExpr
    , getVarBinding
    , partiallyFoldExprM
    , runAstFold
    , runInNewScope
    ) where

import Control.Applicative
import Control.Monad.Except
import Control.Monad.State

import Data.List (intercalate)
import Data.Map (Map, insert, (!?))

import Math.Haskellator.Internal.Operators
import Math.Haskellator.Internal.TH.UnitGeneration
import Math.Haskellator.Internal.Units
import Math.Haskellator.Internal.Utils.Composition
import Math.Haskellator.Internal.Utils.Error
import Math.Haskellator.Internal.Utils.Stack

-- | The specific 'Value' type used in the expression tree
type AstValue = Value Dimension

-- | A list of variable bindings, mapping a name to an arbitrary value
type Bindings a = [(String, a)]

-- | The expression tree
data Expr = Val AstValue -- ^ a literal value
          | BinOp Expr Op Expr -- ^ a binary expression (like +, -, *, /, ^)
          | UnaryOp Op Expr -- ^ a unary expression (like -)
          | Conversion Expr Dimension -- ^ a conversion (1m [km]). If present, this node is the root of the tree.
          | VarBindings (Bindings Expr) Expr -- ^ a variable binding expression
          | Var String

-- | A 'Value' wrapped in a 'Thunk' to allow for lazy evaluation
data Thunk a = Expr Expr -- ^ The unevaluated expression
             | Result a

-- | Folds an expression tree
foldExpr :: (AstValue -> a)        -- ^ function that folds a value
         -> (a -> Op -> a -> a)    -- ^ function that folds a binary expression
         -> (Op -> a -> a)         -- ^ function that folds a unary expression
         -> (a -> Dimension -> a)  -- ^ function that folds a conversion expression
         -> (Bindings a -> a -> a) -- ^ function that folds variable bindings
         -> (String -> a)          -- ^ function that folds a variable
         -> Expr                   -- ^ the 'Expr' to fold over
         -> a                      -- ^ the resulting value
foldExpr :: forall a.
(AstValue -> a)
-> (a -> Op -> a -> a)
-> (Op -> a -> a)
-> (a -> Dimension -> a)
-> (Bindings a -> a -> a)
-> (String -> a)
-> Expr
-> a
foldExpr AstValue -> a
fv a -> Op -> a -> a
fb Op -> a -> a
fu a -> Dimension -> a
fc Bindings a -> a -> a
fvb String -> a
fvn = Expr -> a
doIt
  where
    doIt :: Expr -> a
doIt (Val AstValue
v)            = AstValue -> a
fv AstValue
v
    doIt (BinOp Expr
e1 Op
o Expr
e2)    = a -> Op -> a -> a
fb (Expr -> a
doIt Expr
e1) Op
o (Expr -> a
doIt Expr
e2)
    doIt (UnaryOp Op
o Expr
e)      = Op -> a -> a
fu Op
o (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ Expr -> a
doIt Expr
e
    doIt (Conversion Expr
e Dimension
u)   = a -> Dimension -> a
fc (Expr -> a
doIt Expr
e) Dimension
u
    doIt (VarBindings Bindings Expr
bs Expr
e) = Bindings a -> a -> a
fvb ((Expr -> a) -> (String, Expr) -> (String, a)
forall a b. (a -> b) -> (String, a) -> (String, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Expr -> a
doIt ((String, Expr) -> (String, a)) -> Bindings Expr -> Bindings a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bindings Expr
bs) (Expr -> a
doIt Expr
e)
    doIt (Var String
n)            = String -> a
fvn String
n

-- | Encapsulates the result 'b' of folding an expression tree and holds the current
-- state of variable bindings to values of type 'a'
type AstFold a b = ExceptT Error (State (Stack (Map String (Thunk a)))) b

-- | Simplified version of 'AstFold' that returns the same type as it binds to variables
type SimpleAstFold a = AstFold a a

-- | Retrieves the 'Thunk' bound to a variable name
getVarBinding :: String              -- ^ the variable name
              -> AstFold a (Thunk a) -- ^ the 'Thunk' bound to the variable
getVarBinding :: forall a. String -> AstFold a (Thunk a)
getVarBinding String
n = do
    Stack (Map String (Thunk a))
context <- ExceptT
  Error
  (State (Stack (Map String (Thunk a))))
  (Stack (Map String (Thunk a)))
forall s (m :: * -> *). MonadState s m => m s
get
    let maybeValue :: Maybe (Thunk a)
maybeValue = (Maybe (Thunk a) -> Map String (Thunk a) -> Maybe (Thunk a))
-> Maybe (Thunk a)
-> Stack (Map String (Thunk a))
-> Maybe (Thunk a)
forall b a. (b -> a -> b) -> b -> Stack a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Maybe (Thunk a)
a Map String (Thunk a)
m -> Maybe (Thunk a)
a Maybe (Thunk a) -> Maybe (Thunk a) -> Maybe (Thunk a)
forall a. Maybe a -> Maybe a -> Maybe a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Map String (Thunk a)
m Map String (Thunk a) -> String -> Maybe (Thunk a)
forall k a. Ord k => Map k a -> k -> Maybe a
!? String
n)) Maybe (Thunk a)
forall a. Maybe a
Nothing Stack (Map String (Thunk a))
context
    case Maybe (Thunk a)
maybeValue of
        Just Thunk a
v -> Thunk a -> AstFold a (Thunk a)
forall a.
a -> ExceptT Error (State (Stack (Map String (Thunk a)))) a
forall (m :: * -> *) a. Monad m => a -> m a
return Thunk a
v
        Maybe (Thunk a)
Nothing -> Error -> AstFold a (Thunk a)
forall a.
Error -> ExceptT Error (State (Stack (Map String (Thunk a)))) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> AstFold a (Thunk a)) -> Error -> AstFold a (Thunk a)
forall a b. (a -> b) -> a -> b
$ Kind -> String -> Error
Error Kind
RuntimeError (String -> Error) -> String -> Error
forall a b. (a -> b) -> a -> b
$ String
"Variable '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' not in scope"

-- | Binds a 'Thunk' to a variable name
bindVar :: String -> Thunk a -> AstFold a ()
bindVar :: forall a. String -> Thunk a -> AstFold a ()
bindVar = (Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
 -> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> ((Map String (Thunk a) -> Map String (Thunk a))
    -> Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> (Map String (Thunk a) -> Map String (Thunk a))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map String (Thunk a) -> Map String (Thunk a))
-> Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a))
forall a. (a -> a) -> Stack a -> Stack a
mapTop ((Map String (Thunk a) -> Map String (Thunk a))
 -> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> (String
    -> Thunk a -> Map String (Thunk a) -> Map String (Thunk a))
-> String
-> Thunk a
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: String -> Thunk a -> Map String (Thunk a) -> Map String (Thunk a)
forall k a. Ord k => k -> a -> Map k a -> Map k a
insert

-- | Binds multiple variable names
bindVars :: Bindings (Thunk a) -> AstFold a ()
bindVars :: forall a. Bindings (Thunk a) -> AstFold a ()
bindVars = ((String, Thunk a)
 -> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> [(String, Thunk a)]
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (((String, Thunk a)
  -> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
 -> [(String, Thunk a)]
 -> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> ((String, Thunk a)
    -> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> [(String, Thunk a)]
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall a b. (a -> b) -> a -> b
$ (String
 -> Thunk a
 -> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> (String, Thunk a)
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry String
-> Thunk a
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall a. String -> Thunk a -> AstFold a ()
bindVar

-- | Evaluates a 'SimpleAstFold' inside a new (and empty) scope
runInNewScope :: SimpleAstFold a -- ^ the computation to run
              -> SimpleAstFold a -- ^ the computation's result
runInNewScope :: forall a. SimpleAstFold a -> SimpleAstFold a
runInNewScope SimpleAstFold a
f = do
    (Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
 -> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> (Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall a b. (a -> b) -> a -> b
$ Map String (Thunk a)
-> Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a))
forall a. a -> Stack a -> Stack a
push Map String (Thunk a)
forall a. Monoid a => a
mempty
    a
result <- SimpleAstFold a
f
    (Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
 -> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> (Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall a b. (a -> b) -> a -> b
$ (Map String (Thunk a), Stack (Map String (Thunk a)))
-> Stack (Map String (Thunk a))
forall a b. (a, b) -> b
snd ((Map String (Thunk a), Stack (Map String (Thunk a)))
 -> Stack (Map String (Thunk a)))
-> (Stack (Map String (Thunk a))
    -> (Map String (Thunk a), Stack (Map String (Thunk a))))
-> Stack (Map String (Thunk a))
-> Stack (Map String (Thunk a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stack (Map String (Thunk a))
-> (Map String (Thunk a), Stack (Map String (Thunk a)))
forall a. Stack a -> (a, Stack a)
pop
    a -> SimpleAstFold a
forall a.
a -> ExceptT Error (State (Stack (Map String (Thunk a)))) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
result

-- | Runs an 'SimpleAstFold' computation
runAstFold :: SimpleAstFold a -- ^ the computation to run
           -> Either Error a  -- ^ the computation's result or an error
runAstFold :: forall a. SimpleAstFold a -> Either Error a
runAstFold = (State (Stack (Map String (Thunk a))) (Either Error a)
 -> Stack (Map String (Thunk a)) -> Either Error a)
-> Stack (Map String (Thunk a))
-> State (Stack (Map String (Thunk a))) (Either Error a)
-> Either Error a
forall a b c. (a -> b -> c) -> b -> a -> c
flip State (Stack (Map String (Thunk a))) (Either Error a)
-> Stack (Map String (Thunk a)) -> Either Error a
forall s a. State s a -> s -> a
evalState (Map String (Thunk a)
-> Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a))
forall a. a -> Stack a -> Stack a
push Map String (Thunk a)
forall a. Monoid a => a
mempty Stack (Map String (Thunk a))
forall a. Monoid a => a
mempty) (State (Stack (Map String (Thunk a))) (Either Error a)
 -> Either Error a)
-> (SimpleAstFold a
    -> State (Stack (Map String (Thunk a))) (Either Error a))
-> SimpleAstFold a
-> Either Error a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimpleAstFold a
-> State (Stack (Map String (Thunk a))) (Either Error a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT

-- | Like 'foldExpr', but does not fold into variable bindings and returns a monadic
-- result
partiallyFoldExprM :: (AstValue -> SimpleAstFold a)
 -> (a -> Op -> a -> SimpleAstFold a)
 -> (Op -> a -> SimpleAstFold a)
 -> (a -> Dimension -> SimpleAstFold a)
 -> (Bindings Expr -> Expr -> SimpleAstFold a)
 -> (String -> SimpleAstFold a) -> Expr -> SimpleAstFold a
partiallyFoldExprM :: forall a.
(AstValue -> SimpleAstFold a)
-> (a -> Op -> a -> SimpleAstFold a)
-> (Op -> a -> SimpleAstFold a)
-> (a -> Dimension -> SimpleAstFold a)
-> (Bindings Expr -> Expr -> SimpleAstFold a)
-> (String -> SimpleAstFold a)
-> Expr
-> SimpleAstFold a
partiallyFoldExprM AstValue -> SimpleAstFold a
fv a -> Op -> a -> SimpleAstFold a
fb Op -> a -> SimpleAstFold a
fu a -> Dimension -> SimpleAstFold a
fc Bindings Expr -> Expr -> SimpleAstFold a
fbv String -> SimpleAstFold a
fvar = Expr -> SimpleAstFold a
doIt
    where
        doIt :: Expr -> SimpleAstFold a
doIt (Val AstValue
v) = AstValue -> SimpleAstFold a
fv AstValue
v
        doIt (BinOp Expr
lhs Op
op Expr
rhs) = do
            a
l <- Expr -> SimpleAstFold a
doIt Expr
lhs
            a
r <- Expr -> SimpleAstFold a
doIt Expr
rhs
            a -> Op -> a -> SimpleAstFold a
fb a
l Op
op a
r
        doIt (UnaryOp Op
op Expr
e) = do
            a
v <- Expr -> SimpleAstFold a
doIt Expr
e
            Op -> a -> SimpleAstFold a
fu Op
op a
v
        doIt (Conversion Expr
e Dimension
u) = Expr -> SimpleAstFold a
doIt Expr
e SimpleAstFold a -> (a -> SimpleAstFold a) -> SimpleAstFold a
forall a b.
ExceptT Error (State (Stack (Map String (Thunk a)))) a
-> (a -> ExceptT Error (State (Stack (Map String (Thunk a)))) b)
-> ExceptT Error (State (Stack (Map String (Thunk a)))) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
v -> a -> Dimension -> SimpleAstFold a
fc a
v Dimension
u
        doIt (VarBindings Bindings Expr
bs Expr
expr) = Bindings Expr -> Expr -> SimpleAstFold a
fbv Bindings Expr
bs Expr
expr
        doIt (Var String
n) = String -> SimpleAstFold a
fvar String
n

instance Eq Expr where
    Expr
e1 == :: Expr -> Expr -> Bool
== Expr
e2 = Expr -> String
forall a. Show a => a -> String
show Expr
e1 String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== Expr -> String
forall a. Show a => a -> String
show Expr
e2

instance Show Expr where
    show :: Expr -> String
show = (AstValue -> String)
-> (String -> Op -> String -> String)
-> (Op -> String -> String)
-> (String -> Dimension -> String)
-> (Bindings String -> String -> String)
-> (String -> String)
-> Expr
-> String
forall a.
(AstValue -> a)
-> (a -> Op -> a -> a)
-> (Op -> a -> a)
-> (a -> Dimension -> a)
-> (Bindings a -> a -> a)
-> (String -> a)
-> Expr
-> a
foldExpr AstValue -> String
forall a. Show a => a -> String
show String -> Op -> String -> String
forall {a}. Show a => String -> a -> String -> String
showBinOp Op -> String -> String
forall {a}. Show a => a -> String -> String
showUnaryOp String -> Dimension -> String
forall {a}. Show a => String -> a -> String
showConversion Bindings String -> String -> String
showVarBinds String -> String
forall a. a -> a
id
      where
        showBinOp :: String -> a -> String -> String
showBinOp String
e1 a
o String
e2  = String
"(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
e1 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
o String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
e2 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
        showUnaryOp :: a -> String -> String
showUnaryOp a
o String
e    = String
"(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
o String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
e String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
        showConversion :: String -> a -> String
showConversion String
e a
u = String
e String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"[" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
u String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"]"
        showVarBinds :: Bindings String -> String -> String
showVarBinds Bindings String
bs String
e = String
"(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " ((String, String) -> String
showVarBind ((String, String) -> String) -> Bindings String -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bindings String
bs) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" -> " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
e String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
        showVarBind :: (String, String) -> String
showVarBind (String
n, String
e) = String
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" = " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
e