From 74c617579513c0f79db1f410e6c96d9871ae9300 Mon Sep 17 00:00:00 2001 From: Joey Hess Date: Thu, 11 Jul 2024 09:55:17 -0400 Subject: [PATCH] fix serveGet early handle close Needed that waitv after all.. --- Annex/Proxy.hs | 4 ++-- Command/P2PHttp.hs | 2 +- P2P/Http.hs | 31 ++++++++++++++--------------- P2P/Http/State.hs | 49 ++++++++++++++++++++++++++++------------------ P2P/IO.hs | 19 ++++++++++-------- 5 files changed, 59 insertions(+), 46 deletions(-) diff --git a/Annex/Proxy.hs b/Annex/Proxy.hs index ffa732a08d..4563579ef2 100644 --- a/Annex/Proxy.hs +++ b/Annex/Proxy.hs @@ -59,8 +59,8 @@ proxySpecialRemoteSide clientmaxversion r = mkRemoteSide r $ do let remoteconn = P2PConnection { connRepo = Nothing , connCheckAuth = const False - , connIhdl = P2PHandleTMVar ihdl (Just iwaitv) - , connOhdl = P2PHandleTMVar ohdl (Just owaitv) + , connIhdl = P2PHandleTMVar ihdl iwaitv + , connOhdl = P2PHandleTMVar ohdl owaitv , connIdent = ConnIdent (Just (Remote.name r)) } let closeremoteconn = do diff --git a/Command/P2PHttp.hs b/Command/P2PHttp.hs index c7b0d05e7a..2f5eb11114 100644 --- a/Command/P2PHttp.hs +++ b/Command/P2PHttp.hs @@ -163,7 +163,7 @@ testGet = do burl <- liftIO $ parseBaseUrl "http://localhost:8080/" res <- liftIO $ clientGet (mkClientEnv mgr burl) (P2P.ProtocolVersion 3) - (B64Key (fromJust $ deserializeKey ("WORM-s3218-m1720641607--passwd" :: String))) + (B64Key (fromJust $ deserializeKey ("SHA256E-s1048576000--e3b67ce72aa2571c799d6419e3e36828461ac1c78f8ef300c7f9c8ae671c517f" :: String))) (B64UUID (toUUID ("cu" :: String))) (B64UUID (toUUID ("f11773f0-11e1-45b2-9805-06db16768efe" :: String))) [] diff --git a/P2P/Http.hs b/P2P/Http.hs index c8b3d438a2..8f55f65e26 100644 --- a/P2P/Http.hs +++ b/P2P/Http.hs @@ -32,7 +32,6 @@ import Utility.Metered import Servant import Servant.Client.Streaming -import Servant.API import qualified Servant.Types.SourceT as S import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as L @@ -150,8 +149,7 @@ serveGet -> Maybe Auth -> Handler (Headers '[DataLengthHeader] (S.SourceT IO B.ByteString)) serveGet st apiver (B64Key k) cu su bypass baf startat sec auth = do - (runst, conn, releaseconn) <- - getP2PConnection apiver st cu su bypass sec auth ReadAction + conn <- getP2PConnection apiver st cu su bypass sec auth ReadAction bsv <- liftIO newEmptyTMVarIO endv <- liftIO newEmptyTMVarIO validityv <- liftIO newEmptyTMVarIO @@ -160,15 +158,17 @@ serveGet st apiver (B64Key k) cu su bypass baf startat sec auth = do let storer _offset len = sendContentWith $ \bs -> do liftIO $ atomically $ putTMVar bsv (len, bs) liftIO $ atomically $ takeTMVar endv + liftIO $ signalFullyConsumedByteString $ + connOhdl $ serverP2PConnection conn return $ \v -> do liftIO $ atomically $ putTMVar validityv v return True v <- enteringStage (TransferStage Upload) $ - runFullProto runst conn $ + runFullProto (clientRunState conn) (clientP2PConnection conn) $ void $ receiveContent Nothing nullMeterUpdate sizer storer getreq return v - liftIO $ forkIO $ waitfinal endv finalv releaseconn annexworker + void $ liftIO $ forkIO $ waitfinal endv finalv conn annexworker (Len len, bs) <- liftIO $ atomically $ takeTMVar bsv bv <- liftIO $ newMVar (L.toChunks bs) let streamer = S.SourceT $ \s -> s =<< return @@ -206,7 +206,7 @@ serveGet st apiver (B64Key k) cu su bypass baf startat sec auth = do , pure mempty ) - waitfinal endv finalv releaseconn annexworker = do + waitfinal endv finalv conn annexworker = do -- Wait for everything to be transferred before -- stopping the annexworker. The validityv will usually -- be written to at the end. If the client disconnects @@ -215,8 +215,8 @@ serveGet st apiver (B64Key k) cu su bypass baf startat sec auth = do -- Make sure the annexworker is not left blocked on endv -- if the client disconnected early. liftIO $ atomically $ tryPutTMVar endv () - void $ tryNonAsync $ wait annexworker - void $ tryNonAsync releaseconn + void $ void $ tryNonAsync $ wait annexworker + void $ tryNonAsync $ releaseP2PConnection conn sizer = pure $ Len $ case startat of Just (Offset o) -> fromIntegral o @@ -301,8 +301,7 @@ serveCheckPresent -> Handler CheckPresentResult serveCheckPresent st apiver (B64Key k) cu su bypass sec auth = do res <- withP2PConnection apiver st cu su bypass sec auth ReadAction - $ \runst conn -> - liftIO $ runNetProto runst conn $ checkPresent k + $ \conn -> liftIO $ proxyClientNetProto conn $ checkPresent k case res of Right b -> return (CheckPresentResult b) Left err -> throwError $ err500 { errBody = encodeBL err } @@ -354,8 +353,8 @@ serveRemove -> Handler t serveRemove st resultmangle apiver (B64Key k) cu su bypass sec auth = do res <- withP2PConnection apiver st cu su bypass sec auth RemoveAction - $ \runst conn -> - liftIO $ runNetProto runst conn $ remove Nothing k + $ \conn -> + liftIO $ proxyClientNetProto conn $ remove Nothing k case res of (Right b, plusuuids) -> return $ resultmangle $ RemoveResultPlus b (map B64UUID (fromMaybe [] plusuuids)) @@ -411,8 +410,8 @@ serveRemoveBefore -> Handler RemoveResultPlus serveRemoveBefore st apiver (B64Key k) cu su bypass (Timestamp ts) sec auth = do res <- withP2PConnection apiver st cu su bypass sec auth RemoveAction - $ \runst conn -> - liftIO $ runNetProto runst conn $ + $ \conn -> + liftIO $ proxyClientNetProto conn $ removeBeforeRemoteEndTime ts k case res of (Right b, plusuuids) -> return $ @@ -464,8 +463,8 @@ serveGetTimestamp -> Handler GetTimestampResult serveGetTimestamp st apiver cu su bypass sec auth = do res <- withP2PConnection apiver st cu su bypass sec auth ReadAction - $ \runst conn -> - liftIO $ runNetProto runst conn getTimestamp + $ \conn -> + liftIO $ proxyClientNetProto conn getTimestamp case res of Right ts -> return $ GetTimestampResult (Timestamp ts) Left err -> throwError $ diff --git a/P2P/Http/State.hs b/P2P/Http/State.hs index 8e9cae3fa6..8e90bc3025 100644 --- a/P2P/Http/State.hs +++ b/P2P/Http/State.hs @@ -17,6 +17,7 @@ import Annex.Common import qualified Annex import P2P.Http.Types import qualified P2P.Protocol as P2P +import qualified P2P.IO as P2P import P2P.IO import P2P.Annex import Annex.UUID @@ -62,15 +63,14 @@ withP2PConnection -> IsSecure -> Maybe Auth -> ActionClass - -> (RunState -> P2PConnection -> Handler (Either ProtoFailure a)) + -> (P2PConnectionPair -> Handler (Either ProtoFailure a)) -> Handler a withP2PConnection apiver st cu su bypass sec auth actionclass connaction = do - (runst, conn, releaseconn) <- - getP2PConnection apiver st cu su bypass sec auth actionclass - connaction' runst conn - `finally` liftIO releaseconn + conn <- getP2PConnection apiver st cu su bypass sec auth actionclass + connaction' conn + `finally` liftIO (releaseP2PConnection conn) where - connaction' runst conn = connaction runst conn >>= \case + connaction' conn = connaction conn >>= \case Right r -> return r Left err -> throwError $ err500 { errBody = encodeBL (describeProtoFailure err) } @@ -85,7 +85,7 @@ getP2PConnection -> IsSecure -> Maybe Auth -> ActionClass - -> Handler (RunState, P2PConnection, ReleaseP2PConnection) + -> Handler P2PConnectionPair getP2PConnection apiver st cu su bypass sec auth actionclass = case (getServerMode st sec auth, actionclass) of (Just P2P.ServeReadWrite, _) -> go P2P.ServeReadWrite @@ -130,16 +130,20 @@ data ConnectionProblem | TooManyConnections deriving (Show, Eq) -type AcquireP2PConnection = - ConnectionParams -> IO - ( Either ConnectionProblem - ( RunState - , P2PConnection - , ReleaseP2PConnection -- ^ release connection - ) - ) +data P2PConnectionPair = P2PConnectionPair + { clientRunState :: RunState + , clientP2PConnection :: P2PConnection + , serverP2PConnection :: P2PConnection + , releaseP2PConnection :: IO () + } -type ReleaseP2PConnection = IO () +proxyClientNetProto :: P2PConnectionPair -> P2P.Proto a -> IO (Either P2P.ProtoFailure a) +proxyClientNetProto conn = runNetProto + (clientRunState conn) (clientP2PConnection conn) + +type AcquireP2PConnection + = ConnectionParams + -> IO (Either ConnectionProblem P2PConnectionPair) {- Acquire P2P connections to the local repository. -} -- TODO need worker pool, this can only service a single request at @@ -177,8 +181,10 @@ withLocalP2PConnections a = do else do hdl1 <- liftIO newEmptyTMVarIO hdl2 <- liftIO newEmptyTMVarIO - let h1 = P2PHandleTMVar hdl1 Nothing - let h2 = P2PHandleTMVar hdl2 Nothing + wait1 <- liftIO newEmptyTMVarIO + wait2 <- liftIO newEmptyTMVarIO + let h1 = P2PHandleTMVar hdl1 wait1 + let h2 = P2PHandleTMVar hdl2 wait2 let serverconn = P2PConnection Nothing (const True) h1 h2 (ConnIdent (Just "http server")) @@ -196,7 +202,12 @@ withLocalP2PConnections a = do =<< forkState protorunner let releaseconn = atomically $ putTMVar relv $ join (liftIO (wait asyncworker)) - return $ Right (clientrunst, clientconn, releaseconn) + return $ Right $ P2PConnectionPair + { clientRunState = clientrunst + , clientP2PConnection = clientconn + , serverP2PConnection = serverconn + , releaseP2PConnection = releaseconn + } liftIO $ atomically $ putTMVar respvar resp mkserverrunst connparams = do diff --git a/P2P/IO.hs b/P2P/IO.hs index 66a4c08fea..7f0955250a 100644 --- a/P2P/IO.hs +++ b/P2P/IO.hs @@ -25,6 +25,7 @@ module P2P.IO , describeProtoFailure , runNetProto , runNet + , signalFullyConsumedByteString ) where import Common @@ -79,7 +80,12 @@ mkRunState mk = do data P2PHandle = P2PHandle Handle - | P2PHandleTMVar (TMVar (Either L.ByteString Message)) (Maybe (TMVar ())) + | P2PHandleTMVar (TMVar (Either L.ByteString Message)) (TMVar ()) + +signalFullyConsumedByteString :: P2PHandle -> IO () +signalFullyConsumedByteString (P2PHandle _) = return () +signalFullyConsumedByteString (P2PHandleTMVar _ waitv) = + atomically $ putTMVar waitv () data P2PConnection = P2PConnection { connRepo :: Maybe Repo @@ -246,14 +252,11 @@ runNet runst conn runner f = case f of Right False -> return $ Left $ ProtoFailureMessage "short data write" Left e -> return $ Left $ ProtoFailureException e - P2PHandleTMVar mv mwaitv -> do + P2PHandleTMVar mv waitv -> do liftIO $ atomically $ putTMVar mv (Left b) - case mwaitv of - -- Wait for the whole bytestring to - -- be processed. - Just waitv -> liftIO $ atomically $ - takeTMVar waitv - Nothing -> return () + -- Wait for the whole bytestring to + -- be processed. + liftIO $ atomically $ takeTMVar waitv runner next ReceiveBytes len p next -> case connIhdl conn of