diff --git a/src/codewars/RegExParser.py b/src/codewars/RegExParser.py index 2136571..19e7b22 100644 --- a/src/codewars/RegExParser.py +++ b/src/codewars/RegExParser.py @@ -34,210 +34,123 @@ class ZeroOrMore(RegExp): # See sample tests or test output for examples of usage. -def parse_regexp(indata): - return RegExParser(indata).compile() +def parse_regexp(pattern: str): + try: + parser = RegexParser(pattern) + result = parser.parse_regex() + + if not parser.end(): + raise ValueError("Unexpected characters") + + return result + except ValueError: + return None -class RegExParser: - compiled_pattern: RegExp - operators = "()*|." - - def __init__(self, pattern: str): +class RegexParser: + def __init__(self, pattern): self.pattern = pattern + self.pos = 0 - def compile(self) -> RegExp | None: - # validate - if not self.is_valid(): + # ---------------------------- + # helpers + # ---------------------------- + + def peek(self): + if self.pos >= len(self.pattern): + return None + return self.pattern[self.pos] + + def consume(self): + c = self.peek() + if c is not None: + self.pos += 1 + return c + + def end(self): + return self.pos >= len(self.pattern) + + # ---------------------------- + # regex grammar + # ---------------------------- + + def parse_regex(self): + return self.parse_alternation() + + def parse_alternation(self): + left = self.parse_concatenation() + + seen_or = False + + while self.peek() == "|": + if seen_or: + raise ValueError("Only one '|' allowed per group") + + seen_or = True + self.consume() + + right = self.parse_concatenation() + left = Or(left, right) + + return left + + def parse_concatenation(self): + nodes = [] + + while True: + c = self.peek() + + if c is None or c in "|)": + break + + nodes.append(self.parse_repetition()) + + if not nodes: return None - sequences = RegExParser.parse_sequences(self.pattern) + if len(nodes) == 1: + return nodes[0] - print(sequences) + return Str(nodes) - return RegExParser.parse_regex(sequences) + def parse_repetition(self): + node = self.parse_atom() - def is_valid(self): - if not self.pattern: - return False + while self.peek() == "*": + self.consume() - sequences = RegExParser.parse_sequences(self.pattern) + if isinstance(node, ZeroOrMore): + raise ValueError("Consecutive '*' not allowed") - if not sequences: - return False + node = ZeroOrMore(node) - if any([not string or string.count("|") > 1 for string in sequences]): - return False - elif isinstance(sequences, str) and sequences.count("|") > 1: - return False + return node - if self.pattern.find("*") == 0: - return False + def parse_atom(self): + c = self.peek() - if "**" in self.pattern: - return False - - return True - - @staticmethod - def parse_sequences(pattern): - result = [] - buffer = [] - depth = 0 - - left_paren_count = pattern.count("(") - right_paren_count = pattern.count(")") - if left_paren_count == -1 and right_paren_count == -1: - return [pattern] - if left_paren_count == -1 or right_paren_count == -1: - return None - if left_paren_count > right_paren_count: + if c is None: return None - for c in pattern: - if c == "(": - if depth == 0: - if buffer: - result.append("".join(buffer)) - buffer = [] - else: - buffer.append(c) - depth += 1 + if c == "(": + self.consume() + node = self.parse_regex() - elif c == ")": - depth -= 1 - if depth < 0: - return None - elif depth == 0: - group = "".join(buffer) - next_sequence = RegExParser.parse_sequences(group) - if isinstance(next_sequence, str): - next_sequence = "(" + next_sequence + ")" - result.append(next_sequence) - buffer = [] - else: - buffer.append(c) + if self.peek() != ")": + raise ValueError("Unmatched '('") - else: - buffer.append(c) + self.consume() + return node - if buffer: - result.append("".join(buffer)) + if c == "*": + raise ValueError("'*' cannot start an expression") - return result if len(result) > 1 else result[0] + if c == ".": + self.consume() + return Any() - @staticmethod - def parse_regex(sequences) -> RegExp | list[RegExp] | None: - if isinstance(sequences, str): - if sequences.find("*") == 0: - sequences = sequences[1:] - - if sequences == "": - return None - - if sequences.count("|") == 1: - or_groups = sequences.split("|") - left_exp = RegExParser.get_type(or_groups[0]) - right_exp = RegExParser.get_type(or_groups[1]) - - return Or(left_exp, right_exp) - else: - return RegExParser.get_type(sequences) - else: - result: list[RegExp] = [] - - sequence_idx = 0 - while sequence_idx < len(sequences): - current_sequence = sequences[sequence_idx] - - if sequence_idx < len(sequences) - 1: - next_sequence = sequences[sequence_idx + 1] - if isinstance(next_sequence, str) and next_sequence.startswith( - "*|" - ): - next_sequence = next_sequence[2:] - if next_sequence: - sequence_idx += 1 - else: - sequence_idx += 2 - next_sequence = sequences[sequence_idx] - regex = Or( - ZeroOrMore(RegExParser.parse_regex(current_sequence)), - RegExParser.parse_regex(next_sequence), - ) - if regex: - result.append(regex) - elif isinstance(next_sequence, str) and next_sequence.startswith( - "*" - ): - next_sequence = next_sequence[1:] - regex = ZeroOrMore(RegExParser.parse_regex(current_sequence)) - if regex: - result.append(regex) - elif isinstance(next_sequence, str) and next_sequence.startswith( - "|" - ): - next_sequence = next_sequence[1:] - if next_sequence: - sequence_idx += 1 - else: - sequence_idx += 2 - next_sequence = sequences[sequence_idx] - - regex = Or( - RegExParser.parse_regex(current_sequence), - RegExParser.parse_regex(next_sequence), - ) - if regex: - result.append(regex) - else: - regex = RegExParser.parse_regex(current_sequence) - if regex: - if isinstance(regex, list): - result.extend(regex) - else: - result.append(regex) - else: - regex = RegExParser.parse_regex(current_sequence) - if regex: - if isinstance(regex, list): - result.extend(regex) - else: - result.append(regex) - - sequence_idx += 1 - - return Str(result) if len(result) > 1 else result[0] - - @staticmethod - def get_type(sequence) -> RegExp | list[RegExp]: - regex: RegExp | list[RegExp] - if sequence.startswith("("): - sequence = sequence[1:] - if sequence.endswith(")"): - sequence = sequence[:-1] - - if len(sequence) > 1: - string = [] - - for char_idx in range(len(sequence) - 1): - if sequence[char_idx] == "*": - continue - - next_char = sequence[char_idx + 1] - if next_char == "*": - string.append(ZeroOrMore(RegExParser.get_type(sequence[char_idx]))) - else: - string.append(RegExParser.get_type(sequence[char_idx])) - - last_char = sequence[len(sequence) - 1] - if last_char != "*": - string.append(RegExParser.get_type(last_char)) - - regex = Str(string) if len(string) > 1 else string[0] - else: - regex = Any() if sequence == "." else Normal(sequence) - - return regex + self.consume() + return Normal(c) if __name__ == "__main__":