diff --git a/Assistant/Threads/PushNotifier.hs b/Assistant/Threads/PushNotifier.hs index 088e97ec4b..2784012f24 100644 --- a/Assistant/Threads/PushNotifier.hs +++ b/Assistant/Threads/PushNotifier.hs @@ -16,15 +16,17 @@ import Assistant.DaemonStatus import Assistant.Pushes import Assistant.Sync import qualified Remote +import Utility.FileMode +import Utility.SRV import Network.Protocol.XMPP import Network import Control.Concurrent import qualified Data.Text as T import qualified Data.Set as S -import Utility.FileMode import qualified Git.Branch import Data.XML.Types +import Control.Exception as E thisThread :: ThreadName thisThread = "PushNotifier" @@ -33,27 +35,16 @@ pushNotifierThread :: ThreadState -> DaemonStatusHandle -> PushNotifier -> Named pushNotifierThread st dstatus pushnotifier = NamedThread thisThread $ do v <- runThreadState st $ getXMPPCreds case v of - Nothing -> nocreds - Just c -> case parseJID (xmppJID c) of - Nothing -> nocreds - Just jid -> void $ client c jid - where - nocreds = do - error "no creds" -- TODO alert - return () -- exit thread - - client c jid = runClient server jid (xmppUsername c) (xmppPassword c) $ do - void $ bindJID jid + Nothing -> do + return () -- no creds? exit thread + Just c -> void $ connectXMPP c $ \jid -> do + fulljid <- bindJID jid + liftIO $ debug thisThread ["XMPP connected", show fulljid] s <- getSession _ <- liftIO $ forkOS $ void $ runXMPP s $ receivenotifications sendnotifications - where - server = Server - (JID Nothing (jidDomain jid) Nothing) - (xmppHostname c) - (PortNumber $ fromIntegral $ xmppPort c) - + where sendnotifications = forever $ do us <- liftIO $ waitPush pushnotifier let payload = [extendedAway, encodePushNotification us] @@ -78,12 +69,43 @@ data XMPPCreds = XMPPCreds , xmppPassword :: T.Text , xmppHostname :: HostName , xmppPort :: Int - {- Something like username@hostname, but not necessarily the same - - username or hostname used to connect to the server. -} , xmppJID :: T.Text } deriving (Read, Show) +{- Note that this must be run in a bound thread; gnuTLS requires it. -} +connectXMPP :: XMPPCreds -> (JID -> XMPP a) -> IO (Either SomeException ()) +connectXMPP c a = case parseJID (xmppJID c) of + Nothing -> error "bad JID" + Just jid -> connectXMPP' jid c a + +{- Do a SRV lookup, but if it fails, fall back to the cached xmppHostname. -} +connectXMPP' :: JID -> XMPPCreds -> (JID -> XMPP a) -> IO (Either SomeException ()) +connectXMPP' jid c a = go =<< lookupSRV srvrecord + where + srvrecord = "_xmpp-client._tcp." ++ (T.unpack $ strDomain $ jidDomain jid) + serverjid = JID Nothing (jidDomain jid) Nothing + + go [] = run (xmppHostname c) + (PortNumber $ fromIntegral $ xmppPort c) + (a jid) + go ((h,p):rest) = do + {- Try each SRV record in turn, until one connects, + - at which point the MVar will be full. -} + mv <- newEmptyMVar + r <- run h p $ do + liftIO $ putMVar mv () + a jid + ifM (isEmptyMVar mv) (go rest, return r) + + run h p a' = do + liftIO $ debug thisThread ["XMPP trying", h] + E.try (runClientError (Server serverjid h p) jid (xmppUsername c) (xmppPassword c) (void a')) :: IO (Either SomeException ()) + +{- XMPP runClient, that throws errors rather than returning an Either -} +runClientError :: Server -> JID -> T.Text -> T.Text -> XMPP a -> IO a +runClientError s j u p x = either (error . show) return =<< runClient s j u p x + getXMPPCreds :: Annex (Maybe XMPPCreds) getXMPPCreds = do f <- xmppCredsFile diff --git a/Utility/SRV.hs b/Utility/SRV.hs index 51d4360e22..c30c8bd866 100644 --- a/Utility/SRV.hs +++ b/Utility/SRV.hs @@ -35,7 +35,7 @@ type HostPort = (HostName, PortID) {- Returns an ordered list, with highest priority hosts first. - - On error, returns an empty list. -} -lookupSRV :: String -> IO [HostPort] +lookupSRV :: HostName -> IO [HostPort] #ifdef WITH_ADNS lookupSRV srv = initResolver [] $ \resolver -> do r <- catchDefaultIO (Right []) $ @@ -45,7 +45,7 @@ lookupSRV srv = initResolver [] $ \resolver -> do lookupSRV = lookupSRVHost #endif -lookupSRVHost :: String -> IO [HostPort] +lookupSRVHost :: HostName -> IO [HostPort] lookupSRVHost srv | Build.SysConfig.host = catchDefaultIO [] $ parseSrvHost <$> readProcessEnv "host" ["-t", "SRV", "--", srv] @@ -54,16 +54,17 @@ lookupSRVHost srv | otherwise = return [] parseSrvHost :: String -> [HostPort] -parseSrvHost = map snd . reverse . sortBy priority . catMaybes . map parse . lines +parseSrvHost = map snd . reverse . sortBy cost . catMaybes . map parse . lines where - priority = compare `on` fst + cost = compare `on` fst parse l = case words l of - [_, _, _, _, priority, weight, sport, hostname] -> - case PortNumber . fromIntegral <$> readish sport of + [_, _, _, _, priority, weight, sport, hostname] -> do + let v = readish sport :: Maybe Int + case v of Nothing -> Nothing Just port -> Just ( (priority, weight) - , (hostname, port) + , (hostname, PortNumber $ fromIntegral port) ) _ -> Nothing