Skip to content

Commit

Permalink
Fix caching of panic results for if/match branches
Browse files Browse the repository at this point in the history
  • Loading branch information
fkettelhoit committed Dec 4, 2024
1 parent 7dbf606 commit b41a2e8
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 55 deletions.
146 changes: 97 additions & 49 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,7 @@ pub(crate) struct CircuitBuilder {
negated: HashMap<GateIndex, GateIndex>,
gates_optimized: usize,
gate_counter: usize,
panic_gates: PanicResult,
panic_wires: HashMap<usize, PanicResult>,
panic_gates: CachedPanicResult,
consts: HashMap<String, usize>,
}

Expand All @@ -287,6 +286,12 @@ pub struct PanicResult {
pub end_column: [GateIndex; USIZE_BITS],
}

#[derive(Debug, Clone)]
pub(crate) struct CachedPanicResult {
result: PanicResult,
cache: HashMap<usize, PanicResult>,
}

impl PanicResult {
/// Returns a `PanicResult` indicating that no panic has occurred.
pub fn ok() -> Self {
Expand Down Expand Up @@ -402,8 +407,10 @@ impl CircuitBuilder {
negated: HashMap::new(),
gates_optimized: 0,
gate_counter,
panic_gates: PanicResult::ok(),
panic_wires: HashMap::new(),
panic_gates: CachedPanicResult {
result: PanicResult::ok(),
cache: HashMap::new(),
},
consts,
}
}
Expand All @@ -418,12 +425,12 @@ impl CircuitBuilder {
// inputs (and their inputs, etc.) as used:
let shift = self.shift;
let mut output_gate_stack = output_gates.clone();
output_gate_stack.push(self.panic_gates.has_panicked);
output_gate_stack.extend(self.panic_gates.panic_type.iter());
output_gate_stack.extend(self.panic_gates.start_line.iter());
output_gate_stack.extend(self.panic_gates.start_column.iter());
output_gate_stack.extend(self.panic_gates.end_line.iter());
output_gate_stack.extend(self.panic_gates.end_column.iter());
output_gate_stack.push(self.panic_gates.result.has_panicked);
output_gate_stack.extend(self.panic_gates.result.panic_type.iter());
output_gate_stack.extend(self.panic_gates.result.start_line.iter());
output_gate_stack.extend(self.panic_gates.result.start_column.iter());
output_gate_stack.extend(self.panic_gates.result.end_line.iter());
output_gate_stack.extend(self.panic_gates.result.end_column.iter());
let mut used_gates = vec![false; self.gates.len()];
while let Some(gate_index) = output_gate_stack.pop() {
if gate_index >= shift {
Expand Down Expand Up @@ -466,21 +473,21 @@ impl CircuitBuilder {
*x = shift_gate_index_if_necessary(*x);
*y = shift_gate_index_if_necessary(*y);
}
self.panic_gates.has_panicked =
shift_gate_index_if_necessary(self.panic_gates.has_panicked);
for w in self.panic_gates.panic_type.iter_mut() {
self.panic_gates.result.has_panicked =
shift_gate_index_if_necessary(self.panic_gates.result.has_panicked);
for w in self.panic_gates.result.panic_type.iter_mut() {
*w = shift_gate_index_if_necessary(*w);
}
for w in self.panic_gates.start_line.iter_mut() {
for w in self.panic_gates.result.start_line.iter_mut() {
*w = shift_gate_index_if_necessary(*w);
}
for w in self.panic_gates.start_column.iter_mut() {
for w in self.panic_gates.result.start_column.iter_mut() {
*w = shift_gate_index_if_necessary(*w);
}
for w in self.panic_gates.end_line.iter_mut() {
for w in self.panic_gates.result.end_line.iter_mut() {
*w = shift_gate_index_if_necessary(*w);
}
for w in self.panic_gates.end_column.iter_mut() {
for w in self.panic_gates.result.end_column.iter_mut() {
*w = shift_gate_index_if_necessary(*w);
}
let mut without_unused_gates = Vec::with_capacity(self.gates.len() - unused_gates);
Expand Down Expand Up @@ -563,12 +570,22 @@ impl CircuitBuilder {
}
indexes
};
panic_and_output.push(shift_gate_index_if_necessary(self.panic_gates.has_panicked));
panic_and_output.extend(shift_indexes_if_necessary(self.panic_gates.panic_type));
panic_and_output.extend(shift_indexes_if_necessary(self.panic_gates.start_line));
panic_and_output.extend(shift_indexes_if_necessary(self.panic_gates.start_column));
panic_and_output.extend(shift_indexes_if_necessary(self.panic_gates.end_line));
panic_and_output.extend(shift_indexes_if_necessary(self.panic_gates.end_column));
panic_and_output.push(shift_gate_index_if_necessary(
self.panic_gates.result.has_panicked,
));
panic_and_output.extend(shift_indexes_if_necessary(
self.panic_gates.result.panic_type,
));
panic_and_output.extend(shift_indexes_if_necessary(
self.panic_gates.result.start_line,
));
panic_and_output.extend(shift_indexes_if_necessary(
self.panic_gates.result.start_column,
));
panic_and_output.extend(shift_indexes_if_necessary(self.panic_gates.result.end_line));
panic_and_output.extend(shift_indexes_if_necessary(
self.panic_gates.result.end_column,
));

panic_and_output.extend(output_gates.into_iter().map(shift_gate_index_if_necessary));

Expand All @@ -580,12 +597,13 @@ impl CircuitBuilder {
}

pub fn push_panic_if(&mut self, cond: GateIndex, reason: PanicReason, meta: MetaInfo) {
if let Some(existing_panic) = self.panic_wires.get(&cond) {
self.panic_gates = existing_panic.clone();
if let Some(existing_panic) = self.panic_gates.cache.get(&cond) {
self.panic_gates.result = existing_panic.clone();
return;
}
let already_panicked = self.panic_gates.has_panicked;
self.panic_gates.has_panicked = self.push_or(self.panic_gates.has_panicked, cond);
let already_panicked = self.panic_gates.result.has_panicked;
self.panic_gates.result.has_panicked =
self.push_or(self.panic_gates.result.has_panicked, cond);
let current = PanicResult {
has_panicked: 1,
panic_type: reason.as_bits(),
Expand All @@ -594,72 +612,102 @@ impl CircuitBuilder {
end_line: unsigned_as_usize_bits(meta.end.0 as u64),
end_column: unsigned_as_usize_bits(meta.end.1 as u64),
};
for i in 0..self.panic_gates.start_line.len() {
self.panic_gates.start_line[i] = self.push_mux(
for i in 0..self.panic_gates.result.start_line.len() {
self.panic_gates.result.start_line[i] = self.push_mux(
already_panicked,
self.panic_gates.start_line[i],
self.panic_gates.result.start_line[i],
current.start_line[i],
);
self.panic_gates.start_column[i] = self.push_mux(
self.panic_gates.result.start_column[i] = self.push_mux(
already_panicked,
self.panic_gates.start_column[i],
self.panic_gates.result.start_column[i],
current.start_column[i],
);
self.panic_gates.end_line[i] = self.push_mux(
self.panic_gates.result.end_line[i] = self.push_mux(
already_panicked,
self.panic_gates.end_line[i],
self.panic_gates.result.end_line[i],
current.end_line[i],
);
self.panic_gates.end_column[i] = self.push_mux(
self.panic_gates.result.end_column[i] = self.push_mux(
already_panicked,
self.panic_gates.end_column[i],
self.panic_gates.result.end_column[i],
current.end_column[i],
);
}
for i in 0..current.panic_type.len() {
self.panic_gates.panic_type[i] = self.push_mux(
self.panic_gates.result.panic_type[i] = self.push_mux(
already_panicked,
self.panic_gates.panic_type[i],
self.panic_gates.result.panic_type[i],
current.panic_type[i],
);
}
self.panic_wires.insert(cond, self.panic_gates.clone());
self.panic_gates
.cache
.insert(cond, self.panic_gates.result.clone());
}

pub fn peek_panic(&self) -> &PanicResult {
pub fn peek_panic(&self) -> &CachedPanicResult {
&self.panic_gates
}

pub fn replace_panic_with(&mut self, p: PanicResult) -> PanicResult {
pub fn replace_panic_with(&mut self, p: CachedPanicResult) -> CachedPanicResult {
std::mem::replace(&mut self.panic_gates, p)
}

pub fn mux_panic(
&mut self,
condition: GateIndex,
CachedPanicResult {
result: t,
cache: cache_t,
}: &CachedPanicResult,
CachedPanicResult {
result: f,
cache: cache_f,
}: &CachedPanicResult,
) -> CachedPanicResult {
let result = self.mux_uncached_panic(condition, t, f);
let mut cache = HashMap::new();
for k in cache_t.keys().chain(cache_f.keys()) {
match (cache_t.get(k), cache_f.get(k)) {
(None, None) => {}
(None, Some(result)) | (Some(result), None) => {
cache.insert(*k, result.clone());
}
(Some(t), Some(f)) => {
cache.insert(*k, self.mux_uncached_panic(condition, t, f));
}
}
}
CachedPanicResult { result, cache }
}

fn mux_uncached_panic(
&mut self,
condition: GateIndex,
t: &PanicResult,
f: &PanicResult,
) -> PanicResult {
let mut panic_gates = PanicResult::ok();
panic_gates.has_panicked = self.push_mux(condition, t.has_panicked, f.has_panicked);
let mut result = PanicResult::ok();
result.has_panicked = self.push_mux(condition, t.has_panicked, f.has_panicked);
for (i, (&if_true, &if_false)) in t.panic_type.iter().zip(f.panic_type.iter()).enumerate() {
panic_gates.panic_type[i] = self.push_mux(condition, if_true, if_false);
result.panic_type[i] = self.push_mux(condition, if_true, if_false);
}
for (i, (&if_true, &if_false)) in t.start_line.iter().zip(f.start_line.iter()).enumerate() {
panic_gates.start_line[i] = self.push_mux(condition, if_true, if_false);
result.start_line[i] = self.push_mux(condition, if_true, if_false);
}
for (i, (&if_true, &if_false)) in
t.start_column.iter().zip(f.start_column.iter()).enumerate()
{
panic_gates.start_column[i] = self.push_mux(condition, if_true, if_false);
result.start_column[i] = self.push_mux(condition, if_true, if_false);
}
for (i, (&if_true, &if_false)) in t.end_line.iter().zip(f.end_line.iter()).enumerate() {
panic_gates.end_line[i] = self.push_mux(condition, if_true, if_false);
result.end_line[i] = self.push_mux(condition, if_true, if_false);
}
for (i, (&if_true, &if_false)) in t.end_column.iter().zip(f.end_column.iter()).enumerate() {
panic_gates.end_column[i] = self.push_mux(condition, if_true, if_false);
result.end_column[i] = self.push_mux(condition, if_true, if_false);
}
panic_gates
result
}

pub fn mux_envs(
Expand Down
5 changes: 3 additions & 2 deletions src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
ConstExpr, ConstExprEnum, EnumDef, ExprEnum, Op, Pattern, PatternEnum, StmtEnum, StructDef,
Type, UnaryOp, VariantExprEnum,
},
circuit::{Circuit, CircuitBuilder, GateIndex, PanicReason, PanicResult, USIZE_BITS},
circuit::{Circuit, CircuitBuilder, GateIndex, PanicReason, USIZE_BITS},
env::Env,
literal::Literal,
token::{MetaInfo, SignedNumType, UnsignedNumType},
Expand Down Expand Up @@ -1057,12 +1057,13 @@ impl TypedExpr {
let mut muxed_ret_expr = vec![0; bits];
let mut muxed_panic = circuit.peek_panic().clone();
let mut muxed_env = env.clone();
let panic_before_match = circuit.peek_panic().clone();

for (pattern, ret_expr) in clauses {
let mut env = env.clone();
env.push();

circuit.replace_panic_with(PanicResult::ok());
circuit.replace_panic_with(panic_before_match.clone());

let is_match = pattern.compile(&expr, prg, &mut env, circuit);
let ret_expr = ret_expr.compile(prg, &mut env, circuit);
Expand Down
8 changes: 4 additions & 4 deletions tests/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ pub fn main(b: bool, x: i32) -> bool {
#[test]
fn optimize_same_expr2() -> Result<(), String> {
let unoptimized = "
pub fn main(b: bool, x: i32) -> i32 {
if b { x * x } else { x * x }
pub fn main(x: i32) -> i32 {
(x * x) + (x * x)
}
";
let optimized = "
pub fn main(b: bool, x: i32) -> i32 {
pub fn main(x: i32) -> i32 {
let y = x * x;
if b { y } else { y }
y + y
}
";
let unoptimized = compile(unoptimized).map_err(|e| e.prettify(unoptimized))?;
Expand Down

0 comments on commit b41a2e8

Please sign in to comment.