improve clientGet types

This commit is contained in:
Joey Hess 2024-07-22 16:23:08 -04:00
parent b697c6b9da
commit 48eb6671e4
No known key found for this signature in database
GPG key ID: DB12DB0FF05F8F38
2 changed files with 38 additions and 27 deletions

View file

@ -169,7 +169,7 @@ testGet = do
[]
Nothing
Nothing
Nothing
"outfile"
liftIO $ print res
testPut = do

View file

@ -38,7 +38,6 @@ import qualified Servant.Types.SourceT as S
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Lazy.Internal as LI
import Data.Char
import Control.Concurrent.STM
import Control.Concurrent.Async
import Control.Concurrent
@ -207,7 +206,7 @@ serveGet st su apiver (B64Key k) cu bypass baf startat sec auth = do
-- to the client that it's not valid.
, return ([], B.take (B.length b - 1) b)
)
nextchunk szv checkvalid (b:bs) = do
nextchunk szv _checkvalid (b:bs) = do
updateszv szv b
return (bs, b)
nextchunk _szv checkvalid [] = do
@ -269,11 +268,13 @@ clientGet
-> B64UUID ClientSide
-> [B64UUID Bypass]
-> Maybe B64FilePath
-> Maybe Offset
-> Maybe Auth
-> IO ()
clientGet clientenv (ProtocolVersion ver) k su cu bypass af o auth =
withClientM (cli k cu bypass af o auth) clientenv $ \case
-> RawFilePath
-> IO Validity
clientGet clientenv (ProtocolVersion ver) k su cu bypass af auth dest = do
sz <- tryWhenExists $ getFileSize dest
let mo = fmap (Offset . fromIntegral) sz
withClientM (cli k cu bypass af mo auth) clientenv $ \case
Left err -> throwM err
Right respheaders -> do
let dl = case lookupResponseHeader @DataLengthHeader' respheaders of
@ -281,8 +282,15 @@ clientGet clientenv (ProtocolVersion ver) k su cu bypass af o auth =
_ -> error "missing data length header"
liftIO $ print ("datalength", dl :: DataLength)
b <- S.unSourceT (getResponse respheaders) gatherByteString
liftIO $ print "got it all, writing to file 'got'"
L.writeFile "got" b
liftIO $ withBinaryFile (fromRawFilePath dest) WriteMode $ \h -> do
case sz of
Just sz' | sz' /= 0 ->
hSeek h AbsoluteSeek sz'
_ -> noop
L.writeFile (fromRawFilePath dest) b
-- TODO compare dl with the number of bytes written
-- to the file
return Valid
where
cli =case ver of
3 -> v3 su V3
@ -663,18 +671,21 @@ clientPut clientenv (ProtocolVersion ver) k su cu bypass auth moffset af content
v <- newMVar (0, filter (not . B.null) (L.toChunks bl))
a (go v)
where
go v = S.fromActionStep B.null $
modifyMVar v $ \case
go v = S.fromActionStep B.null $ modifyMVar v $ \case
(n, (b:[])) -> do
let !n' = n + B.length b
ifM (checkvalid n')
( return ((n', []), b)
-- The key's content is invalid, but
-- the amount of data is the same as the
-- DataLengthHeader indicates. Truncate
-- the stream by one byte to indicate
-- to the server that it's not valid.
, return ((n' - 1, []), B.take (B.length b - 1) b)
-- the amount of data is the same as
-- the DataLengthHeader indicates.
-- Truncate the stream by one byte to
-- indicate to the server that it's
-- not valid.
, return
( (n' - 1, [])
, B.take (B.length b - 1) b
)
)
(n, []) -> do
void $ checkvalid n