Skip to content

Commit

Permalink
refactor: integrate clk and ci registers to IO interaction trace eval…
Browse files Browse the repository at this point in the history
…uation
  • Loading branch information
zmalatrax committed Dec 16, 2024
1 parent 8e586c2 commit 0e602b3
Showing 1 changed file with 97 additions and 30 deletions.
127 changes: 97 additions & 30 deletions crates/brainfuck_prover/src/components/io/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::components::{
io::component::InteractionClaim, IoClaim, TraceColumn, TraceError, TraceEval,
};
use brainfuck_vm::{instruction::InstructionType, registers::Registers};
use num_traits::One;
use num_traits::{One, Zero};
use stwo_prover::{
constraint_framework::{
logup::{LogupTraceGenerator, LookupElements},
Expand Down Expand Up @@ -278,12 +278,17 @@ pub fn interaction_trace_evaluation(
let mut logup_gen = LogupTraceGenerator::new(log_size);
let mut col_gen = logup_gen.new_col();

let mv_col = &main_trace_eval[IoColumn::Io.index()].data;
for (vec_row, mv) in mv_col.iter().enumerate().take(1 << (log_size - LOG_N_LANES)) {
// We want to prove that the I/O table is a sublist (ordered set inclusion) of the Processor
// table.
let num = -PackedSecureField::one();
let denom: PackedSecureField = lookup_elements.combine(&[*mv]);
let clk_col = &main_trace_eval[IoColumn::Clk.index()].data;
let ci_col = &main_trace_eval[IoColumn::Ci.index()].data;
let mv_col = &main_trace_eval[IoColumn::Mv.index()].data;
for vec_row in 0..1 << (log_size - LOG_N_LANES) {
let clk = clk_col[vec_row];
let ci = ci_col[vec_row];
let mv = mv_col[vec_row];
// We want to prove that the I/O table is a sublist (ordered set inclusion)
// of the Processor table.
let num = if ci.is_zero() { PackedSecureField::zero() } else { -PackedSecureField::one() };
let denom: PackedSecureField = lookup_elements.combine(&[clk, ci, mv]);
col_gen.write_frac(vec_row, num, denom);
}

Expand All @@ -298,15 +303,21 @@ pub fn interaction_trace_evaluation(
mod tests {
use super::*;
use num_traits::{One, Zero};
use stwo_prover::core::channel::Blake2sChannel;

type TestIOTable = IOTable<10>;

#[test]
fn test_io_row_new() {
let row = IOTableRow::new(BaseField::zero(), BaseField::from(46), BaseField::from(91));
let expected_row =
IOTableRow { clk: BaseField::zero(), ci: BaseField::from(46), mv: BaseField::from(91) };
let row = IOTableRow::new(
BaseField::zero(),
InstructionType::PutChar.to_base_field(),
BaseField::from(91),
);
let expected_row = IOTableRow {
clk: BaseField::zero(),
ci: InstructionType::PutChar.to_base_field(),
mv: BaseField::from(91),
};
assert_eq!(row, expected_row);
}

Expand All @@ -320,9 +331,17 @@ mod tests {
fn test_table_add_row_from_register() {
let mut io_table = TestIOTable::new();
// Create a row to add to the table
let row = IOTableRow::new(BaseField::zero(), BaseField::from(46), BaseField::from(91));
let row = IOTableRow::new(
BaseField::zero(),
InstructionType::PutChar.to_base_field(),
BaseField::from(91),
);
// Add the row to the table
io_table.add_row_from_register(BaseField::zero(), BaseField::from(46), BaseField::from(91));
io_table.add_row_from_register(
BaseField::zero(),
InstructionType::PutChar.to_base_field(),
BaseField::from(91),
);
// Check that the table contains the added row
assert_eq!(io_table.table, vec![row], "Added row should match the expected row.");
}
Expand All @@ -331,7 +350,11 @@ mod tests {
fn test_table_add_row() {
let mut io_table = TestIOTable::new();
// Create a row to add to the table
let row = IOTableRow::new(BaseField::zero(), BaseField::from(46), BaseField::from(91));
let row = IOTableRow::new(
BaseField::zero(),
InstructionType::PutChar.to_base_field(),
BaseField::from(91),
);
// Add the row to the table
io_table.add_row(row.clone());
// Check that the table contains the added row
Expand All @@ -343,9 +366,21 @@ mod tests {
let mut io_table = TestIOTable::new();
// Create a vector of rows to add to the table
let rows = vec![
IOTableRow::new(BaseField::zero(), BaseField::from(46), BaseField::from(91)),
IOTableRow::new(BaseField::one(), BaseField::from(46), BaseField::from(9)),
IOTableRow::new(BaseField::from(4), BaseField::from(46), BaseField::from(43)),
IOTableRow::new(
BaseField::zero(),
InstructionType::PutChar.to_base_field(),
BaseField::from(91),
),
IOTableRow::new(
BaseField::one(),
InstructionType::PutChar.to_base_field(),
BaseField::from(9),
),
IOTableRow::new(
BaseField::from(4),
InstructionType::PutChar.to_base_field(),
BaseField::from(43),
),
];
// Add the rows to the table
io_table.add_rows(rows.clone());
Expand Down Expand Up @@ -392,7 +427,11 @@ mod tests {
};
let registers: Vec<Registers> = vec![reg3, reg1, reg2];

let row = IOTableRow::new(BaseField::zero(), BaseField::from(46), BaseField::from(5));
let row = IOTableRow::new(
BaseField::zero(),
InstructionType::PutChar.to_base_field(),
BaseField::from(5),
);

let mut expected_io_table: OutputTable = IOTable::new();
expected_io_table.add_row(row);
Expand Down Expand Up @@ -520,8 +559,21 @@ mod tests {
fn test_trace_evaluation_circle_domain() {
let mut io_table = TestIOTable::new();
io_table.add_rows(vec![
IOTableRow::new(BaseField::zero(), BaseField::from(44), BaseField::one()),
IOTableRow::new(BaseField::one(), BaseField::from(44), BaseField::from(2)),
IOTableRow::new(
BaseField::zero(),
InstructionType::ReadChar.to_base_field(),
BaseField::one(),
),
IOTableRow::new(
BaseField::one(),
InstructionType::ReadChar.to_base_field(),
BaseField::from(2),
),
IOTableRow::new(
BaseField::from(3),
InstructionType::ReadChar.to_base_field(),
BaseField::from(7),
),
]);

let (trace, claim) = io_table.trace_evaluation();
Expand All @@ -542,20 +594,31 @@ mod tests {
// Trace rows are:
// - Real row
// - Real row
// - Dummy row (padding to the power of 2)
// - Real row
// - Dummy row (padding to the power of 2)
let rows = vec![
IOTableRow::new(BaseField::one()),
IOTableRow::new(BaseField::from(2)),
IOTableRow::new(BaseField::zero()),
IOTableRow::new(BaseField::zero()),
IOTableRow::new(
BaseField::zero(),
InstructionType::ReadChar.to_base_field(),
BaseField::one(),
),
IOTableRow::new(
BaseField::one(),
InstructionType::ReadChar.to_base_field(),
BaseField::from(2),
),
IOTableRow::new(
BaseField::from(3),
InstructionType::ReadChar.to_base_field(),
BaseField::from(7),
),
IOTableRow::new(BaseField::zero(), BaseField::zero(), BaseField::zero()),
];
io_table.add_rows(rows);

let (trace_eval, claim) = io_table.trace_evaluation();

let channel = &mut Blake2sChannel::default();
let lookup_elements = IoElements::draw(channel);
let lookup_elements = IoElements::dummy();

let (interaction_trace_eval, interaction_claim) =
interaction_trace_evaluation(&trace_eval, &lookup_elements).unwrap();
Expand All @@ -565,19 +628,23 @@ mod tests {
let mut col_gen = logup_gen.new_col();

let mut denoms = [PackedSecureField::zero(); 4];
let mv_col = &trace_eval[IoColumn::Io.index()].data;
let clk_col = &trace_eval[IoColumn::Clk.index()].data;
let ci_col = &trace_eval[IoColumn::Ci.index()].data;
let mv_col = &trace_eval[IoColumn::Mv.index()].data;
// Construct the denominator for each row of the logUp column, from the main trace
// evaluation.
for vec_row in 0..1 << (log_size - LOG_N_LANES) {
let clk = clk_col[vec_row];
let ci = ci_col[vec_row];
let mv = mv_col[vec_row];
let denom: PackedSecureField = lookup_elements.combine(&[mv]);
let denom: PackedSecureField = lookup_elements.combine(&[clk, ci, mv]);
denoms[vec_row] = denom;
}

let num_0 = -PackedSecureField::one();
let num_1 = -PackedSecureField::one();
let num_2 = -PackedSecureField::one();
let num_3 = -PackedSecureField::one();
let num_3 = PackedSecureField::zero();

col_gen.write_frac(0, num_0, denoms[0]);
col_gen.write_frac(1, num_1, denoms[1]);
Expand Down

0 comments on commit 0e602b3

Please sign in to comment.