Servant, persistent, and DSLs

A few days ago I hinted on how I used servant with persistent on reddit. My approach is about using a few DSLs to simplify the definition of the web application, and to make sure that all important side effects (authentication, access control, database failures, etc.) are handled in a single place. In this post I will define a DSL for access right management and another for writing the webservice itself.

I tried to extract the gist of my engine in that post, so this might seem a bit overengineered for such a simple example. Also I just made sure it compiled, and did not test it, so it might be buggy :)

This post is a literate Haskell file. It assumes you are already knowledgeable about servant, persistent and operational (for the last one, being familiar with free will be enough).

{-# LANGUAGE KitchenSink #-}

I am using servant, persistent and lens in this example, so there is quite a bit of boilerplate, as expected.

{-# LANGUAGE ConstraintKinds            #-}
{-# LANGUAGE DataKinds                  #-}
{-# LANGUAGE DeriveGeneric              #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE OverloadedStrings          #-}
{-# LANGUAGE PolyKinds                  #-}
{-# LANGUAGE QuasiQuotes                #-}
{-# LANGUAGE TemplateHaskell            #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE TypeOperators              #-}

module Main where

import           Control.Lens
import           Control.Monad
import           Control.Monad.Error.Class
import           Control.Monad.IO.Class
import           Control.Monad.Logger
import qualified Control.Monad.Operational  as O
import           Control.Monad.Operational  hiding (view)
import           Control.Monad.Reader       (ask)
import           Control.Monad.Trans.Either
import           Data.Aeson
import           Data.Int
import qualified Data.Foldable              as F
import           Database.Persist.Sql
import           Database.Persist.Sqlite
import           Database.Persist.TH
import           Data.Text.Lens
import           Data.Text                  (Text)
import qualified Network.Wai.Handler.Warp
import           Servant

import AccessType

The AccessType module just exports the following definition, along with all the instances required to use it with servant and persistent :

data AccessType = NoAccess | ReadOnly | ReadWrite | Owner

Persistent model

A person is a user of the blog. They can have a special attribute that marks them as administrators, meaning they have all rights.

share [ mkPersist sqlSettings { mpsGenerateLenses = True }
      , mkMigrate "migrateAll"]
      [persistLowerCase|
Person json
    name Text
    isAdmin Bool
    UniqueName name
    deriving Show

The BlogPost definition is pretty basic.

BlogPost json
    title Text
    content Text
    deriving Show

Finally we have a table that associates an access right to a person. When no matching record exists, the access right is equivalent to NoAccess.

PostRights json
    person PersonId
    post BlogPostId
    access AccessType
    UniqueRight person post
    deriving Show
|]

API definition

I use a few tricks to define the API. The first one was stolen from this reddit comment:

type CRUD a = DN :> ReqBody '[JSON] a :> Post '[JSON] (MKey a) -- create
         :<|> DN :> Capture "id" (MKey a) :> Get '[JSON] a -- read
         :<|> DN :> Capture "id" (MKey a) :> ReqBody '[JSON] a :> Put '[JSON] () -- update
         :<|> DN :> Capture "id" (MKey a) :> Delete '[JSON] () -- delete

It lets you factor the four endpoints you will need for a basic CRUD interface, and will prove handy during the API definition. It uses some helper types:

type DN = Header "dn" Text

DN is a header that is provided by the web server. There is a web-facing server that handles the TLS authentication, adds the dn header when it succeeds or drops the connection when it fails. This field contains the distinguished name of the certificate the user presented. I can trust this value as much as I can trust my PKI, and will not verify it in the application code.

newtype MKey a = MKey { getMKey :: Int64 }
               deriving (FromJSON, ToJSON, FromText)

_MKey :: ToBackendKey SqlBackend a => Iso' (MKey a) (Key a)
_MKey = iso (toSqlKey . getMKey) (MKey . fromSqlKey)

The _MKey is a newtype that is isomorphic to persistent's Key. It is needed because Key does not have the instances I need to use it with servant. In this case the user-facing id of the entities will be their primary key in the database. Don't try this at home!

The API itself is just made of CRUD enpoints for Person and BlogPost :

type MyApi = "person" :> CRUD Person
        :<|> "post"   :> CRUD BlogPost

myApi :: Proxy MyApi
myApi = Proxy

Permission checking DSL

The first DSL that I am going to introduce is for permission checking. This will look overengineered in this toy program, but it turns out I can have funky permissions for my use case.

I use the operational package, mostly because I am used to it. It is overkill as this particular DSL probably only requires an Applicative instance. But as the next DSL requires a full Monad instance, both will be defined using the same machinery to keep things consistent.

Let's start with defining the actions :

type PermProgram = Program PermCheck
data PermCheck a where
    IsAdmin       :: PermCheck Bool
    BlogPostRight :: Key BlogPost -> PermCheck AccessType

I just have a pair of them here, but you should add one each time there is a new kind of access type you will need to enforce.

Now here are a few helper functions :

-- checks that the current user is an administrator
isAdmin :: PermProgram Bool
isAdmin = singleton IsAdmin

-- returns the access right of the current user on a particular blog post
blogPostRight :: Key BlogPost -> PermProgram AccessType
blogPostRight = singleton . BlogPostRight

-- checks that an `AccessType` is at least `ReadOnly`
ro :: PermProgram AccessType -> PermProgram Bool
ro = fmap (>= ReadOnly)

-- checks that an `AccessType` is at least `ReadWrite`
rw :: PermProgram AccessType -> PermProgram Bool
rw = fmap (>= ReadWrite)

-- checks that an `AccessType` is at least `Owner`
owner :: PermProgram AccessType -> PermProgram Bool
owner = fmap (>= Owner)

-- helper function for actions that everyone can perform
always :: PermProgram Bool
always = pure True

-- operator for "or-ing" two permission checking actions
(.||) :: PermProgram Bool -> PermProgram Bool -> PermProgram Bool
(.||) = liftM2 (||)

-- operator for "and-ing" two permission checking actions
(.&&) :: PermProgram Bool -> PermProgram Bool -> PermProgram Bool
(.&&) = liftM2 (&&)

And here is the kind of expression you can write :

isAdmin .|| (ro (blogPostRight postid) .&& owner (commentRight commentid))

Web application DSL

This DSL is a bit more complicated. It exposes primitives for interacting safely with the database and a way to cancel the computation and throw errors at the user. We expect the error throwing part to cancel the current SQL transaction and spit an error message to the webservice user.

The instructions are a direct mapping of their counterpart in persistent. The WebService monad is not an instance of MonadError ServantErr because "catching" an exception would not make sense (it would require nested SQL transactions, which I know nothing about). As all sort of constraints on our values are required to make persistent happy, the PC constraint synonym as been created.

In the real life, you will probably need a more versatile database access (arbitrary selects and deletes using esqueleto), logging and other goodies. They are not required for our simple example.

type WebService = Program WebAction
type PC val = (PersistEntityBackend val ~ SqlBackend, PersistEntity val)
data WebAction a where
    Throw :: ServantErr               -> WebAction a
    Get   :: PC val => Key val        -> WebAction (Maybe val)
    Del   :: PC val => Key val        -> WebAction ()
    GetBy :: PC val => Unique val     -> WebAction (Maybe (Entity val))
    New   :: PC val =>            val -> WebAction (Key val)
    Upd   :: PC val => Key val -> val -> WebAction ()

-- throws an error
throw :: ServantErr -> WebService a
throw = singleton . Throw

-- dual of `persistent`'s `get`
mget :: PC val => Key val -> WebService (Maybe val)
mget = singleton . Get

-- dual of `persistent`'s `getBy`
mgetBy :: PC val => Unique val ->  WebService (Maybe (Entity val))
mgetBy = singleton . GetBy

-- dual of `persistent`'s `insert`
mnew :: PC val => val ->  WebService (Key val)
mnew = singleton . New

-- dual of `persistent`'s `update`
mupd :: PC val => Key val -> val -> WebService ()
mupd k v = singleton (Upd k v)

-- dual of `persistent`'s `delete`
mdel :: PC val => Key val -> WebService ()
mdel = singleton . Del

-- like `mget` but throws a 404 if it could not find the corresponding record
mgetOr404 :: PC val => Key val -> WebService val
mgetOr404 = mget >=> maybe (throw err404) return

-- like `mgetBy` but throws a 404 if it could not find the corresponding record
mgetByOr404 :: PC val => Unique val -> WebService (Entity val)
mgetByOr404 = mgetBy >=> maybe (throw err404) return

Evaluating the permissions checking DSL

Now that we have defined the web application DSL, it's trivial to evaluate our expression checking DSL into it.

-- Given the current user, runs a `PermProgram` in the `WebService` monad.
checkPerms :: Entity Person -> PermProgram a -> WebService a
checkPerms ent cnd = eval (O.view cnd)
    where
        usr = entityVal ent
        userkey = entityKey ent
        eval :: ProgramView PermCheck a -> WebService a
        eval (Return a) = return a
        eval (IsAdmin :>>= nxt) =
              checkPerms ent (nxt (_personIsAdmin usr))

There might be no matching record when retrieving the access rights a given user has on a given blog post. In that case, NoAccess should be used.

        eval (BlogPostRight k :>>= nxt) =
              mgetBy (UniqueRight userkey k)
                  >>= checkPerms ent
                    . nxt
                    . maybe NoAccess (_postRightsAccess . entityVal)

Evaluating the web application DSL

Now we are going to turn the WebService type into something that can be used with persistent :

type ServantIO a = SqlPersistT (LoggingT (EitherT ServantErr IO)) a

The conversion practically writes itself, except for handling rollbacks. There is only one thing you should take care of: never use a bare throwError, except in the Throw handler, as the whole point of this exercise is to ensure that sessions are properly rollbacked in case of errors.

runServant :: WebService a -> ServantIO a
runServant ws = case O.view ws of
                  Return a -> return a
                  a :>>= f -> runM a f

runM :: WebAction a -> (a -> WebService b) -> ServantIO b
runM x f = case x of
    Throw rr@(ServantErr c rs _ _) -> do
                  conn <- ask
                  liftIO $ connRollback conn (getStmtConn conn)
                  logOtherNS "WS" LevelError (show (c,rs) ^. packed)
                  throwError rr
    Get k    -> get k       >>= tsf
    New v    -> insert v    >>= tsf
    Del v    -> delete v    >>= tsf
    GetBy u  -> getBy u     >>= tsf
    Upd k v  -> replace k v >>= tsf
  where
      tsf = runServant . f

Implementing the API

I start with defining a record that will hold the permission checking functions for my four CRUD actions:

data PermsFor a = PermsFor { _newPerms :: PermProgram Bool
                           , _getPerms :: Key a -> PermProgram Bool
                           , _updPerms :: Key a -> PermProgram Bool
                           , _delPerms :: Key a -> PermProgram Bool
                           }

adminOnly :: PermsFor a
adminOnly = PermsFor isAdmin (const isAdmin) (const isAdmin) (const isAdmin)

The runCrud function handles most of the logic. It needs as arguments a connection pool, the previously defined record, and two optional functions.

The first optional function is used when creating a new item in the database. It is here to setup appropriate rights for the object that was just created (set its owner).

The second optional function is used to perform additional cleanup before object deletion. Its main use is to remove cross references.

As a result you get all four crud actions defined and ready to use by the serve function!

runCrud :: (PersistEntity a, ToBackendKey SqlBackend a, PC b)
        => ConnectionPool -- ^ Connection pool
        -> PermsFor a -- ^ Permission checking record
        -> Maybe (Key Person -> Key a -> AccessType -> b)
           -- ^ Extra actions after creation
        -> Maybe (Key a -> WebService ()) -- ^ Extra actions after deletion
        -> (      Maybe Text           -> a -> EitherT ServantErr IO (MKey a))
           :<|> ((Maybe Text -> MKey a      -> EitherT ServantErr IO a)
           :<|> ((Maybe Text -> MKey a -> a -> EitherT ServantErr IO ())
           :<|> ( Maybe Text -> MKey a      -> EitherT ServantErr IO ()))
           )
runCrud pool (PermsFor pnew pget pupd pdel) rightConstructor predelete =
          runnew :<|> runget :<|> runupd :<|> rundel
    where
        auth Nothing _ = throw err401
        auth (Just dn) perm = do
            user <- mgetBy (UniqueName dn) >>= maybe (throw err403) return
            check <- checkPerms user perm
            unless check (throw err403)
            return user
        runget dn mk = runQuery $ do
            let k = mk ^. _MKey
            void $ auth dn (pget k)
            mgetOr404 k
        runnew dn val = runQuery $ do
            usr <- auth dn pnew
            k <- mnew val
            F.mapM_ (\c -> mnew (c (entityKey usr) k Owner)) rightConstructor
            return (k ^. from _MKey)
        runupd dn mk val = runQuery $ do
            let k = mk ^. _MKey
            void $ auth dn (pupd k)
            mupd k val
        rundel dn mk = runQuery $ do
            let k = mk ^. _MKey
            void $ auth dn (pdel k)
            F.mapM_ ($ k) predelete
            mdel k
        runQuery :: WebService a -> EitherT ServantErr IO a
        runQuery ws = runStderrLoggingT $ runSqlPool (runServant ws) pool

The last touch is a default null action that is going to be used when you do not need to setup extra stuff after object creation. You cannot just throw a Nothing as the type inference will fail (in the definition of runCrud you can see there is a b that needs to be known, even if it is not used). I picked randomly one of the available types that satisfy the PersistEntity constraint.

-- A default action for when you need not run additional actions after creation
noCreateRightAdjustment :: Maybe (Key Person -> Key a -> AccessType -> PostRights)
noCreateRightAdjustment = Nothing

Serving the API

The server function takes a ConnectionPool and can be directly served by ... the serve function from servant. It should be painless to add additional CRUD endpoints.

server :: ConnectionPool -> Server MyApi
server pool =
      runCrud pool adminOnly noCreateRightAdjustment Nothing
 :<|> defaultCrud blogPostRight PostRights Nothing
   where
     editRights c cid = rw (c cid) .|| isAdmin
     delRights c cid = owner (c cid) .|| isAdmin
     defaultPermissions c =
          PermsFor always
                  (const always)
                  (editRights c)
                  (delRights c)
     defaultCrud c r d = runCrud pool (defaultPermissions c)
                                 (Just r) d

Finally, the main function creates the connection pool, run the migration scripts and starts the web service.

main :: IO ()
main = do
  pool <- runStderrLoggingT $ do
      p <- createSqlitePool ":memory:" 1
      runSqlPool (runMigration migrateAll) p
      return p
  Network.Wai.Handler.Warp.run 8080 (serve myApi (server pool))

Conclusion

I am not sure how this post went. When I started it, I expected it to be just a few lines long, but it turned out I had to add all kind of parts to have a complete example that somebody else could compile and run. Hopefully it demonstrates how easy it is to integrate different libraries with the help of DSL glue.

What it didn't show is how useful this might be for testing, as you "just" need to write a pure variant of runServant that will mock the database endpoints to verify your logic.

This might also look overengineered, but I wrote all this machinery for a webservice that has around 100 endpoints, so the ability to concisely describe the endpoint logic and expected access rights pays off well.