Skip to content

Commit

Permalink
Allow compilation without panic circuits
Browse files Browse the repository at this point in the history
  • Loading branch information
fkettelhoit committed Dec 2, 2024
1 parent 0522b1c commit 1bf95b7
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 3 deletions.
11 changes: 10 additions & 1 deletion src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ pub(crate) struct CircuitBuilder {
gate_counter: usize,
panic_gates: PanicResult,
consts: HashMap<String, usize>,
panic_enabled: bool,
}

pub(crate) const USIZE_BITS: usize = 32;
Expand Down Expand Up @@ -388,7 +389,11 @@ impl PanicReason {
}

impl CircuitBuilder {
pub fn new(input_gates: Vec<usize>, consts: HashMap<String, usize>) -> Self {
pub fn new(
input_gates: Vec<usize>,
consts: HashMap<String, usize>,
panic_enabled: bool,
) -> Self {
let mut gate_counter = 2; // for const true and false
for input_gates_of_party in input_gates.iter() {
gate_counter += input_gates_of_party;
Expand All @@ -403,6 +408,7 @@ impl CircuitBuilder {
gate_counter,
panic_gates: PanicResult::ok(),
consts,
panic_enabled,
}
}

Expand Down Expand Up @@ -578,6 +584,9 @@ impl CircuitBuilder {
}

pub fn push_panic_if(&mut self, cond: GateIndex, reason: PanicReason, meta: MetaInfo) {
if !self.panic_enabled {
return;
}
let already_panicked = self.panic_gates.has_panicked;
self.panic_gates.has_panicked = self.push_or(self.panic_gates.has_panicked, cond);
let current = PanicResult {
Expand Down
35 changes: 34 additions & 1 deletion src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ impl TypedProgram {
.map(|(c, f, _)| (c, f))
}

/// Compiles the (type-checked) program, _silently ignoring panics_.
///
/// Assumes that the input program has been correctly type-checked and **panics** if
/// incompatible types are found that should have been caught by the type-checker.
pub fn compile_ignore_panic(
&self,
fn_name: &str,
) -> Result<(Circuit, &TypedFnDef), Vec<CompilerError>> {
self.compile_with_constants_ignore_panic(fn_name, HashMap::new())
.map(|(c, f, _)| (c, f))
}

/// Compiles the (type-checked) program with provided constants, producing a circuit of gates.
///
/// Assumes that the input program has been correctly type-checked and **panics** if
Expand All @@ -94,6 +106,27 @@ impl TypedProgram {
&self,
fn_name: &str,
consts: HashMap<String, HashMap<String, Literal>>,
) -> Result<CompiledProgram, Vec<CompilerError>> {
self.comp_with_constants(fn_name, consts, true)
}

/// Compiles the (type-checked) program with provided constants, _silently ignoring panics_.
///
/// Assumes that the input program has been correctly type-checked and **panics** if
/// incompatible types are found that should have been caught by the type-checker.
pub fn compile_with_constants_ignore_panic(
&self,
fn_name: &str,
consts: HashMap<String, HashMap<String, Literal>>,
) -> Result<CompiledProgram, Vec<CompilerError>> {
self.comp_with_constants(fn_name, consts, false)
}

fn comp_with_constants(
&self,
fn_name: &str,
consts: HashMap<String, HashMap<String, Literal>>,
panic_enabled: bool,
) -> Result<CompiledProgram, Vec<CompilerError>> {
let mut env = Env::new();
let mut const_sizes = HashMap::new();
Expand Down Expand Up @@ -249,7 +282,7 @@ impl TypedProgram {
input_gates.push(type_size);
env.let_in_current_scope(param.name.clone(), wires);
}
let mut circuit = CircuitBuilder::new(input_gates, const_sizes.clone());
let mut circuit = CircuitBuilder::new(input_gates, const_sizes.clone(), panic_enabled);
for (const_name, const_def) in self.const_defs.iter() {
let ConstExpr(expr, _) = &const_def.value;
match expr {
Expand Down
32 changes: 32 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,20 @@ pub fn compile(prg: &str) -> Result<GarbleProgram, Error> {
})
}

/// Scans, parses, type-checks and then compiles the `"main"` fn of a program to a boolean circuit.
pub fn compile_ignore_panic(prg: &str) -> Result<GarbleProgram, Error> {
let program = check(prg)?;
let (circuit, main) = program.compile_ignore_panic("main")?;
let main = main.clone();
Ok(GarbleProgram {
program,
main,
circuit,
consts: HashMap::new(),
const_sizes: HashMap::new(),
})
}

/// Scans, parses, type-checks and then compiles the `"main"` fn of a program to a boolean circuit.
pub fn compile_with_constants(
prg: &str,
Expand All @@ -127,6 +141,24 @@ pub fn compile_with_constants(
})
}

/// Scans, parses, type-checks and then compiles the `"main"` fn of a program to a boolean circuit.
pub fn compile_with_constants_ignore_panic(
prg: &str,
consts: HashMap<String, HashMap<String, Literal>>,
) -> Result<GarbleProgram, Error> {
let program = check(prg)?;
let (circuit, main, const_sizes) =
program.compile_with_constants_ignore_panic("main", consts.clone())?;
let main = main.clone();
Ok(GarbleProgram {
program,
main,
circuit,
consts,
const_sizes,
})
}

/// The result of type-checking and compiling a Garble program.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down
64 changes: 63 additions & 1 deletion tests/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use garble_lang::compile;
use garble_lang::{compile, compile_ignore_panic};

#[test]
fn optimize_or() -> Result<(), String> {
Expand Down Expand Up @@ -66,6 +66,68 @@ pub fn main(b: bool, x: i32) -> bool {
Ok(())
}

#[test]
fn optimize_same_expr2() -> Result<(), String> {
let unoptimized = "
pub fn main(b: bool, x: i32) -> i32 {
if b { x * x } else { x * x }
}
";
let optimized = "
pub fn main(b: bool, x: i32) -> i32 {
let y = x * x;
if b { y } else { y }
}
";
let unoptimized = compile_ignore_panic(unoptimized).map_err(|e| e.prettify(unoptimized))?;
let optimized = compile_ignore_panic(optimized).map_err(|e| e.prettify(optimized))?;
assert_eq!(
unoptimized.circuit.gates.len(),
optimized.circuit.gates.len()
);
Ok(())
}

#[test]
fn optimize_same_expr3() -> Result<(), String> {
let unoptimized = "
pub fn main(input1: i8, input2: i8) -> bool {
let _unused = add(input1, input2);
square(input1) < input2 || square(input1) > input2
}
fn square(num: i8) -> i8 {
num * num
}
fn add(a: i8, b: i8) -> i8 {
a + b
}
";
let optimized = "
pub fn main(input1: i8, input2: i8) -> bool {
let _unused = add(input1, input2);
let squared = square(input1);
squared < input2 || squared > input2
}
fn square(num: i8) -> i8 {
num * num
}
fn add(a: i8, b: i8) -> i8 {
a + b
}
";
let unoptimized = compile_ignore_panic(unoptimized).map_err(|e| e.prettify(unoptimized))?;
let optimized = compile_ignore_panic(optimized).map_err(|e| e.prettify(optimized))?;
assert_eq!(
unoptimized.circuit.gates.len(),
optimized.circuit.gates.len()
);
Ok(())
}

#[test]
fn optimize_not_equivalence() -> Result<(), String> {
let unoptimized = "
Expand Down

0 comments on commit 1bf95b7

Please sign in to comment.