diff --git a/Database/MongoDB/Connection.hs b/Database/MongoDB/Connection.hs index bd7f330..d366de6 100644 --- a/Database/MongoDB/Connection.hs +++ b/Database/MongoDB/Connection.hs @@ -20,11 +20,11 @@ import Database.MongoDB.Internal.Protocol (Pipe, newPipe) import System.IO.Pipeline (IOE, close, isClosed) import Control.Exception as E (try) import Network (HostName, PortID(..), connectTo) -import Text.ParserCombinators.Parsec as T (parse, many1, letter, digit, char, eof, spaces, try, (<|>)) +import Text.ParserCombinators.Parsec as T (ParseError, parse, many, many1, letter, digit, hexDigit, char, string, eof, spaces, try, (<|>)) import Control.Monad.Identity (runIdentity) import Control.Monad.Error (ErrorT(..), lift, throwError) import Control.Concurrent.MVar.Lifted -import Control.Monad (forM_) +import Control.Monad (forM_, liftM, liftM2) import Control.Applicative ((<$>)) import Data.UString (UString, unpack) import Data.Bson as D (Document, lookup, at, (=:)) @@ -71,15 +71,49 @@ readHostPortM :: (Monad m) => String -> m Host -- TODO: handle Service and UnixSocket port readHostPortM = either (fail . show) return . parse parser "readHostPort" where hostname = many1 (letter <|> digit <|> char '-' <|> char '.') - parser = do - spaces + parser = spaces >> (T.try simpleParser <|> ipv6Parser) + simpleParser = do h <- hostname - T.try (spaces >> eof >> return (host h)) <|> do - _ <- char ':' - port :: Int <- read <$> many1 digit - spaces >> eof - return $ Host h (PortNumber $ fromIntegral port) - + T.try (spaces >> eof >> return (host h)) <|> do + _ <- char ':' + port :: Int <- read <$> many1 digit + spaces >> eof + return $ Host h (PortNumber $ fromIntegral port) + breakLast :: (a -> Bool) -> [a] -> ([a], [a]) + breakLast f = either (\l -> (l, [])) id . foldr (\e s -> either (\l -> if f e then Right ([], l) else Left (e:l)) (\(a, b) -> Right (e:a, b)) s) (Left []) + ipv6Parser = do + fullHost <- liftM2 (++) (many hexDigit) (liftM2 (:) (char ':') (many1 (hexDigit <|> (char ':')))) + eof + let + (splitHost, splitPort) = breakLast (==':') fullHost + splitPortParsed :: Either ParseError Int + splitPortParsed = parse (liftM read (many1 digit)) "hostPort port" splitPort + splitHostParsed = parse ipv6HostTest "hostPort host" splitHost + fullHostParsed = parse ipv6HostTest "hostPort host" fullHost + case (splitHostParsed, splitPortParsed, fullHostParsed) of + -- Resolve the ambiguous cases (e.g. ::1:1234) as host:port + (Right _, Right p, _) -> return $ Host splitHost (PortNumber $ fromIntegral p) + (_, _, Right _) -> return $ host fullHost + _ -> fail "HostPort specification contains more than one :, but is invalid as an IPv6 address or IPv6 address:port" + ipv6HostTest = ((T.try (string "::") >> ipv6HostTest' True True 7) <|> + ipv6HostTest' False False 8) >> eof + + ipv6HostTest' hadDouble True v = eof <|> ipv6HostTest' hadDouble False v + ipv6HostTest' hadDouble False v = + (do + many1 hexDigit + if v == 1 + then eof + else + (if hadDouble then eof else do + T.try (string "::") + ipv6HostTest' True True (v-1) + ) <|> + (do + char ':' + ipv6HostTest' hadDouble False (v-1) + ) + ) readHostPort :: String -> Host -- ^ Read string \"hostname:port\" as @Host hostname (PortNumber port)@ or \"hostname\" as @host hostname@ (default port). Error if string does not match either syntax. readHostPort = runIdentity . readHostPortM diff --git a/tests/TestReadHostPort.hs b/tests/TestReadHostPort.hs new file mode 100644 index 0000000..dbe3ae7 --- /dev/null +++ b/tests/TestReadHostPort.hs @@ -0,0 +1,24 @@ +import Database.MongoDB.Connection +import System.Exit +import Control.Monad + +testList = [ + ("Simple host", readHostPort "host" == Host "host" (PortNumber 27017)), + ("Simple host with port", readHostPort "host:123" == Host "host" (PortNumber 123)), + ("Pathological ::1:1234 case", readHostPort "::1:1234" == Host "::1" (PortNumber 1234)), + ("Full IPv6 with port", readHostPort "1:2:3:4:5:6:7:8:1234" == Host "1:2:3:4:5:6:7:8" (PortNumber 1234)), + ("Full IPv6 without port", readHostPort "1:2:3:4:5:6:7:8" == Host "1:2:3:4:5:6:7:8" (PortNumber 27017)), + ("Partial IPv6 with port", readHostPort "1:2:3::4:12" == Host "1:2:3::4" (PortNumber 12)), + ("Partial IPv6 with hex at end", readHostPort "1:2:3::4:a" == Host "1:2:3::4:a" (PortNumber 27017)) + ] +main = + let + failedTests = filter (not . snd) testList + in + if null failedTests then + putStrLn "All tests passed" + else + do + putStrLn "The following tests failed:" + forM_ failedTests $ \(descr, _) -> putStrLn $ " * " ++ descr + exitWith (ExitFailure 1)