module Parser.Util where import Control.Monad (mfilter) import Data.Functor (($>)) import Data.Text (Text) import Data.Text qualified as Text import ParserMonad (Parser) import Text.Parsec (State (..)) import Text.Parsec qualified as Parsec import Text.Parsec.Pos (updatePosChar) {- | Consume characters from the input up to and including the given pattern. Return everything consumed except for the end pattern itself. -} takeUntil :: Text -> Parser Text takeUntil end_ = Text.dropEnd (Text.length end_) <$> requireEnd (scan p (False, end)) >>= gotSome where end = Text.unpack end_ p :: (Bool, String) -> Char -> Maybe (Bool, String) p acc c = case acc of (True, _) -> Just (False, end) (_, []) -> Nothing (_, x : xs) | x == c -> Just (False, xs) _ -> Just (c == '\\', end) requireEnd = mfilter (Text.isSuffixOf end_) gotSome xs | Text.null xs = fail "didn't get any content" | otherwise = return xs -- | Like `takeWhile`, but unconditionally take escaped characters. takeWhile_ :: (Char -> Bool) -> Parser Text takeWhile_ p = scan p_ False where p_ escaped c | escaped = Just False | not $ p c = Nothing | otherwise = Just (c == '\\') -- | Like 'takeWhile1', but unconditionally take escaped characters. takeWhile1_ :: (Char -> Bool) -> Parser Text takeWhile1_ = mfilter (not . Text.null) . takeWhile_ {- | Scan the input text, accumulating characters as long as the scanning function returns true. -} scan :: -- | scan function (state -> Char -> Maybe state) -> -- | initial state state -> Parser Text scan f initState = do parserState@State{stateInput = input, statePos = pos} <- Parsec.getParserState (remaining, finalPos, ct) <- go input initState pos 0 let newState = parserState{stateInput = remaining, statePos = finalPos} Parsec.setParserState newState $> Text.take ct input where go !input' !st !posAccum !count' = case Text.uncons input' of Nothing -> pure (input', posAccum, count') Just (char', input'') -> case f st char' of Nothing -> pure (input', posAccum, count') Just st' -> go input'' st' (updatePosChar posAccum char') (count' + 1)