{-# LANGUAGE LambdaCase #-}
module Math.Haskellator.Internal.AstProcessingSteps.Evaluate (
EvalValue
, evaluate
, execute
, mergeUnits
, subtractUnits
) where
import Control.Monad.Except
import Data.Functor
import Math.Haskellator.Internal.Expr
import Math.Haskellator.Internal.Operators
import Math.Haskellator.Internal.Units
import Math.Haskellator.Internal.Utils.Error
type EvalValue = Value Dimension
evaluate :: Expr
-> Either Error Double
evaluate :: Expr -> Either Error Double
evaluate Expr
expr = Expr -> Either Error EvalValue
execute Expr
expr Either Error EvalValue
-> (EvalValue -> Double) -> Either Error Double
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> EvalValue -> Double
forall u. Value u -> Double
value
execute :: Expr
-> Either Error EvalValue
execute :: Expr -> Either Error EvalValue
execute Expr
expr = do
EvalValue
r <- SimpleAstFold EvalValue -> Either Error EvalValue
forall a. SimpleAstFold a -> Either Error a
runAstFold (SimpleAstFold EvalValue -> Either Error EvalValue)
-> SimpleAstFold EvalValue -> Either Error EvalValue
forall a b. (a -> b) -> a -> b
$ Expr -> SimpleAstFold EvalValue
execute' Expr
expr
EvalValue -> Either Error EvalValue
forall a. a -> Either Error a
forall (m :: * -> *) a. Monad m => a -> m a
return (EvalValue -> Either Error EvalValue)
-> EvalValue -> Either Error EvalValue
forall a b. (a -> b) -> a -> b
$ EvalValue
r { unit = filterZeroPower $ unit r }
execute' :: Expr -> SimpleAstFold EvalValue
execute' :: Expr -> SimpleAstFold EvalValue
execute' = (EvalValue -> SimpleAstFold EvalValue)
-> (EvalValue -> Op -> EvalValue -> SimpleAstFold EvalValue)
-> (Op -> EvalValue -> SimpleAstFold EvalValue)
-> (EvalValue -> Dimension -> SimpleAstFold EvalValue)
-> (Bindings Expr -> Expr -> SimpleAstFold EvalValue)
-> (String -> SimpleAstFold EvalValue)
-> Expr
-> SimpleAstFold EvalValue
forall a.
(EvalValue -> SimpleAstFold a)
-> (a -> Op -> a -> SimpleAstFold a)
-> (Op -> a -> SimpleAstFold a)
-> (a -> Dimension -> SimpleAstFold a)
-> (Bindings Expr -> Expr -> SimpleAstFold a)
-> (String -> SimpleAstFold a)
-> Expr
-> SimpleAstFold a
partiallyFoldExprM EvalValue -> SimpleAstFold EvalValue
execVal EvalValue -> Op -> EvalValue -> SimpleAstFold EvalValue
execBinOp Op -> EvalValue -> SimpleAstFold EvalValue
execUnaryOp EvalValue -> Dimension -> SimpleAstFold EvalValue
execConversion Bindings Expr -> Expr -> SimpleAstFold EvalValue
execVarBinds String -> SimpleAstFold EvalValue
execVar
execVal :: EvalValue -> SimpleAstFold EvalValue
execVal :: EvalValue -> SimpleAstFold EvalValue
execVal = EvalValue -> SimpleAstFold EvalValue
forall a.
a -> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall (m :: * -> *) a. Monad m => a -> m a
return
execBinOp :: EvalValue -> Op -> EvalValue -> SimpleAstFold EvalValue
execBinOp :: EvalValue -> Op -> EvalValue -> SimpleAstFold EvalValue
execBinOp EvalValue
lhs Op
Plus EvalValue
rhs | EvalValue -> Dimension
forall u. Value u -> u
unit EvalValue
lhs Dimension -> Dimension -> Bool
forall a. Eq a => a -> a -> Bool
== EvalValue -> Dimension
forall u. Value u -> u
unit EvalValue
rhs = EvalValue -> SimpleAstFold EvalValue
forall a.
a -> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall (m :: * -> *) a. Monad m => a -> m a
return (EvalValue -> SimpleAstFold EvalValue)
-> EvalValue -> SimpleAstFold EvalValue
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double) -> EvalValue -> EvalValue -> EvalValue
forall u.
Eq u =>
(Double -> Double -> Double) -> Value u -> Value u -> Value u
combineValues Double -> Double -> Double
forall a. Num a => a -> a -> a
(+) EvalValue
lhs EvalValue
rhs
| Bool
otherwise = Error -> SimpleAstFold EvalValue
forall a.
Error
-> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> SimpleAstFold EvalValue)
-> Error -> SimpleAstFold EvalValue
forall a b. (a -> b) -> a -> b
$ Kind -> String -> Error
Error Kind
RuntimeError (String -> Error) -> String -> Error
forall a b. (a -> b) -> a -> b
$ String
"Cannot add units " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Dimension -> String
forall a. Show a => a -> String
show (EvalValue -> Dimension
forall u. Value u -> u
unit EvalValue
lhs) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" and " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Dimension -> String
forall a. Show a => a -> String
show (EvalValue -> Dimension
forall u. Value u -> u
unit EvalValue
rhs)
execBinOp EvalValue
lhs Op
Minus EvalValue
rhs | EvalValue -> Dimension
forall u. Value u -> u
unit EvalValue
lhs Dimension -> Dimension -> Bool
forall a. Eq a => a -> a -> Bool
== EvalValue -> Dimension
forall u. Value u -> u
unit EvalValue
rhs = EvalValue -> SimpleAstFold EvalValue
forall a.
a -> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall (m :: * -> *) a. Monad m => a -> m a
return (EvalValue -> SimpleAstFold EvalValue)
-> EvalValue -> SimpleAstFold EvalValue
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double) -> EvalValue -> EvalValue -> EvalValue
forall u.
Eq u =>
(Double -> Double -> Double) -> Value u -> Value u -> Value u
combineValues (-) EvalValue
lhs EvalValue
rhs
| Bool
otherwise = Error -> SimpleAstFold EvalValue
forall a.
Error
-> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> SimpleAstFold EvalValue)
-> Error -> SimpleAstFold EvalValue
forall a b. (a -> b) -> a -> b
$ Kind -> String -> Error
Error Kind
RuntimeError (String -> Error) -> String -> Error
forall a b. (a -> b) -> a -> b
$ String
"Cannot subtract units " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Dimension -> String
forall a. Show a => a -> String
show (EvalValue -> Dimension
forall u. Value u -> u
unit EvalValue
lhs) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" and " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Dimension -> String
forall a. Show a => a -> String
show (EvalValue -> Dimension
forall u. Value u -> u
unit EvalValue
rhs)
execBinOp EvalValue
lhs Op
Mult EvalValue
rhs = do
let u :: Dimension
u = Dimension -> Dimension -> Dimension
mergeUnits (EvalValue -> Dimension
forall u. Value u -> u
unit EvalValue
lhs) (EvalValue -> Dimension
forall u. Value u -> u
unit EvalValue
rhs)
EvalValue -> SimpleAstFold EvalValue
forall a.
a -> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall (m :: * -> *) a. Monad m => a -> m a
return (EvalValue -> SimpleAstFold EvalValue)
-> EvalValue -> SimpleAstFold EvalValue
forall a b. (a -> b) -> a -> b
$ Double -> Dimension -> EvalValue
forall u. Double -> u -> Value u
Value (EvalValue -> Double
forall u. Value u -> Double
value EvalValue
lhs Double -> Double -> Double
forall a. Num a => a -> a -> a
* EvalValue -> Double
forall u. Value u -> Double
value EvalValue
rhs) Dimension
u
execBinOp EvalValue
lhs Op
Div EvalValue
rhs = do
let u :: Dimension
u = Dimension -> Dimension -> Dimension
subtractUnits (EvalValue -> Dimension
forall u. Value u -> u
unit EvalValue
lhs) (EvalValue -> Dimension
forall u. Value u -> u
unit EvalValue
rhs)
EvalValue -> SimpleAstFold EvalValue
forall a.
a -> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall (m :: * -> *) a. Monad m => a -> m a
return (EvalValue -> SimpleAstFold EvalValue)
-> EvalValue -> SimpleAstFold EvalValue
forall a b. (a -> b) -> a -> b
$ Double -> Dimension -> EvalValue
forall u. Double -> u -> Value u
Value (EvalValue -> Double
forall u. Value u -> Double
value EvalValue
lhs Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ EvalValue -> Double
forall u. Value u -> Double
value EvalValue
rhs) Dimension
u
execBinOp EvalValue
lhs Op
Pow EvalValue
rhs = case EvalValue
rhs of
Value Double
_ [] -> EvalValue -> SimpleAstFold EvalValue
forall a.
a -> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall (m :: * -> *) a. Monad m => a -> m a
return (EvalValue -> SimpleAstFold EvalValue)
-> EvalValue -> SimpleAstFold EvalValue
forall a b. (a -> b) -> a -> b
$ Double -> Dimension -> EvalValue
forall u. Double -> u -> Value u
Value (EvalValue -> Double
forall u. Value u -> Double
value EvalValue
lhs Double -> Double -> Double
forall a. Floating a => a -> a -> a
** EvalValue -> Double
forall u. Value u -> Double
value EvalValue
rhs) ((\UnitExp
u -> UnitExp
u {
power = power u * (round (value rhs) :: Int)
}) (UnitExp -> UnitExp) -> Dimension -> Dimension
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EvalValue -> Dimension
forall u. Value u -> u
unit EvalValue
lhs)
EvalValue
_ -> Error -> SimpleAstFold EvalValue
forall a.
Error
-> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> SimpleAstFold EvalValue)
-> Error -> SimpleAstFold EvalValue
forall a b. (a -> b) -> a -> b
$ Kind -> String -> Error
Error Kind
RuntimeError String
"Exponentiation of units is not supported"
execBinOp EvalValue
_ Op
op EvalValue
_ = Error -> SimpleAstFold EvalValue
forall a.
Error
-> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> SimpleAstFold EvalValue)
-> Error -> SimpleAstFold EvalValue
forall a b. (a -> b) -> a -> b
$ Kind -> String -> Error
Error Kind
ImplementationError (String -> Error) -> String -> Error
forall a b. (a -> b) -> a -> b
$ String
"Unknown binary operator " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Op -> String
forall a. Show a => a -> String
show Op
op
execUnaryOp :: Op -> EvalValue -> SimpleAstFold EvalValue
execUnaryOp :: Op -> EvalValue -> SimpleAstFold EvalValue
execUnaryOp Op
op EvalValue
rhs = case Op
op of
Op
UnaryMinus -> EvalValue -> SimpleAstFold EvalValue
forall a.
a -> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall (m :: * -> *) a. Monad m => a -> m a
return (EvalValue -> SimpleAstFold EvalValue)
-> EvalValue -> SimpleAstFold EvalValue
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> EvalValue -> EvalValue
forall u. (Double -> Double) -> Value u -> Value u
mapValue (Double
0Double -> Double -> Double
forall a. Num a => a -> a -> a
-) EvalValue
rhs
Op
_ -> Error -> SimpleAstFold EvalValue
forall a.
Error
-> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> SimpleAstFold EvalValue)
-> Error -> SimpleAstFold EvalValue
forall a b. (a -> b) -> a -> b
$ Kind -> String -> Error
Error Kind
ImplementationError (String -> Error) -> String -> Error
forall a b. (a -> b) -> a -> b
$ String
"Unknown unary operator " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Op -> String
forall a. Show a => a -> String
show Op
op
execConversion :: EvalValue -> Dimension -> SimpleAstFold EvalValue
execConversion :: EvalValue -> Dimension -> SimpleAstFold EvalValue
execConversion EvalValue
_ Dimension
_ = Error -> SimpleAstFold EvalValue
forall a.
Error
-> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> SimpleAstFold EvalValue)
-> Error -> SimpleAstFold EvalValue
forall a b. (a -> b) -> a -> b
$ Kind -> String -> Error
Error Kind
ImplementationError String
"Conversion is handled elsewhere"
execVarBinds :: Bindings Expr -> Expr -> SimpleAstFold EvalValue
execVarBinds :: Bindings Expr -> Expr -> SimpleAstFold EvalValue
execVarBinds Bindings Expr
bs Expr
expr = SimpleAstFold EvalValue -> SimpleAstFold EvalValue
forall a. SimpleAstFold a -> SimpleAstFold a
runInNewScope (SimpleAstFold EvalValue -> SimpleAstFold EvalValue)
-> SimpleAstFold EvalValue -> SimpleAstFold EvalValue
forall a b. (a -> b) -> a -> b
$ do
Bindings (Thunk EvalValue) -> AstFold EvalValue ()
forall a. Bindings (Thunk a) -> AstFold a ()
bindVars (Bindings (Thunk EvalValue) -> AstFold EvalValue ())
-> Bindings (Thunk EvalValue) -> AstFold EvalValue ()
forall a b. (a -> b) -> a -> b
$ (Expr -> Thunk EvalValue)
-> (String, Expr) -> (String, Thunk EvalValue)
forall a b. (a -> b) -> (String, a) -> (String, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Expr -> Thunk EvalValue
forall a. Expr -> Thunk a
Expr ((String, Expr) -> (String, Thunk EvalValue))
-> Bindings Expr -> Bindings (Thunk EvalValue)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bindings Expr
bs
Expr -> SimpleAstFold EvalValue
execute' Expr
expr
execVar :: String -> SimpleAstFold EvalValue
execVar :: String -> SimpleAstFold EvalValue
execVar String
n = String -> AstFold EvalValue (Thunk EvalValue)
forall a. String -> AstFold a (Thunk a)
getVarBinding String
n AstFold EvalValue (Thunk EvalValue)
-> (Thunk EvalValue -> SimpleAstFold EvalValue)
-> SimpleAstFold EvalValue
forall a b.
ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
-> (a
-> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) b)
-> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
(Result EvalValue
v) -> EvalValue -> SimpleAstFold EvalValue
forall a.
a -> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall (m :: * -> *) a. Monad m => a -> m a
return EvalValue
v
(Expr Expr
e) -> do
EvalValue
result <- Expr -> SimpleAstFold EvalValue
execute' Expr
e
String -> Thunk EvalValue -> AstFold EvalValue ()
forall a. String -> Thunk a -> AstFold a ()
bindVar String
n (Thunk EvalValue -> AstFold EvalValue ())
-> Thunk EvalValue -> AstFold EvalValue ()
forall a b. (a -> b) -> a -> b
$ EvalValue -> Thunk EvalValue
forall a. a -> Thunk a
Result EvalValue
result
EvalValue -> SimpleAstFold EvalValue
forall a.
a -> ExceptT Error (State (Stack (Map String (Thunk EvalValue)))) a
forall (m :: * -> *) a. Monad m => a -> m a
return EvalValue
result
mergeUnits :: Dimension -> Dimension -> Dimension
mergeUnits :: Dimension -> Dimension -> Dimension
mergeUnits Dimension
lhs Dimension
rhs = [UnitExp
x{power = power x + power y} | (UnitExp
x, UnitExp
y) <- [(UnitExp, UnitExp)]
pairs] Dimension -> Dimension -> Dimension
forall a. [a] -> [a] -> [a]
++ Dimension
lr Dimension -> Dimension -> Dimension
forall a. [a] -> [a] -> [a]
++ Dimension
rr
where ([(UnitExp, UnitExp)]
pairs, (Dimension
lr, Dimension
rr)) = Dimension
-> Dimension -> ([(UnitExp, UnitExp)], (Dimension, Dimension))
findPairs Dimension
lhs Dimension
rhs
subtractUnits :: Dimension -> Dimension -> Dimension
subtractUnits :: Dimension -> Dimension -> Dimension
subtractUnits Dimension
lhs Dimension
rhs = [UnitExp
x{power = power x - power y} | (UnitExp
x, UnitExp
y) <- [(UnitExp, UnitExp)]
pairs] Dimension -> Dimension -> Dimension
forall a. [a] -> [a] -> [a]
++ Dimension
lr Dimension -> Dimension -> Dimension
forall a. [a] -> [a] -> [a]
++ (UnitExp -> UnitExp) -> Dimension -> Dimension
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap UnitExp -> UnitExp
flipPower Dimension
rr
where ([(UnitExp, UnitExp)]
pairs, (Dimension
lr, Dimension
rr)) = Dimension
-> Dimension -> ([(UnitExp, UnitExp)], (Dimension, Dimension))
findPairs Dimension
lhs Dimension
rhs
flipPower :: UnitExp -> UnitExp
flipPower (UnitExp Unit
d Int
e) = Unit -> Int -> UnitExp
UnitExp Unit
d (-Int
e)
findPairs :: Dimension -> Dimension -> ([(UnitExp, UnitExp)], (Dimension, Dimension))
findPairs :: Dimension
-> Dimension -> ([(UnitExp, UnitExp)], (Dimension, Dimension))
findPairs [] Dimension
ys = ([], ([], Dimension
ys))
findPairs (UnitExp
x:Dimension
xs) Dimension
ys = let ([(UnitExp, UnitExp)]
pairs, (Dimension
lr', Dimension
rr')) = Dimension
-> Dimension -> ([(UnitExp, UnitExp)], (Dimension, Dimension))
findPairs Dimension
xs Dimension
rr in ([(UnitExp, UnitExp)]
pair [(UnitExp, UnitExp)]
-> [(UnitExp, UnitExp)] -> [(UnitExp, UnitExp)]
forall a. [a] -> [a] -> [a]
++ [(UnitExp, UnitExp)]
pairs, (Dimension
lr Dimension -> Dimension -> Dimension
forall a. [a] -> [a] -> [a]
++ Dimension
lr', Dimension
rr'))
where ([(UnitExp, UnitExp)]
pair, (Dimension
lr, Dimension
rr)) = UnitExp
-> Dimension -> ([(UnitExp, UnitExp)], (Dimension, Dimension))
findPair UnitExp
x Dimension
ys
findPair :: UnitExp -> Dimension -> ([(UnitExp, UnitExp)], (Dimension, Dimension))
findPair :: UnitExp
-> Dimension -> ([(UnitExp, UnitExp)], (Dimension, Dimension))
findPair UnitExp
x [] = ([], ([UnitExp
x], []))
findPair UnitExp
x (UnitExp
y:Dimension
ys) | UnitExp -> Unit
dimUnit UnitExp
x Unit -> Unit -> Bool
forall a. Eq a => a -> a -> Bool
== UnitExp -> Unit
dimUnit UnitExp
y = ([(UnitExp
x, UnitExp
y)], ([], Dimension
ys))
| Bool
otherwise = let ([(UnitExp, UnitExp)]
pair, (Dimension
lr, Dimension
rr)) = UnitExp
-> Dimension -> ([(UnitExp, UnitExp)], (Dimension, Dimension))
findPair UnitExp
x Dimension
ys
in ([(UnitExp, UnitExp)]
pair, (Dimension
lr, UnitExp
yUnitExp -> Dimension -> Dimension
forall a. a -> [a] -> [a]
:Dimension
rr))
filterZeroPower :: Dimension -> Dimension
filterZeroPower :: Dimension -> Dimension
filterZeroPower = (UnitExp -> Bool) -> Dimension -> Dimension
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/=Int
0) (Int -> Bool) -> (UnitExp -> Int) -> UnitExp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnitExp -> Int
power)
mapValue :: (Double -> Double) -> Value u -> Value u
mapValue :: forall u. (Double -> Double) -> Value u -> Value u
mapValue Double -> Double
f (Value Double
v u
u) = Double -> u -> Value u
forall u. Double -> u -> Value u
Value (Double -> Double
f Double
v) u
u
combineValues :: Eq u => (Double -> Double -> Double) -> Value u -> Value u -> Value u
combineValues :: forall u.
Eq u =>
(Double -> Double -> Double) -> Value u -> Value u -> Value u
combineValues Double -> Double -> Double
f (Value Double
v1 u
u1) (Value Double
v2 u
u2) | u
u1 u -> u -> Bool
forall a. Eq a => a -> a -> Bool
== u
u2 = Double -> u -> Value u
forall u. Double -> u -> Value u
Value (Double
v1 Double -> Double -> Double
`f` Double
v2) u
u1
| Bool
otherwise = String -> Value u
forall a. HasCallStack => String -> a
error String
"Cannot map values with different units"