diff --git a/src/codewars/RegExParser.py b/src/codewars/RegExParser.py index d3920ba..553fa00 100644 --- a/src/codewars/RegExParser.py +++ b/src/codewars/RegExParser.py @@ -1,20 +1,39 @@ class RegExp: def __init__(self, *args): self.args = args + def __repr__(self): args = ", ".join(map(repr, self.args)) return f"{self.__class__.__name__}({args})" + def __eq__(self, other): return type(self) is type(other) and self.args == other.args -class Any(RegExp): pass -class Normal(RegExp): pass -class Or(RegExp): pass -class Str(RegExp): pass -class ZeroOrMore(RegExp): pass + + +class Any(RegExp): + pass + + +class Normal(RegExp): + pass + + +class Or(RegExp): + pass + + +class Str(RegExp): + pass + + +class ZeroOrMore(RegExp): + pass + # Your task is to build an AST using those nodes. # See sample tests or test output for examples of usage. + def parse_regexp(indata): return RegExParser(indata).compile() @@ -46,6 +65,8 @@ class RegExParser: if any([not string or string.count("|") > 1 for string in sequences]): return False + elif isinstance(sequences, str) and sequences.count("|") > 1: + return False if self.pattern.find("*") == 0: return False @@ -57,33 +78,47 @@ class RegExParser: @staticmethod def parse_sequences(pattern): - left_paren = pattern.find('(') - right_paren = pattern.rfind(')') - - if left_paren == -1 and right_paren == -1: - return [pattern] - - if left_paren == -1 or right_paren == -1: - return None - - if left_paren > right_paren: - return None - - left_side = pattern[:left_paren] - middle = pattern[left_paren + 1: right_paren] - right_side = pattern[right_paren + 1:] - result = [] - if left_side: - result.append(left_side) + buffer = [] + depth = 0 - if middle: - result.append(RegExParser.parse_sequences(middle)) + left_paren_count = pattern.count("(") + right_paren_count = pattern.count(")") + if left_paren_count == -1 and right_paren_count == -1: + return [pattern] + if left_paren_count == -1 or right_paren_count == -1: + return None + if left_paren_count > right_paren_count: + return None - if right_side: - result.append(right_side) + for c in pattern: + if c == "(": + if depth == 0: + if buffer: + result.append("".join(buffer)) + buffer = [] + else: + buffer.append(c) + depth += 1 - return result + elif c == ")": + depth -= 1 + if depth < 0: + return None + elif depth == 0: + group = "".join(buffer) + result.append(RegExParser.parse_sequences(group)) + buffer = [] + else: + buffer.append(c) + + else: + buffer.append(c) + + if buffer: + result.append("".join(buffer)) + + return result if len(result) > 1 else result[0] @staticmethod def parse_regex(sequences) -> RegExp | None: @@ -106,15 +141,50 @@ class RegExParser: else: result: list[RegExp] = [] - for sequence_idx in range(len(sequences)): + sequence_idx = 0 + while sequence_idx < len(sequences): current_sequence = sequences[sequence_idx] if sequence_idx < len(sequences) - 1: next_sequence = sequences[sequence_idx + 1] - if isinstance(next_sequence, str) and next_sequence.startswith("*"): + if isinstance(next_sequence, str) and next_sequence.startswith( + "*|" + ): + next_sequence = next_sequence[2:] + if next_sequence: + sequence_idx += 1 + else: + sequence_idx += 2 + next_sequence = sequences[sequence_idx] + regex = Or( + ZeroOrMore(RegExParser.parse_regex(current_sequence)), + RegExParser.parse_regex(next_sequence), + ) + if regex: + result.append(regex) + elif isinstance(next_sequence, str) and next_sequence.startswith( + "*" + ): + next_sequence = next_sequence[1:] regex = ZeroOrMore(RegExParser.parse_regex(current_sequence)) if regex: result.append(regex) + elif isinstance(next_sequence, str) and next_sequence.startswith( + "|" + ): + next_sequence = next_sequence[1:] + if next_sequence: + sequence_idx += 1 + else: + sequence_idx += 2 + next_sequence = sequences[sequence_idx] + + regex = Or( + RegExParser.parse_regex(current_sequence), + RegExParser.parse_regex(next_sequence), + ) + if regex: + result.append(regex) else: regex = RegExParser.parse_regex(current_sequence) if regex: @@ -124,8 +194,9 @@ class RegExParser: if regex: result.append(regex) - return Str(result) if len(result) > 1 else result[0] + sequence_idx += 1 + return Str(result) if len(result) > 1 else result[0] @staticmethod def get_type(sequence) -> RegExp: @@ -154,18 +225,10 @@ class RegExParser: return regex + if __name__ == "__main__": - test_cases = [ - "", - "(", - "(hi!", - ")(", - "a|t|y", - "a**", - "]K\nBYg