Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure ACTION and GOTO tables. #194

Merged
merged 4 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 28 additions & 36 deletions compiler/front_end/lr1.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,9 @@ def __init__(self, start_symbol, productions):
def _set_productions_by_lhs(self):
# Prepopulating _productions_by_lhs speeds up _closure_of_item by about 30%,
# which is significant on medium-to-large grammars.
self._productions_by_lhs = {}
self._productions_by_lhs = collections.defaultdict(list)
for production in self.productions:
self._productions_by_lhs.setdefault(production.lhs, list()).append(
production
)
self._productions_by_lhs[production.lhs].append(production)

def _populate_item_cache(self):
# There are a relatively small number of possible Items for a grammar, and
Expand Down Expand Up @@ -456,7 +454,7 @@ def _items(self):
)
]
items = {item_list[0]: 0}
goto_table = {}
goto_table = collections.defaultdict(dict)
i = 0
# For each state, figure out what the new state when each symbol is added to
# the top of the parsing stack (see the comments in parser._parse). See
Expand All @@ -469,7 +467,7 @@ def _items(self):
if goto not in items:
items[goto] = len(item_list)
item_list.append(goto)
goto_table[i, symbol] = items[goto]
goto_table[i][symbol] = items[goto]
i += 1
return item_list, goto_table

Expand All @@ -494,7 +492,7 @@ def parser(self):
A Parser.
"""
item_sets, goto = self._items()
action = {}
action = collections.defaultdict(dict)
conflicts = set()
end_item = self._item_cache[self._seed_production, 1, END_OF_INPUT]
for i in range(len(item_sets)):
Expand All @@ -508,38 +506,31 @@ def parser(self):
new_action = Reduce(item.production)
elif item.next_symbol in self.terminals:
terminal = item.next_symbol
assert goto[i, terminal] is not None
new_action = Shift(goto[i, terminal], item_sets[goto[i, terminal]])
assert goto[i][terminal] is not None
new_action = Shift(goto[i][terminal], item_sets[goto[i][terminal]])
if new_action:
if (i, terminal) in action and action[i, terminal] != new_action:
if action[i].get(terminal, new_action) != new_action:
conflicts.add(
Conflict(
i,
terminal,
frozenset([action[i, terminal], new_action]),
frozenset([action[i][terminal], new_action]),
)
)
action[i, terminal] = new_action
action[i][terminal] = new_action
if item == end_item:
new_action = Accept()
assert (i, END_OF_INPUT) not in action or action[
i, END_OF_INPUT
] == new_action
action[i, END_OF_INPUT] = new_action
trimmed_goto = {}
assert action[i].get(END_OF_INPUT, new_action) == new_action
action[i][END_OF_INPUT] = new_action
trimmed_goto = collections.defaultdict(dict)
for k in goto:
if k[1] in self.nonterminals:
trimmed_goto[k] = goto[k]
expected = {}
for state, terminal in action:
if state not in expected:
expected[state] = set()
expected[state].add(terminal)
for l in goto[k]:
if l in self.nonterminals:
trimmed_goto[k][l] = goto[k][l]
return Parser(
item_sets,
trimmed_goto,
action,
expected,
conflicts,
self.terminals,
self.nonterminals,
Expand Down Expand Up @@ -585,7 +576,6 @@ def __init__(
item_sets,
goto,
action,
expected,
conflicts,
terminals,
nonterminals,
Expand All @@ -595,7 +585,6 @@ def __init__(
self.item_sets = item_sets
self.goto = goto
self.action = action
self.expected = expected
self.conflicts = conflicts
self.terminals = terminals
self.nonterminals = nonterminals
Expand Down Expand Up @@ -634,15 +623,15 @@ def state():
# On each iteration, look at the next symbol and the current state, and
# perform the corresponding action.
while True:
if (state(), tokens[cursor].symbol) not in self.action:
if tokens[cursor].symbol not in self.action.get(state(), {}):
# Most state/symbol entries would be Errors, so rather than exhaustively
# adding error entries, we just check here.
if state() in self.default_errors:
next_action = Error(self.default_errors[state()])
else:
next_action = Error(None)
else:
next_action = self.action[state(), tokens[cursor].symbol]
next_action = self.action[state()][tokens[cursor].symbol]

if isinstance(next_action, Shift):
# Shift means that there are no "complete" productions on the stack,
Expand Down Expand Up @@ -717,7 +706,7 @@ def state():
next_action.rule.lhs, children, next_action.rule, source_location
)
del stack[len(stack) - len(next_action.rule.rhs) :]
stack.append((self.goto[state(), next_action.rule.lhs], reduction))
stack.append((self.goto[state()][next_action.rule.lhs], reduction))
elif isinstance(next_action, Error):
# Error means that the parse is impossible. For typical grammars and
# texts, this usually happens within a few tokens after the mistake in
Expand All @@ -730,7 +719,11 @@ def state():
cursor,
tokens[cursor],
state(),
self.expected[state()],
set(
k
for k in self.action[state()].keys()
if not isinstance(self.action[state()][k], Error)
),
),
)
else:
Expand Down Expand Up @@ -801,8 +794,8 @@ def mark_error(self, tokens, error_token, error_code):
self.default_errors[result.error.state] = error_code
return None
else:
if (result.error.state, error_symbol) in self.action:
existing_error = self.action[result.error.state, error_symbol]
if error_symbol in self.action.get(result.error.state, {}):
existing_error = self.action[result.error.state][error_symbol]
assert isinstance(existing_error, Error), "Bug"
if existing_error.code == error_code:
return None
Expand All @@ -817,7 +810,7 @@ def mark_error(self, tokens, error_token, error_code):
)
)
else:
self.action[result.error.state, error_symbol] = Error(error_code)
self.action[result.error.state][error_symbol] = Error(error_code)
return None
assert False, "All other paths should lead to return."

Expand All @@ -830,5 +823,4 @@ def parse(self, tokens):
Returns:
A ParseResult.
"""
result = self._parse(tokens)
return result
return self._parse(tokens)
75 changes: 44 additions & 31 deletions compiler/front_end/lr1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,31 +98,42 @@ def _parse_productions(text):

# ACTION table corresponding to the above grammar, ASLU p266.
_alsu_action = {
(0, "c"): lr1.Shift(3, _alsu_items[3]),
(0, "d"): lr1.Shift(4, _alsu_items[4]),
(1, lr1.END_OF_INPUT): lr1.Accept(),
(2, "c"): lr1.Shift(6, _alsu_items[6]),
(2, "d"): lr1.Shift(7, _alsu_items[7]),
(3, "c"): lr1.Shift(3, _alsu_items[3]),
(3, "d"): lr1.Shift(4, _alsu_items[4]),
(4, "c"): lr1.Reduce(parser_types.Production("C", ("d",))),
(4, "d"): lr1.Reduce(parser_types.Production("C", ("d",))),
(5, lr1.END_OF_INPUT): lr1.Reduce(parser_types.Production("S", ("C", "C"))),
(6, "c"): lr1.Shift(6, _alsu_items[6]),
(6, "d"): lr1.Shift(7, _alsu_items[7]),
(7, lr1.END_OF_INPUT): lr1.Reduce(parser_types.Production("C", ("d",))),
(8, "c"): lr1.Reduce(parser_types.Production("C", ("c", "C"))),
(8, "d"): lr1.Reduce(parser_types.Production("C", ("c", "C"))),
(9, lr1.END_OF_INPUT): lr1.Reduce(parser_types.Production("C", ("c", "C"))),
0: {
"c": lr1.Shift(3, _alsu_items[3]),
"d": lr1.Shift(4, _alsu_items[4]),
},
1: {lr1.END_OF_INPUT: lr1.Accept()},
2: {
"c": lr1.Shift(6, _alsu_items[6]),
"d": lr1.Shift(7, _alsu_items[7]),
},
3: {
"c": lr1.Shift(3, _alsu_items[3]),
"d": lr1.Shift(4, _alsu_items[4]),
},
4: {
"c": lr1.Reduce(parser_types.Production("C", ("d",))),
"d": lr1.Reduce(parser_types.Production("C", ("d",))),
},
5: {lr1.END_OF_INPUT: lr1.Reduce(parser_types.Production("S", ("C", "C")))},
6: {
"c": lr1.Shift(6, _alsu_items[6]),
"d": lr1.Shift(7, _alsu_items[7]),
},
7: {lr1.END_OF_INPUT: lr1.Reduce(parser_types.Production("C", ("d",)))},
8: {
"c": lr1.Reduce(parser_types.Production("C", ("c", "C"))),
"d": lr1.Reduce(parser_types.Production("C", ("c", "C"))),
},
9: {lr1.END_OF_INPUT: lr1.Reduce(parser_types.Production("C", ("c", "C")))},
}

# GOTO table corresponding to the above grammar, ASLU p266.
_alsu_goto = {
(0, "S"): 1,
(0, "C"): 2,
(2, "C"): 5,
(3, "C"): 8,
(6, "C"): 9,
0: {"S": 1, "C": 2},
2: {"C": 5},
3: {"C": 8},
6: {"C": 9},
}


Expand All @@ -137,15 +148,17 @@ def _normalize_table(items, table):
original_index_to_index[item_to_original_index[sorted_items[i]]] = i
updated_table = {}
for k in table:
new_k = original_index_to_index[k[0]], k[1]
new_value = table[k]
if isinstance(new_value, int):
new_value = original_index_to_index[new_value]
elif isinstance(new_value, lr1.Shift):
new_value = lr1.Shift(
original_index_to_index[new_value.state], new_value.items
)
updated_table[new_k] = new_value
for l in table[k]:
new_k = original_index_to_index[k]
new_value = table[k][l]
if isinstance(new_value, int):
new_value = original_index_to_index[new_value]
elif isinstance(new_value, lr1.Shift):
new_value = lr1.Shift(
original_index_to_index[new_value.state], new_value.items
)
updated_table.setdefault(new_k, {})
updated_table[new_k][l] = new_value
return sorted_items, updated_table


Expand Down Expand Up @@ -302,7 +315,7 @@ def test_mark_error(self):
# Marking an already-marked error with the same error code should succeed.
self.assertIsNone(parser.mark_error(_tokenize("d"), None, "missing last C"))
# Marking an already-marked error with a different error code should fail.
self.assertRegexpMatches(
self.assertRegex(
parser.mark_error(_tokenize("d"), None, "different message"),
r"^Attempted to overwrite existing error code 'missing last C' with "
r"new error code 'different message' for state \d+, terminal \$$",
Expand Down