improve clientGet types

This commit is contained in:
Joey Hess 2024-07-22 16:23:08 -04:00
parent b697c6b9da
commit 48eb6671e4
No known key found for this signature in database
GPG key ID: DB12DB0FF05F8F38
2 changed files with 38 additions and 27 deletions

View file

@ -169,7 +169,7 @@ testGet = do
[] []
Nothing Nothing
Nothing Nothing
Nothing "outfile"
liftIO $ print res liftIO $ print res
testPut = do testPut = do

View file

@ -38,7 +38,6 @@ import qualified Servant.Types.SourceT as S
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Lazy.Internal as LI import qualified Data.ByteString.Lazy.Internal as LI
import Data.Char
import Control.Concurrent.STM import Control.Concurrent.STM
import Control.Concurrent.Async import Control.Concurrent.Async
import Control.Concurrent import Control.Concurrent
@ -207,7 +206,7 @@ serveGet st su apiver (B64Key k) cu bypass baf startat sec auth = do
-- to the client that it's not valid. -- to the client that it's not valid.
, return ([], B.take (B.length b - 1) b) , return ([], B.take (B.length b - 1) b)
) )
nextchunk szv checkvalid (b:bs) = do nextchunk szv _checkvalid (b:bs) = do
updateszv szv b updateszv szv b
return (bs, b) return (bs, b)
nextchunk _szv checkvalid [] = do nextchunk _szv checkvalid [] = do
@ -269,11 +268,13 @@ clientGet
-> B64UUID ClientSide -> B64UUID ClientSide
-> [B64UUID Bypass] -> [B64UUID Bypass]
-> Maybe B64FilePath -> Maybe B64FilePath
-> Maybe Offset
-> Maybe Auth -> Maybe Auth
-> IO () -> RawFilePath
clientGet clientenv (ProtocolVersion ver) k su cu bypass af o auth = -> IO Validity
withClientM (cli k cu bypass af o auth) clientenv $ \case clientGet clientenv (ProtocolVersion ver) k su cu bypass af auth dest = do
sz <- tryWhenExists $ getFileSize dest
let mo = fmap (Offset . fromIntegral) sz
withClientM (cli k cu bypass af mo auth) clientenv $ \case
Left err -> throwM err Left err -> throwM err
Right respheaders -> do Right respheaders -> do
let dl = case lookupResponseHeader @DataLengthHeader' respheaders of let dl = case lookupResponseHeader @DataLengthHeader' respheaders of
@ -281,8 +282,15 @@ clientGet clientenv (ProtocolVersion ver) k su cu bypass af o auth =
_ -> error "missing data length header" _ -> error "missing data length header"
liftIO $ print ("datalength", dl :: DataLength) liftIO $ print ("datalength", dl :: DataLength)
b <- S.unSourceT (getResponse respheaders) gatherByteString b <- S.unSourceT (getResponse respheaders) gatherByteString
liftIO $ print "got it all, writing to file 'got'" liftIO $ withBinaryFile (fromRawFilePath dest) WriteMode $ \h -> do
L.writeFile "got" b case sz of
Just sz' | sz' /= 0 ->
hSeek h AbsoluteSeek sz'
_ -> noop
L.writeFile (fromRawFilePath dest) b
-- TODO compare dl with the number of bytes written
-- to the file
return Valid
where where
cli =case ver of cli =case ver of
3 -> v3 su V3 3 -> v3 su V3
@ -663,18 +671,21 @@ clientPut clientenv (ProtocolVersion ver) k su cu bypass auth moffset af content
v <- newMVar (0, filter (not . B.null) (L.toChunks bl)) v <- newMVar (0, filter (not . B.null) (L.toChunks bl))
a (go v) a (go v)
where where
go v = S.fromActionStep B.null $ go v = S.fromActionStep B.null $ modifyMVar v $ \case
modifyMVar v $ \case
(n, (b:[])) -> do (n, (b:[])) -> do
let !n' = n + B.length b let !n' = n + B.length b
ifM (checkvalid n') ifM (checkvalid n')
( return ((n', []), b) ( return ((n', []), b)
-- The key's content is invalid, but -- The key's content is invalid, but
-- the amount of data is the same as the -- the amount of data is the same as
-- DataLengthHeader indicates. Truncate -- the DataLengthHeader indicates.
-- the stream by one byte to indicate -- Truncate the stream by one byte to
-- to the server that it's not valid. -- indicate to the server that it's
, return ((n' - 1, []), B.take (B.length b - 1) b) -- not valid.
, return
( (n' - 1, [])
, B.take (B.length b - 1) b
)
) )
(n, []) -> do (n, []) -> do
void $ checkvalid n void $ checkvalid n