diff --git a/compiler/front_end/lr1.py b/compiler/front_end/lr1.py index 579d729..610820a 100644 --- a/compiler/front_end/lr1.py +++ b/compiler/front_end/lr1.py @@ -198,11 +198,9 @@ def __init__(self, start_symbol, productions): def _set_productions_by_lhs(self): # Prepopulating _productions_by_lhs speeds up _closure_of_item by about 30%, # which is significant on medium-to-large grammars. - self._productions_by_lhs = {} + self._productions_by_lhs = collections.defaultdict(list) for production in self.productions: - self._productions_by_lhs.setdefault(production.lhs, list()).append( - production - ) + self._productions_by_lhs[production.lhs].append(production) def _populate_item_cache(self): # There are a relatively small number of possible Items for a grammar, and @@ -456,7 +454,7 @@ def _items(self): ) ] items = {item_list[0]: 0} - goto_table = {} + goto_table = collections.defaultdict(dict) i = 0 # For each state, figure out what the new state when each symbol is added to # the top of the parsing stack (see the comments in parser._parse). See @@ -469,7 +467,7 @@ def _items(self): if goto not in items: items[goto] = len(item_list) item_list.append(goto) - goto_table[i, symbol] = items[goto] + goto_table[i][symbol] = items[goto] i += 1 return item_list, goto_table @@ -494,7 +492,7 @@ def parser(self): A Parser. """ item_sets, goto = self._items() - action = {} + action = collections.defaultdict(dict) conflicts = set() end_item = self._item_cache[self._seed_production, 1, END_OF_INPUT] for i in range(len(item_sets)): @@ -508,38 +506,31 @@ def parser(self): new_action = Reduce(item.production) elif item.next_symbol in self.terminals: terminal = item.next_symbol - assert goto[i, terminal] is not None - new_action = Shift(goto[i, terminal], item_sets[goto[i, terminal]]) + assert goto[i][terminal] is not None + new_action = Shift(goto[i][terminal], item_sets[goto[i][terminal]]) if new_action: - if (i, terminal) in action and action[i, terminal] != new_action: + if action[i].get(terminal, new_action) != new_action: conflicts.add( Conflict( i, terminal, - frozenset([action[i, terminal], new_action]), + frozenset([action[i][terminal], new_action]), ) ) - action[i, terminal] = new_action + action[i][terminal] = new_action if item == end_item: new_action = Accept() - assert (i, END_OF_INPUT) not in action or action[ - i, END_OF_INPUT - ] == new_action - action[i, END_OF_INPUT] = new_action - trimmed_goto = {} + assert action[i].get(END_OF_INPUT, new_action) == new_action + action[i][END_OF_INPUT] = new_action + trimmed_goto = collections.defaultdict(dict) for k in goto: - if k[1] in self.nonterminals: - trimmed_goto[k] = goto[k] - expected = {} - for state, terminal in action: - if state not in expected: - expected[state] = set() - expected[state].add(terminal) + for l in goto[k]: + if l in self.nonterminals: + trimmed_goto[k][l] = goto[k][l] return Parser( item_sets, trimmed_goto, action, - expected, conflicts, self.terminals, self.nonterminals, @@ -585,7 +576,6 @@ def __init__( item_sets, goto, action, - expected, conflicts, terminals, nonterminals, @@ -595,7 +585,6 @@ def __init__( self.item_sets = item_sets self.goto = goto self.action = action - self.expected = expected self.conflicts = conflicts self.terminals = terminals self.nonterminals = nonterminals @@ -634,7 +623,7 @@ def state(): # On each iteration, look at the next symbol and the current state, and # perform the corresponding action. while True: - if (state(), tokens[cursor].symbol) not in self.action: + if tokens[cursor].symbol not in self.action.get(state(), {}): # Most state/symbol entries would be Errors, so rather than exhaustively # adding error entries, we just check here. if state() in self.default_errors: @@ -642,7 +631,7 @@ def state(): else: next_action = Error(None) else: - next_action = self.action[state(), tokens[cursor].symbol] + next_action = self.action[state()][tokens[cursor].symbol] if isinstance(next_action, Shift): # Shift means that there are no "complete" productions on the stack, @@ -717,7 +706,7 @@ def state(): next_action.rule.lhs, children, next_action.rule, source_location ) del stack[len(stack) - len(next_action.rule.rhs) :] - stack.append((self.goto[state(), next_action.rule.lhs], reduction)) + stack.append((self.goto[state()][next_action.rule.lhs], reduction)) elif isinstance(next_action, Error): # Error means that the parse is impossible. For typical grammars and # texts, this usually happens within a few tokens after the mistake in @@ -730,7 +719,11 @@ def state(): cursor, tokens[cursor], state(), - self.expected[state()], + set( + k + for k in self.action[state()].keys() + if not isinstance(self.action[state()][k], Error) + ), ), ) else: @@ -801,8 +794,8 @@ def mark_error(self, tokens, error_token, error_code): self.default_errors[result.error.state] = error_code return None else: - if (result.error.state, error_symbol) in self.action: - existing_error = self.action[result.error.state, error_symbol] + if error_symbol in self.action.get(result.error.state, {}): + existing_error = self.action[result.error.state][error_symbol] assert isinstance(existing_error, Error), "Bug" if existing_error.code == error_code: return None @@ -817,7 +810,7 @@ def mark_error(self, tokens, error_token, error_code): ) ) else: - self.action[result.error.state, error_symbol] = Error(error_code) + self.action[result.error.state][error_symbol] = Error(error_code) return None assert False, "All other paths should lead to return." @@ -830,5 +823,4 @@ def parse(self, tokens): Returns: A ParseResult. """ - result = self._parse(tokens) - return result + return self._parse(tokens) diff --git a/compiler/front_end/lr1_test.py b/compiler/front_end/lr1_test.py index ae03e2d..6ca6e67 100644 --- a/compiler/front_end/lr1_test.py +++ b/compiler/front_end/lr1_test.py @@ -98,31 +98,42 @@ def _parse_productions(text): # ACTION table corresponding to the above grammar, ASLU p266. _alsu_action = { - (0, "c"): lr1.Shift(3, _alsu_items[3]), - (0, "d"): lr1.Shift(4, _alsu_items[4]), - (1, lr1.END_OF_INPUT): lr1.Accept(), - (2, "c"): lr1.Shift(6, _alsu_items[6]), - (2, "d"): lr1.Shift(7, _alsu_items[7]), - (3, "c"): lr1.Shift(3, _alsu_items[3]), - (3, "d"): lr1.Shift(4, _alsu_items[4]), - (4, "c"): lr1.Reduce(parser_types.Production("C", ("d",))), - (4, "d"): lr1.Reduce(parser_types.Production("C", ("d",))), - (5, lr1.END_OF_INPUT): lr1.Reduce(parser_types.Production("S", ("C", "C"))), - (6, "c"): lr1.Shift(6, _alsu_items[6]), - (6, "d"): lr1.Shift(7, _alsu_items[7]), - (7, lr1.END_OF_INPUT): lr1.Reduce(parser_types.Production("C", ("d",))), - (8, "c"): lr1.Reduce(parser_types.Production("C", ("c", "C"))), - (8, "d"): lr1.Reduce(parser_types.Production("C", ("c", "C"))), - (9, lr1.END_OF_INPUT): lr1.Reduce(parser_types.Production("C", ("c", "C"))), + 0: { + "c": lr1.Shift(3, _alsu_items[3]), + "d": lr1.Shift(4, _alsu_items[4]), + }, + 1: {lr1.END_OF_INPUT: lr1.Accept()}, + 2: { + "c": lr1.Shift(6, _alsu_items[6]), + "d": lr1.Shift(7, _alsu_items[7]), + }, + 3: { + "c": lr1.Shift(3, _alsu_items[3]), + "d": lr1.Shift(4, _alsu_items[4]), + }, + 4: { + "c": lr1.Reduce(parser_types.Production("C", ("d",))), + "d": lr1.Reduce(parser_types.Production("C", ("d",))), + }, + 5: {lr1.END_OF_INPUT: lr1.Reduce(parser_types.Production("S", ("C", "C")))}, + 6: { + "c": lr1.Shift(6, _alsu_items[6]), + "d": lr1.Shift(7, _alsu_items[7]), + }, + 7: {lr1.END_OF_INPUT: lr1.Reduce(parser_types.Production("C", ("d",)))}, + 8: { + "c": lr1.Reduce(parser_types.Production("C", ("c", "C"))), + "d": lr1.Reduce(parser_types.Production("C", ("c", "C"))), + }, + 9: {lr1.END_OF_INPUT: lr1.Reduce(parser_types.Production("C", ("c", "C")))}, } # GOTO table corresponding to the above grammar, ASLU p266. _alsu_goto = { - (0, "S"): 1, - (0, "C"): 2, - (2, "C"): 5, - (3, "C"): 8, - (6, "C"): 9, + 0: {"S": 1, "C": 2}, + 2: {"C": 5}, + 3: {"C": 8}, + 6: {"C": 9}, } @@ -137,15 +148,17 @@ def _normalize_table(items, table): original_index_to_index[item_to_original_index[sorted_items[i]]] = i updated_table = {} for k in table: - new_k = original_index_to_index[k[0]], k[1] - new_value = table[k] - if isinstance(new_value, int): - new_value = original_index_to_index[new_value] - elif isinstance(new_value, lr1.Shift): - new_value = lr1.Shift( - original_index_to_index[new_value.state], new_value.items - ) - updated_table[new_k] = new_value + for l in table[k]: + new_k = original_index_to_index[k] + new_value = table[k][l] + if isinstance(new_value, int): + new_value = original_index_to_index[new_value] + elif isinstance(new_value, lr1.Shift): + new_value = lr1.Shift( + original_index_to_index[new_value.state], new_value.items + ) + updated_table.setdefault(new_k, {}) + updated_table[new_k][l] = new_value return sorted_items, updated_table @@ -302,7 +315,7 @@ def test_mark_error(self): # Marking an already-marked error with the same error code should succeed. self.assertIsNone(parser.mark_error(_tokenize("d"), None, "missing last C")) # Marking an already-marked error with a different error code should fail. - self.assertRegexpMatches( + self.assertRegex( parser.mark_error(_tokenize("d"), None, "different message"), r"^Attempted to overwrite existing error code 'missing last C' with " r"new error code 'different message' for state \d+, terminal \$$",