module Math.Haskellator.Internal.Expr (
AstFold
, AstValue
, Bindings
, Expr (..)
, SimpleAstFold
, Thunk (..)
, Value (..)
, bindVar
, bindVars
, foldExpr
, getVarBinding
, partiallyFoldExprM
, runAstFold
, runInNewScope
) where
import Control.Applicative
import Control.Monad.Except
import Control.Monad.State
import Data.List (intercalate)
import Data.Map (Map, insert, (!?))
import Math.Haskellator.Internal.Operators
import Math.Haskellator.Internal.TH.UnitGeneration
import Math.Haskellator.Internal.Units
import Math.Haskellator.Internal.Utils.Composition
import Math.Haskellator.Internal.Utils.Error
import Math.Haskellator.Internal.Utils.Stack
type AstValue = Value Dimension
type Bindings a = [(String, a)]
data Expr = Val AstValue
| BinOp Expr Op Expr
| UnaryOp Op Expr
| Conversion Expr Dimension
| VarBindings (Bindings Expr) Expr
| Var String
data Thunk a = Expr Expr
| Result a
foldExpr :: (AstValue -> a)
-> (a -> Op -> a -> a)
-> (Op -> a -> a)
-> (a -> Dimension -> a)
-> (Bindings a -> a -> a)
-> (String -> a)
-> Expr
-> a
foldExpr :: forall a.
(AstValue -> a)
-> (a -> Op -> a -> a)
-> (Op -> a -> a)
-> (a -> Dimension -> a)
-> (Bindings a -> a -> a)
-> (String -> a)
-> Expr
-> a
foldExpr AstValue -> a
fv a -> Op -> a -> a
fb Op -> a -> a
fu a -> Dimension -> a
fc Bindings a -> a -> a
fvb String -> a
fvn = Expr -> a
doIt
where
doIt :: Expr -> a
doIt (Val AstValue
v) = AstValue -> a
fv AstValue
v
doIt (BinOp Expr
e1 Op
o Expr
e2) = a -> Op -> a -> a
fb (Expr -> a
doIt Expr
e1) Op
o (Expr -> a
doIt Expr
e2)
doIt (UnaryOp Op
o Expr
e) = Op -> a -> a
fu Op
o (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ Expr -> a
doIt Expr
e
doIt (Conversion Expr
e Dimension
u) = a -> Dimension -> a
fc (Expr -> a
doIt Expr
e) Dimension
u
doIt (VarBindings Bindings Expr
bs Expr
e) = Bindings a -> a -> a
fvb ((Expr -> a) -> (String, Expr) -> (String, a)
forall a b. (a -> b) -> (String, a) -> (String, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Expr -> a
doIt ((String, Expr) -> (String, a)) -> Bindings Expr -> Bindings a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bindings Expr
bs) (Expr -> a
doIt Expr
e)
doIt (Var String
n) = String -> a
fvn String
n
type AstFold a b = ExceptT Error (State (Stack (Map String (Thunk a)))) b
type SimpleAstFold a = AstFold a a
getVarBinding :: String
-> AstFold a (Thunk a)
getVarBinding :: forall a. String -> AstFold a (Thunk a)
getVarBinding String
n = do
Stack (Map String (Thunk a))
context <- ExceptT
Error
(State (Stack (Map String (Thunk a))))
(Stack (Map String (Thunk a)))
forall s (m :: * -> *). MonadState s m => m s
get
let maybeValue :: Maybe (Thunk a)
maybeValue = (Maybe (Thunk a) -> Map String (Thunk a) -> Maybe (Thunk a))
-> Maybe (Thunk a)
-> Stack (Map String (Thunk a))
-> Maybe (Thunk a)
forall b a. (b -> a -> b) -> b -> Stack a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Maybe (Thunk a)
a Map String (Thunk a)
m -> Maybe (Thunk a)
a Maybe (Thunk a) -> Maybe (Thunk a) -> Maybe (Thunk a)
forall a. Maybe a -> Maybe a -> Maybe a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Map String (Thunk a)
m Map String (Thunk a) -> String -> Maybe (Thunk a)
forall k a. Ord k => Map k a -> k -> Maybe a
!? String
n)) Maybe (Thunk a)
forall a. Maybe a
Nothing Stack (Map String (Thunk a))
context
case Maybe (Thunk a)
maybeValue of
Just Thunk a
v -> Thunk a -> AstFold a (Thunk a)
forall a.
a -> ExceptT Error (State (Stack (Map String (Thunk a)))) a
forall (m :: * -> *) a. Monad m => a -> m a
return Thunk a
v
Maybe (Thunk a)
Nothing -> Error -> AstFold a (Thunk a)
forall a.
Error -> ExceptT Error (State (Stack (Map String (Thunk a)))) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> AstFold a (Thunk a)) -> Error -> AstFold a (Thunk a)
forall a b. (a -> b) -> a -> b
$ Kind -> String -> Error
Error Kind
RuntimeError (String -> Error) -> String -> Error
forall a b. (a -> b) -> a -> b
$ String
"Variable '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' not in scope"
bindVar :: String -> Thunk a -> AstFold a ()
bindVar :: forall a. String -> Thunk a -> AstFold a ()
bindVar = (Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> ((Map String (Thunk a) -> Map String (Thunk a))
-> Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> (Map String (Thunk a) -> Map String (Thunk a))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map String (Thunk a) -> Map String (Thunk a))
-> Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a))
forall a. (a -> a) -> Stack a -> Stack a
mapTop ((Map String (Thunk a) -> Map String (Thunk a))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> (String
-> Thunk a -> Map String (Thunk a) -> Map String (Thunk a))
-> String
-> Thunk a
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: String -> Thunk a -> Map String (Thunk a) -> Map String (Thunk a)
forall k a. Ord k => k -> a -> Map k a -> Map k a
insert
bindVars :: Bindings (Thunk a) -> AstFold a ()
bindVars :: forall a. Bindings (Thunk a) -> AstFold a ()
bindVars = ((String, Thunk a)
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> [(String, Thunk a)]
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (((String, Thunk a)
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> [(String, Thunk a)]
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> ((String, Thunk a)
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> [(String, Thunk a)]
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall a b. (a -> b) -> a -> b
$ (String
-> Thunk a
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> (String, Thunk a)
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry String
-> Thunk a
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall a. String -> Thunk a -> AstFold a ()
bindVar
runInNewScope :: SimpleAstFold a
-> SimpleAstFold a
runInNewScope :: forall a. SimpleAstFold a -> SimpleAstFold a
runInNewScope SimpleAstFold a
f = do
(Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> (Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall a b. (a -> b) -> a -> b
$ Map String (Thunk a)
-> Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a))
forall a. a -> Stack a -> Stack a
push Map String (Thunk a)
forall a. Monoid a => a
mempty
a
result <- SimpleAstFold a
f
(Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ())
-> (Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a)))
-> ExceptT Error (State (Stack (Map String (Thunk a)))) ()
forall a b. (a -> b) -> a -> b
$ (Map String (Thunk a), Stack (Map String (Thunk a)))
-> Stack (Map String (Thunk a))
forall a b. (a, b) -> b
snd ((Map String (Thunk a), Stack (Map String (Thunk a)))
-> Stack (Map String (Thunk a)))
-> (Stack (Map String (Thunk a))
-> (Map String (Thunk a), Stack (Map String (Thunk a))))
-> Stack (Map String (Thunk a))
-> Stack (Map String (Thunk a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stack (Map String (Thunk a))
-> (Map String (Thunk a), Stack (Map String (Thunk a)))
forall a. Stack a -> (a, Stack a)
pop
a -> SimpleAstFold a
forall a.
a -> ExceptT Error (State (Stack (Map String (Thunk a)))) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
result
runAstFold :: SimpleAstFold a
-> Either Error a
runAstFold :: forall a. SimpleAstFold a -> Either Error a
runAstFold = (State (Stack (Map String (Thunk a))) (Either Error a)
-> Stack (Map String (Thunk a)) -> Either Error a)
-> Stack (Map String (Thunk a))
-> State (Stack (Map String (Thunk a))) (Either Error a)
-> Either Error a
forall a b c. (a -> b -> c) -> b -> a -> c
flip State (Stack (Map String (Thunk a))) (Either Error a)
-> Stack (Map String (Thunk a)) -> Either Error a
forall s a. State s a -> s -> a
evalState (Map String (Thunk a)
-> Stack (Map String (Thunk a)) -> Stack (Map String (Thunk a))
forall a. a -> Stack a -> Stack a
push Map String (Thunk a)
forall a. Monoid a => a
mempty Stack (Map String (Thunk a))
forall a. Monoid a => a
mempty) (State (Stack (Map String (Thunk a))) (Either Error a)
-> Either Error a)
-> (SimpleAstFold a
-> State (Stack (Map String (Thunk a))) (Either Error a))
-> SimpleAstFold a
-> Either Error a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimpleAstFold a
-> State (Stack (Map String (Thunk a))) (Either Error a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT
partiallyFoldExprM :: (AstValue -> SimpleAstFold a)
-> (a -> Op -> a -> SimpleAstFold a)
-> (Op -> a -> SimpleAstFold a)
-> (a -> Dimension -> SimpleAstFold a)
-> (Bindings Expr -> Expr -> SimpleAstFold a)
-> (String -> SimpleAstFold a) -> Expr -> SimpleAstFold a
partiallyFoldExprM :: forall a.
(AstValue -> SimpleAstFold a)
-> (a -> Op -> a -> SimpleAstFold a)
-> (Op -> a -> SimpleAstFold a)
-> (a -> Dimension -> SimpleAstFold a)
-> (Bindings Expr -> Expr -> SimpleAstFold a)
-> (String -> SimpleAstFold a)
-> Expr
-> SimpleAstFold a
partiallyFoldExprM AstValue -> SimpleAstFold a
fv a -> Op -> a -> SimpleAstFold a
fb Op -> a -> SimpleAstFold a
fu a -> Dimension -> SimpleAstFold a
fc Bindings Expr -> Expr -> SimpleAstFold a
fbv String -> SimpleAstFold a
fvar = Expr -> SimpleAstFold a
doIt
where
doIt :: Expr -> SimpleAstFold a
doIt (Val AstValue
v) = AstValue -> SimpleAstFold a
fv AstValue
v
doIt (BinOp Expr
lhs Op
op Expr
rhs) = do
a
l <- Expr -> SimpleAstFold a
doIt Expr
lhs
a
r <- Expr -> SimpleAstFold a
doIt Expr
rhs
a -> Op -> a -> SimpleAstFold a
fb a
l Op
op a
r
doIt (UnaryOp Op
op Expr
e) = do
a
v <- Expr -> SimpleAstFold a
doIt Expr
e
Op -> a -> SimpleAstFold a
fu Op
op a
v
doIt (Conversion Expr
e Dimension
u) = Expr -> SimpleAstFold a
doIt Expr
e SimpleAstFold a -> (a -> SimpleAstFold a) -> SimpleAstFold a
forall a b.
ExceptT Error (State (Stack (Map String (Thunk a)))) a
-> (a -> ExceptT Error (State (Stack (Map String (Thunk a)))) b)
-> ExceptT Error (State (Stack (Map String (Thunk a)))) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
v -> a -> Dimension -> SimpleAstFold a
fc a
v Dimension
u
doIt (VarBindings Bindings Expr
bs Expr
expr) = Bindings Expr -> Expr -> SimpleAstFold a
fbv Bindings Expr
bs Expr
expr
doIt (Var String
n) = String -> SimpleAstFold a
fvar String
n
instance Eq Expr where
Expr
e1 == :: Expr -> Expr -> Bool
== Expr
e2 = Expr -> String
forall a. Show a => a -> String
show Expr
e1 String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== Expr -> String
forall a. Show a => a -> String
show Expr
e2
instance Show Expr where
show :: Expr -> String
show = (AstValue -> String)
-> (String -> Op -> String -> String)
-> (Op -> String -> String)
-> (String -> Dimension -> String)
-> (Bindings String -> String -> String)
-> (String -> String)
-> Expr
-> String
forall a.
(AstValue -> a)
-> (a -> Op -> a -> a)
-> (Op -> a -> a)
-> (a -> Dimension -> a)
-> (Bindings a -> a -> a)
-> (String -> a)
-> Expr
-> a
foldExpr AstValue -> String
forall a. Show a => a -> String
show String -> Op -> String -> String
forall {a}. Show a => String -> a -> String -> String
showBinOp Op -> String -> String
forall {a}. Show a => a -> String -> String
showUnaryOp String -> Dimension -> String
forall {a}. Show a => String -> a -> String
showConversion Bindings String -> String -> String
showVarBinds String -> String
forall a. a -> a
id
where
showBinOp :: String -> a -> String -> String
showBinOp String
e1 a
o String
e2 = String
"(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
e1 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
o String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
e2 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
showUnaryOp :: a -> String -> String
showUnaryOp a
o String
e = String
"(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
o String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
e String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
showConversion :: String -> a -> String
showConversion String
e a
u = String
e String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"[" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
u String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"]"
showVarBinds :: Bindings String -> String -> String
showVarBinds Bindings String
bs String
e = String
"(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " ((String, String) -> String
showVarBind ((String, String) -> String) -> Bindings String -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bindings String
bs) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" -> " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
e String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
showVarBind :: (String, String) -> String
showVarBind (String
n, String
e) = String
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" = " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
e