From 48eb6671e489effd4a796b6564a4971fcf6a566d Mon Sep 17 00:00:00 2001 From: Joey Hess Date: Mon, 22 Jul 2024 16:23:08 -0400 Subject: [PATCH] improve clientGet types --- Command/P2PHttp.hs | 2 +- P2P/Http.hs | 63 +++++++++++++++++++++++++++------------------- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/Command/P2PHttp.hs b/Command/P2PHttp.hs index d565169d96..622dceffd1 100644 --- a/Command/P2PHttp.hs +++ b/Command/P2PHttp.hs @@ -169,7 +169,7 @@ testGet = do [] Nothing Nothing - Nothing + "outfile" liftIO $ print res testPut = do diff --git a/P2P/Http.hs b/P2P/Http.hs index 65433439c3..f09ef222da 100644 --- a/P2P/Http.hs +++ b/P2P/Http.hs @@ -38,7 +38,6 @@ import qualified Servant.Types.SourceT as S import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy.Internal as LI -import Data.Char import Control.Concurrent.STM import Control.Concurrent.Async import Control.Concurrent @@ -207,7 +206,7 @@ serveGet st su apiver (B64Key k) cu bypass baf startat sec auth = do -- to the client that it's not valid. , return ([], B.take (B.length b - 1) b) ) - nextchunk szv checkvalid (b:bs) = do + nextchunk szv _checkvalid (b:bs) = do updateszv szv b return (bs, b) nextchunk _szv checkvalid [] = do @@ -269,11 +268,13 @@ clientGet -> B64UUID ClientSide -> [B64UUID Bypass] -> Maybe B64FilePath - -> Maybe Offset -> Maybe Auth - -> IO () -clientGet clientenv (ProtocolVersion ver) k su cu bypass af o auth = - withClientM (cli k cu bypass af o auth) clientenv $ \case + -> RawFilePath + -> IO Validity +clientGet clientenv (ProtocolVersion ver) k su cu bypass af auth dest = do + sz <- tryWhenExists $ getFileSize dest + let mo = fmap (Offset . fromIntegral) sz + withClientM (cli k cu bypass af mo auth) clientenv $ \case Left err -> throwM err Right respheaders -> do let dl = case lookupResponseHeader @DataLengthHeader' respheaders of @@ -281,8 +282,15 @@ clientGet clientenv (ProtocolVersion ver) k su cu bypass af o auth = _ -> error "missing data length header" liftIO $ print ("datalength", dl :: DataLength) b <- S.unSourceT (getResponse respheaders) gatherByteString - liftIO $ print "got it all, writing to file 'got'" - L.writeFile "got" b + liftIO $ withBinaryFile (fromRawFilePath dest) WriteMode $ \h -> do + case sz of + Just sz' | sz' /= 0 -> + hSeek h AbsoluteSeek sz' + _ -> noop + L.writeFile (fromRawFilePath dest) b + -- TODO compare dl with the number of bytes written + -- to the file + return Valid where cli =case ver of 3 -> v3 su V3 @@ -663,25 +671,28 @@ clientPut clientenv (ProtocolVersion ver) k su cu bypass auth moffset af content v <- newMVar (0, filter (not . B.null) (L.toChunks bl)) a (go v) where - go v = S.fromActionStep B.null $ - modifyMVar v $ \case - (n, (b:[])) -> do - let !n' = n + B.length b - ifM (checkvalid n') - ( return ((n', []), b) - -- The key's content is invalid, but - -- the amount of data is the same as the - -- DataLengthHeader indicates. Truncate - -- the stream by one byte to indicate - -- to the server that it's not valid. - , return ((n' - 1, []), B.take (B.length b - 1) b) + go v = S.fromActionStep B.null $ modifyMVar v $ \case + (n, (b:[])) -> do + let !n' = n + B.length b + ifM (checkvalid n') + ( return ((n', []), b) + -- The key's content is invalid, but + -- the amount of data is the same as + -- the DataLengthHeader indicates. + -- Truncate the stream by one byte to + -- indicate to the server that it's + -- not valid. + , return + ( (n' - 1, []) + , B.take (B.length b - 1) b ) - (n, []) -> do - void $ checkvalid n - return ((n, []), mempty) - (n, (b:bs)) -> - let !n' = n + B.length b - in return ((n', bs), b) + ) + (n, []) -> do + void $ checkvalid n + return ((n, []), mempty) + (n, (b:bs)) -> + let !n' = n + B.length b + in return ((n', bs), b) checkvalid n = do void $ liftIO $ atomically $ tryPutTMVar checkv ()