Skip to content

Commit e3b2f62

Browse files
author
Grant Wuerker
committed
ADT recursion
1 parent fd274c9 commit e3b2f62

File tree

12 files changed

+422
-61
lines changed

12 files changed

+422
-61
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/common2/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ camino = "1.1.4"
1515
smol_str = "0.1.24"
1616
salsa = { git = "https://github.com/salsa-rs/salsa", package = "salsa-2022" }
1717
parser = { path = "../parser2", package = "fe-parser2" }
18+
rustc-hash = "1.1.0"

crates/common2/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pub mod diagnostics;
22
pub mod input;
3+
pub mod recursion;
34

45
pub use input::{InputFile, InputIngot};
56

crates/common2/src/recursion/dsf.rs

+195
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
//! code copied from https://gist.github.com/jlikhuva/70bef29102d62054e8379a7c8f51ad87#
2+
//!
3+
//! Sometimes, we need to keep track of k disjoint groups of
4+
//! items — meaning that each each item uniquely belongs to one group.
5+
//! The most common operations that we'd like to run on our collection of groups are:
6+
//! find_set(x) which tells us which of our k groups u belongs to and union(s1, s2)
7+
//! which allows us to merge two groups. Remember in the section on connected
8+
//! components we needed to quickly find out if a node was already part of some component.
9+
10+
use std::marker::PhantomData;
11+
12+
/// A single node in the disjoit set forest
13+
#[derive(Debug, Eq, PartialEq)]
14+
pub struct DSFNode<T: PartialEq> {
15+
/// The payload stored at this node. This is a unique
16+
/// identifier of a node
17+
key: T,
18+
19+
/// Each node must have a parent. The root node
20+
/// is its own parent
21+
parent: usize,
22+
23+
/// The location of this node in forest
24+
index: usize,
25+
26+
/// With each node, we maintain an integer value x.rank
27+
/// This is an upper bound on the height of x -- that is
28+
/// an upper bound on the number of edges in the longest
29+
/// x -> descendant_leaf simple path. make_set, initializes
30+
/// this value to 0
31+
rank: usize,
32+
}
33+
34+
#[derive(Debug, PartialEq, Eq)]
35+
pub struct DSFNodeHandle<T>(usize, PhantomData<T>);
36+
37+
impl<T: PartialEq> DSFNode<T> {
38+
/// Create a new node with the given value and parent.
39+
/// On creation, the parent field should point to
40+
/// the location of this node in the forest vector --
41+
/// that is, a singleton node is its own parent
42+
pub fn new(key: T, parent: usize) -> Self {
43+
let rank = 0;
44+
let index = parent;
45+
DSFNode {
46+
key,
47+
parent,
48+
rank,
49+
index,
50+
}
51+
}
52+
}
53+
/// A disjoint set forest is a collection of trees.
54+
/// Only the nodes in a single tree are linked together.
55+
/// A user interacts with the forest using the nodes
56+
#[derive(Debug)]
57+
pub struct DSF<T: PartialEq> {
58+
/// A collection of all the nodes in the tree
59+
forest: Vec<DSFNode<T>>,
60+
}
61+
62+
impl<T: PartialEq> DSF<T> {
63+
/// Creates a new disjoint set forest structure with no trees in it
64+
pub fn new() -> Self {
65+
DSF { forest: Vec::new() }
66+
}
67+
68+
/// Adds a new node into the disjoint set forest.
69+
/// It returns a handle to a node that can be passed into
70+
/// the other two methods
71+
pub fn make_set(&mut self, x: T) -> DSFNodeHandle<T> {
72+
let idx = self.forest.len();
73+
self.forest.push(DSFNode::new(x, idx));
74+
DSFNodeHandle(idx, PhantomData)
75+
}
76+
77+
/// Th union operation has two bcases. if the roots have unequal rank, we make
78+
/// the root with lower rank point to the root with the higher rank. The
79+
/// ranks, however, do not change. If the roots have equal ranke, we choose one
80+
/// of the roots as the root of the combined set. We also increase
81+
/// the rank of the new root by 1.
82+
pub fn union(&mut self, a: &DSFNodeHandle<T>, b: &DSFNodeHandle<T>) {
83+
let a_root = Self::find_set_helper(&mut self.forest, a.0);
84+
let b_root = Self::find_set_helper(&mut self.forest, b.0);
85+
86+
// We make the root with the higher rank the parent of the one
87+
// with the lower rank. This effectively makes it the representative
88+
// of the combined set
89+
if self.forest[a_root].rank > self.forest[b_root].rank {
90+
self.forest[b_root].parent = self.forest[a_root].index
91+
} else {
92+
self.forest[a_root].parent = self.forest[b_root].index;
93+
94+
// Note that ranks only change when we merge two trees with the same
95+
// ranks. The choice of whose rank to increase is made arbitrarily
96+
if self.forest[a_root].rank == self.forest[b_root].rank {
97+
self.forest[b_root].rank += 1;
98+
}
99+
}
100+
}
101+
102+
/// Finds the representative if x. Also does path compression. It does not change
103+
/// the value of rank.
104+
pub fn find_set(&mut self, x: &DSFNodeHandle<T>) -> DSFNodeHandle<T> {
105+
let idx = Self::find_set_helper(&mut self.forest, x.0);
106+
DSFNodeHandle(idx, PhantomData)
107+
}
108+
109+
fn find_set_helper(forest: &mut Vec<DSFNode<T>>, x_index: usize) -> usize {
110+
// When I first saw this method, I simply thought it was the coolest
111+
// thing in the world. Recursion on recursive structures yields
112+
// simple, elegant, yet powerful code. According to CLRS, this
113+
// is an instance of a general method called `the two-pass` method.
114+
115+
// First make an upward pass to find the representative, i.e the root
116+
// then make a downward pass, as the stack is being unwound, to set
117+
// the parent of each node in the x -> root path
118+
let cur_x_parent = forest[x_index].parent;
119+
if cur_x_parent != x_index {
120+
forest[x_index].parent = Self::find_set_helper(forest, cur_x_parent);
121+
}
122+
forest[x_index].parent
123+
}
124+
}
125+
126+
#[cfg(test)]
127+
mod test {
128+
#[test]
129+
fn make_set() {
130+
use super::DSF;
131+
let mut forest = DSF::<&str>::new();
132+
let _t1 = forest.make_set("good");
133+
let _t2 = forest.make_set("splendid");
134+
let _t3 = forest.make_set("remarkable");
135+
let _t4 = forest.make_set("nice");
136+
let _t5 = forest.make_set("amazing");
137+
}
138+
139+
#[test]
140+
fn union_and_find_set() {
141+
use super::DSF;
142+
let mut forest = DSF::<&str>::new();
143+
// Synonyms for good
144+
let t1 = forest.make_set("good");
145+
let t2 = forest.make_set("splendid");
146+
let t3 = forest.make_set("remarkable");
147+
let t4 = forest.make_set("nice");
148+
let t5 = forest.make_set("amazing");
149+
150+
// Assert Singleton Trees
151+
assert_ne!(forest.find_set(&t1), forest.find_set(&t2));
152+
assert_ne!(forest.find_set(&t2), forest.find_set(&t3));
153+
assert_ne!(forest.find_set(&t3), forest.find_set(&t4));
154+
assert_ne!(forest.find_set(&t4), forest.find_set(&t5));
155+
156+
// Synonyms for bad
157+
let t6 = forest.make_set("bad");
158+
let t7 = forest.make_set("schlecht");
159+
let t8 = forest.make_set("unpleasany");
160+
let t9 = forest.make_set("poor");
161+
162+
// Assert Singleton Trees
163+
assert_ne!(forest.find_set(&t6), forest.find_set(&t7));
164+
assert_ne!(forest.find_set(&t7), forest.find_set(&t8));
165+
assert_ne!(forest.find_set(&t8), forest.find_set(&t9));
166+
167+
// Union Galore
168+
forest.union(&t1, &t2);
169+
forest.union(&t1, &t3);
170+
forest.union(&t2, &t4);
171+
forest.union(&t5, &t3);
172+
173+
forest.union(&t6, &t7);
174+
forest.union(&t8, &t9);
175+
forest.union(&t9, &t7);
176+
177+
// Assert Only 2 disjoint sets
178+
//
179+
// First Set
180+
assert_eq!(forest.find_set(&t1), forest.find_set(&t2));
181+
assert_eq!(forest.find_set(&t2), forest.find_set(&t3));
182+
assert_eq!(forest.find_set(&t3), forest.find_set(&t4));
183+
assert_eq!(forest.find_set(&t4), forest.find_set(&t5));
184+
185+
// Second Set
186+
assert_eq!(forest.find_set(&t6), forest.find_set(&t7));
187+
assert_eq!(forest.find_set(&t7), forest.find_set(&t8));
188+
assert_eq!(forest.find_set(&t8), forest.find_set(&t9));
189+
190+
// Assert Disjointness
191+
assert_ne!(forest.find_set(&t6), forest.find_set(&t1));
192+
assert_ne!(forest.find_set(&t7), forest.find_set(&t3));
193+
assert_ne!(forest.find_set(&t8), forest.find_set(&t4));
194+
}
195+
}

crates/common2/src/recursion/mod.rs

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
use std::hash::Hash;
2+
3+
use self::dsf::{DSFNodeHandle, DSF};
4+
use rustc_hash::FxHashMap;
5+
6+
mod dsf;
7+
8+
/// `RecursionConstituent` stores information about part of a recursion. Constituents
9+
/// of a single recursion can be joined using `RecursionHelper`.
10+
///
11+
/// `T` is the recursion's identifier type and `U` carries diagnostic information.
12+
#[derive(Eq, PartialEq, Clone, Debug, Hash)]
13+
pub struct RecursionConstituent<T, U>
14+
where
15+
T: PartialEq + Copy,
16+
{
17+
/// From where the constituent originates.
18+
pub from: (T, U),
19+
/// To where the constituent goes.
20+
pub to: (T, U),
21+
}
22+
23+
impl<T, U> RecursionConstituent<T, U>
24+
where
25+
T: PartialEq + Copy,
26+
{
27+
pub fn new(from: (T, U), to: (T, U)) -> Self {
28+
Self { from, to }
29+
}
30+
}
31+
32+
pub struct RecursionHelper<T, U>
33+
where
34+
T: PartialEq + Copy,
35+
{
36+
constituents: Vec<RecursionConstituent<T, U>>,
37+
forest: DSF<T>,
38+
trees: FxHashMap<T, DSFNodeHandle<T>>,
39+
}
40+
41+
/// `RecursionHelper` uses a disjoint set forest to unify constituents of recursions.
42+
impl<T, U> RecursionHelper<T, U>
43+
where
44+
T: Eq + Hash + PartialEq + Copy,
45+
{
46+
pub fn new(constituents: Vec<RecursionConstituent<T, U>>) -> Self {
47+
let mut forest = DSF::<_>::new();
48+
let trees: FxHashMap<_, _> = constituents
49+
.iter()
50+
.map(|constituent| (constituent.from.0, forest.make_set(constituent.from.0)))
51+
.collect();
52+
53+
for constituent in constituents.iter() {
54+
forest.union(&trees[&constituent.from.0], &trees[&constituent.to.0])
55+
}
56+
57+
Self {
58+
constituents,
59+
forest,
60+
trees,
61+
}
62+
}
63+
64+
/// Removes a set of disjoint constituents from the helper and returns them.
65+
///
66+
/// This should be called until the disjoint set is empty.
67+
pub fn remove_disjoint_set(&mut self) -> Option<Vec<RecursionConstituent<T, U>>> {
68+
let mut disjoint_set = vec![];
69+
let mut remaining_set = vec![];
70+
let mut set_id = None;
71+
72+
while let Some(constituent) = self.constituents.pop() {
73+
let cur_set_id = self.forest.find_set(&self.trees[&constituent.from.0]);
74+
75+
if set_id == None {
76+
set_id = Some(cur_set_id);
77+
disjoint_set.push(constituent);
78+
} else if set_id == Some(cur_set_id) {
79+
disjoint_set.push(constituent)
80+
} else {
81+
remaining_set.push(constituent)
82+
}
83+
}
84+
85+
self.constituents = remaining_set;
86+
87+
if set_id.is_some() {
88+
Some(disjoint_set)
89+
} else {
90+
None
91+
}
92+
}
93+
}
94+
95+
#[test]
96+
fn one_recursion() {
97+
let constituents = vec![
98+
RecursionConstituent::new((0, ()), (1, ())),
99+
RecursionConstituent::new((1, ()), (0, ())),
100+
];
101+
102+
let mut helper = RecursionHelper::new(constituents);
103+
let disjoint_constituents = helper.remove_disjoint_set();
104+
// panic!("{:?}", disjoint_constituents)
105+
// assert_eq!(disjoint_constituents[0].from.0, 0);
106+
// assert_eq!(disjoint_constituents[1].from.0, 0);
107+
}
108+
109+
#[test]
110+
fn two_recursions() {
111+
let constituents = vec![
112+
RecursionConstituent::new((0, ()), (1, ())),
113+
RecursionConstituent::new((1, ()), (0, ())),
114+
RecursionConstituent::new((2, ()), (3, ())),
115+
RecursionConstituent::new((3, ()), (4, ())),
116+
RecursionConstituent::new((4, ()), (2, ())),
117+
];
118+
119+
let mut helper = RecursionHelper::new(constituents);
120+
let disjoint_constituents1 = helper.remove_disjoint_set();
121+
let disjoint_constituents2 = helper.remove_disjoint_set();
122+
// panic!("{:?}", disjoint_constituents1)
123+
// assert_eq!(disjoint_constituents[0].from.0, 0);
124+
// assert_eq!(disjoint_constituents[1].from.0, 0);
125+
}

crates/hir-analysis/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ pub struct Jar(
7272
ty::diagnostics::ImplTraitDefDiagAccumulator,
7373
ty::diagnostics::ImplDefDiagAccumulator,
7474
ty::diagnostics::FuncDefDiagAccumulator,
75+
ty::diagnostics::AdtRecursionConstituentAccumulator,
7576
);
7677

7778
pub trait HirAnalysisDb: salsa::DbWithJar<Jar> + HirDb {

0 commit comments

Comments
 (0)