improve clientGet types

This commit is contained in:
Joey Hess 2024-07-22 16:23:08 -04:00
parent b697c6b9da
commit 48eb6671e4
No known key found for this signature in database
GPG key ID: DB12DB0FF05F8F38
2 changed files with 38 additions and 27 deletions

View file

@ -169,7 +169,7 @@ testGet = do
[] []
Nothing Nothing
Nothing Nothing
Nothing "outfile"
liftIO $ print res liftIO $ print res
testPut = do testPut = do

View file

@ -38,7 +38,6 @@ import qualified Servant.Types.SourceT as S
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Lazy.Internal as LI import qualified Data.ByteString.Lazy.Internal as LI
import Data.Char
import Control.Concurrent.STM import Control.Concurrent.STM
import Control.Concurrent.Async import Control.Concurrent.Async
import Control.Concurrent import Control.Concurrent
@ -207,7 +206,7 @@ serveGet st su apiver (B64Key k) cu bypass baf startat sec auth = do
-- to the client that it's not valid. -- to the client that it's not valid.
, return ([], B.take (B.length b - 1) b) , return ([], B.take (B.length b - 1) b)
) )
nextchunk szv checkvalid (b:bs) = do nextchunk szv _checkvalid (b:bs) = do
updateszv szv b updateszv szv b
return (bs, b) return (bs, b)
nextchunk _szv checkvalid [] = do nextchunk _szv checkvalid [] = do
@ -269,11 +268,13 @@ clientGet
-> B64UUID ClientSide -> B64UUID ClientSide
-> [B64UUID Bypass] -> [B64UUID Bypass]
-> Maybe B64FilePath -> Maybe B64FilePath
-> Maybe Offset
-> Maybe Auth -> Maybe Auth
-> IO () -> RawFilePath
clientGet clientenv (ProtocolVersion ver) k su cu bypass af o auth = -> IO Validity
withClientM (cli k cu bypass af o auth) clientenv $ \case clientGet clientenv (ProtocolVersion ver) k su cu bypass af auth dest = do
sz <- tryWhenExists $ getFileSize dest
let mo = fmap (Offset . fromIntegral) sz
withClientM (cli k cu bypass af mo auth) clientenv $ \case
Left err -> throwM err Left err -> throwM err
Right respheaders -> do Right respheaders -> do
let dl = case lookupResponseHeader @DataLengthHeader' respheaders of let dl = case lookupResponseHeader @DataLengthHeader' respheaders of
@ -281,8 +282,15 @@ clientGet clientenv (ProtocolVersion ver) k su cu bypass af o auth =
_ -> error "missing data length header" _ -> error "missing data length header"
liftIO $ print ("datalength", dl :: DataLength) liftIO $ print ("datalength", dl :: DataLength)
b <- S.unSourceT (getResponse respheaders) gatherByteString b <- S.unSourceT (getResponse respheaders) gatherByteString
liftIO $ print "got it all, writing to file 'got'" liftIO $ withBinaryFile (fromRawFilePath dest) WriteMode $ \h -> do
L.writeFile "got" b case sz of
Just sz' | sz' /= 0 ->
hSeek h AbsoluteSeek sz'
_ -> noop
L.writeFile (fromRawFilePath dest) b
-- TODO compare dl with the number of bytes written
-- to the file
return Valid
where where
cli =case ver of cli =case ver of
3 -> v3 su V3 3 -> v3 su V3
@ -663,25 +671,28 @@ clientPut clientenv (ProtocolVersion ver) k su cu bypass auth moffset af content
v <- newMVar (0, filter (not . B.null) (L.toChunks bl)) v <- newMVar (0, filter (not . B.null) (L.toChunks bl))
a (go v) a (go v)
where where
go v = S.fromActionStep B.null $ go v = S.fromActionStep B.null $ modifyMVar v $ \case
modifyMVar v $ \case (n, (b:[])) -> do
(n, (b:[])) -> do let !n' = n + B.length b
let !n' = n + B.length b ifM (checkvalid n')
ifM (checkvalid n') ( return ((n', []), b)
( return ((n', []), b) -- The key's content is invalid, but
-- The key's content is invalid, but -- the amount of data is the same as
-- the amount of data is the same as the -- the DataLengthHeader indicates.
-- DataLengthHeader indicates. Truncate -- Truncate the stream by one byte to
-- the stream by one byte to indicate -- indicate to the server that it's
-- to the server that it's not valid. -- not valid.
, return ((n' - 1, []), B.take (B.length b - 1) b) , return
( (n' - 1, [])
, B.take (B.length b - 1) b
) )
(n, []) -> do )
void $ checkvalid n (n, []) -> do
return ((n, []), mempty) void $ checkvalid n
(n, (b:bs)) -> return ((n, []), mempty)
let !n' = n + B.length b (n, (b:bs)) ->
in return ((n', bs), b) let !n' = n + B.length b
in return ((n', bs), b)
checkvalid n = do checkvalid n = do
void $ liftIO $ atomically $ tryPutTMVar checkv () void $ liftIO $ atomically $ tryPutTMVar checkv ()