diff --git a/Messages.hs b/Messages.hs index c6d6022bc5..c8395ff4c7 100644 --- a/Messages.hs +++ b/Messages.hs @@ -5,7 +5,7 @@ - Licensed under the GNU AGPL version 3 or higher. -} -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE OverloadedStrings, ScopedTypeVariables #-} module Messages ( showStartMessage, @@ -56,6 +56,7 @@ import Control.Monad.IO.Class import qualified Data.ByteString as S import qualified Data.ByteString.Char8 as S8 import System.Exit +import qualified Control.Monad.Catch as M import Common import Types @@ -326,11 +327,14 @@ mkPrompter = getConcurrency >>= \case (\v -> putMVar l v >> cleanup) (const $ run a) -{- Catch all (non-async) exceptions and display, santizing any control - - characters in the exceptions. Exits nonzero on exception, so should only - - be used at topmost level. -} +{- Catch all (non-async and not ExitCode) exceptions and display, + - santizing any control characters in the exceptions. + - + - Exits nonzero on exception, so should only be used at topmost level. + -} sanitizeTopLevelExceptionMessages :: IO a -> IO a -sanitizeTopLevelExceptionMessages a = catchNonAsync a go +sanitizeTopLevelExceptionMessages a = a `catches` + ((M.Handler (\ (e :: ExitCode) -> throwM e)) : nonAsyncHandler go) where go e = do warningIO (show e) diff --git a/Utility/Exception.hs b/Utility/Exception.hs index 3d1a69912a..cf55c5fefc 100644 --- a/Utility/Exception.hs +++ b/Utility/Exception.hs @@ -1,6 +1,6 @@ {- Simple IO exception handling (and some more) - - - Copyright 2011-2016 Joey Hess + - Copyright 2011-2023 Joey Hess - - License: BSD-2-clause -} @@ -20,6 +20,7 @@ module Utility.Exception ( bracketIO, catchNonAsync, tryNonAsync, + nonAsyncHandler, tryWhenExists, catchIOErrorType, IOErrorType(..), @@ -28,8 +29,7 @@ module Utility.Exception ( import Control.Monad.Catch as X hiding (Handler) import qualified Control.Monad.Catch as M -import Control.Exception (IOException, AsyncException) -import Control.Exception (SomeAsyncException) +import Control.Exception (IOException, AsyncException, SomeAsyncException) import Control.Monad import Control.Monad.IO.Class (liftIO, MonadIO) import System.IO.Error (isDoesNotExistError, ioeGetErrorType) @@ -85,11 +85,7 @@ bracketIO setup cleanup = bracket (liftIO setup) (liftIO . cleanup) - ThreadKilled and UserInterrupt get through. -} catchNonAsync :: MonadCatch m => m a -> (SomeException -> m a) -> m a -catchNonAsync a onerr = a `catches` - [ M.Handler (\ (e :: AsyncException) -> throwM e) - , M.Handler (\ (e :: SomeAsyncException) -> throwM e) - , M.Handler (\ (e :: SomeException) -> onerr e) - ] +catchNonAsync a onerr = a `catches` (nonAsyncHandler onerr) tryNonAsync :: MonadCatch m => m a -> m (Either SomeException a) tryNonAsync a = go `catchNonAsync` (return . Left) @@ -98,6 +94,13 @@ tryNonAsync a = go `catchNonAsync` (return . Left) v <- a return (Right v) +nonAsyncHandler :: MonadCatch m => (SomeException -> m a) -> [M.Handler m a] +nonAsyncHandler onerr = + [ M.Handler (\ (e :: AsyncException) -> throwM e) + , M.Handler (\ (e :: SomeAsyncException) -> throwM e) + , M.Handler (\ (e :: SomeException) -> onerr e) + ] + {- Catches only DoesNotExist exceptions, and lets all others through. -} tryWhenExists :: MonadCatch m => m a -> m (Maybe a) tryWhenExists a = do