Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: I/O interaction trace evaluation #130

Merged
merged 5 commits into from
Dec 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 194 additions & 15 deletions crates/brainfuck_prover/src/components/io/table.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use crate::components::{IoClaim, TraceColumn, TraceEval};
use crate::components::{
io::component::InteractionClaim, IoClaim, TraceColumn, TraceError, TraceEval,
};
use brainfuck_vm::{instruction::InstructionType, registers::Registers};
use num_traits::{One, Zero};
use stwo_prover::{
constraint_framework::{logup::LookupElements, Relation, RelationEFTraitBound},
constraint_framework::{
logup::{LogupTraceGenerator, LookupElements},
Relation, RelationEFTraitBound,
},
core::{
backend::{
simd::{column::BaseColumn, m31::LOG_N_LANES},
simd::{column::BaseColumn, m31::LOG_N_LANES, qm31::PackedSecureField},
Column,
},
channel::Channel,
Expand Down Expand Up @@ -248,6 +254,51 @@ impl<F: Clone, EF: RelationEFTraitBound<F>> Relation<F, EF> for IoElements {
}
}

/// Creates the interaction trace from the main trace evaluation
/// and the interaction elements for the I/O components.
///
/// The Processor component uses the other components:
/// The Processor component multiplicities are then positive,
/// and the I/O components' multiplicities are negative
/// in the logUp protocol.
///
/// # Returns
/// - Interaction trace evaluation, to be committed.
/// - Interaction claim: the total sum from the logUp protocol,
/// to be mixed into the Fiat-Shamir [`Channel`].
#[allow(clippy::similar_names)]
pub fn interaction_trace_evaluation(
main_trace_eval: &TraceEval,
lookup_elements: &IoElements,
) -> Result<(TraceEval, InteractionClaim), TraceError> {
// If the main trace of the I/O components is empty, then we claimed that it's log size is zero.
let log_size =
if main_trace_eval.is_empty() { 0 } else { main_trace_eval[0].domain.log_size() };

let mut logup_gen = LogupTraceGenerator::new(log_size);
let mut col_gen = logup_gen.new_col();

let clk_col = &main_trace_eval[IoColumn::Clk.index()].data;
let ci_col = &main_trace_eval[IoColumn::Ci.index()].data;
let mv_col = &main_trace_eval[IoColumn::Mv.index()].data;
for vec_row in 0..1 << (log_size - LOG_N_LANES) {
let clk = clk_col[vec_row];
let ci = ci_col[vec_row];
let mv = mv_col[vec_row];
// We want to prove that the I/O table is a sublist (ordered set inclusion)
// of the Processor table.
let num = if ci.is_zero() { PackedSecureField::zero() } else { -PackedSecureField::one() };
let denom: PackedSecureField = lookup_elements.combine(&[clk, ci, mv]);
col_gen.write_frac(vec_row, num, denom);
}

col_gen.finalize_col();

let (trace, claimed_sum) = logup_gen.finalize_last();

Ok((trace, InteractionClaim { claimed_sum }))
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -257,9 +308,16 @@ mod tests {

#[test]
fn test_io_row_new() {
let row = IOTableRow::new(BaseField::zero(), BaseField::from(46), BaseField::from(91));
let expected_row =
IOTableRow { clk: BaseField::zero(), ci: BaseField::from(46), mv: BaseField::from(91) };
let row = IOTableRow::new(
BaseField::zero(),
InstructionType::PutChar.to_base_field(),
BaseField::from(91),
);
let expected_row = IOTableRow {
clk: BaseField::zero(),
ci: InstructionType::PutChar.to_base_field(),
mv: BaseField::from(91),
};
assert_eq!(row, expected_row);
}

Expand All @@ -273,9 +331,17 @@ mod tests {
fn test_table_add_row_from_register() {
let mut io_table = TestIOTable::new();
// Create a row to add to the table
let row = IOTableRow::new(BaseField::zero(), BaseField::from(46), BaseField::from(91));
let row = IOTableRow::new(
BaseField::zero(),
InstructionType::PutChar.to_base_field(),
BaseField::from(91),
);
// Add the row to the table
io_table.add_row_from_register(BaseField::zero(), BaseField::from(46), BaseField::from(91));
io_table.add_row_from_register(
BaseField::zero(),
InstructionType::PutChar.to_base_field(),
BaseField::from(91),
);
// Check that the table contains the added row
assert_eq!(io_table.table, vec![row], "Added row should match the expected row.");
}
Expand All @@ -284,7 +350,11 @@ mod tests {
fn test_table_add_row() {
let mut io_table = TestIOTable::new();
// Create a row to add to the table
let row = IOTableRow::new(BaseField::zero(), BaseField::from(46), BaseField::from(91));
let row = IOTableRow::new(
BaseField::zero(),
InstructionType::PutChar.to_base_field(),
BaseField::from(91),
);
// Add the row to the table
io_table.add_row(row.clone());
// Check that the table contains the added row
Expand All @@ -296,9 +366,21 @@ mod tests {
let mut io_table = TestIOTable::new();
// Create a vector of rows to add to the table
let rows = vec![
IOTableRow::new(BaseField::zero(), BaseField::from(46), BaseField::from(91)),
IOTableRow::new(BaseField::one(), BaseField::from(46), BaseField::from(9)),
IOTableRow::new(BaseField::from(4), BaseField::from(46), BaseField::from(43)),
IOTableRow::new(
BaseField::zero(),
InstructionType::PutChar.to_base_field(),
BaseField::from(91),
),
IOTableRow::new(
BaseField::one(),
InstructionType::PutChar.to_base_field(),
BaseField::from(9),
),
IOTableRow::new(
BaseField::from(4),
InstructionType::PutChar.to_base_field(),
BaseField::from(43),
),
];
// Add the rows to the table
io_table.add_rows(rows.clone());
Expand Down Expand Up @@ -345,7 +427,11 @@ mod tests {
};
let registers: Vec<Registers> = vec![reg3, reg1, reg2];

let row = IOTableRow::new(BaseField::zero(), BaseField::from(46), BaseField::from(5));
let row = IOTableRow::new(
BaseField::zero(),
InstructionType::PutChar.to_base_field(),
BaseField::from(5),
);

let mut expected_io_table: OutputTable = IOTable::new();
expected_io_table.add_row(row);
Expand Down Expand Up @@ -473,8 +559,21 @@ mod tests {
fn test_trace_evaluation_circle_domain() {
let mut io_table = TestIOTable::new();
io_table.add_rows(vec![
IOTableRow::new(BaseField::zero(), BaseField::from(44), BaseField::one()),
IOTableRow::new(BaseField::one(), BaseField::from(44), BaseField::from(2)),
IOTableRow::new(
BaseField::zero(),
InstructionType::ReadChar.to_base_field(),
BaseField::one(),
),
IOTableRow::new(
BaseField::one(),
InstructionType::ReadChar.to_base_field(),
BaseField::from(2),
),
IOTableRow::new(
BaseField::from(3),
InstructionType::ReadChar.to_base_field(),
BaseField::from(7),
),
]);

let (trace, claim) = io_table.trace_evaluation();
Expand All @@ -488,4 +587,84 @@ mod tests {
);
}
}

#[test]
fn test_interaction_trace_evaluation() {
let mut io_table = TestIOTable::new();
// Trace rows are:
// - Real row
// - Real row
// - Real row
// - Dummy row (padding to the power of 2)
let rows = vec![
IOTableRow::new(
BaseField::zero(),
InstructionType::ReadChar.to_base_field(),
BaseField::one(),
),
IOTableRow::new(
BaseField::one(),
InstructionType::ReadChar.to_base_field(),
BaseField::from(2),
),
IOTableRow::new(
BaseField::from(3),
InstructionType::ReadChar.to_base_field(),
BaseField::from(7),
),
IOTableRow::new(BaseField::zero(), BaseField::zero(), BaseField::zero()),
];
io_table.add_rows(rows);

let (trace_eval, claim) = io_table.trace_evaluation();

let lookup_elements = IoElements::dummy();

let (interaction_trace_eval, interaction_claim) =
interaction_trace_evaluation(&trace_eval, &lookup_elements).unwrap();

let log_size = trace_eval[0].domain.log_size();
let mut logup_gen = LogupTraceGenerator::new(log_size);
let mut col_gen = logup_gen.new_col();

let mut denoms = [PackedSecureField::zero(); 4];
let clk_col = &trace_eval[IoColumn::Clk.index()].data;
let ci_col = &trace_eval[IoColumn::Ci.index()].data;
let mv_col = &trace_eval[IoColumn::Mv.index()].data;
// Construct the denominator for each row of the logUp column, from the main trace
// evaluation.
for vec_row in 0..1 << (log_size - LOG_N_LANES) {
let clk = clk_col[vec_row];
let ci = ci_col[vec_row];
let mv = mv_col[vec_row];
let denom: PackedSecureField = lookup_elements.combine(&[clk, ci, mv]);
denoms[vec_row] = denom;
}

let num_0 = -PackedSecureField::one();
let num_1 = -PackedSecureField::one();
let num_2 = -PackedSecureField::one();
let num_3 = PackedSecureField::zero();

col_gen.write_frac(0, num_0, denoms[0]);
col_gen.write_frac(1, num_1, denoms[1]);
col_gen.write_frac(2, num_2, denoms[2]);
col_gen.write_frac(3, num_3, denoms[3]);

col_gen.finalize_col();
let (expected_interaction_trace_eval, expected_claimed_sum) = logup_gen.finalize_last();

assert_eq!(claim.log_size, log_size,);
for col_index in 0..expected_interaction_trace_eval.len() {
assert_eq!(
interaction_trace_eval[col_index].domain,
expected_interaction_trace_eval[col_index].domain
);
assert_eq!(
interaction_trace_eval[col_index].to_cpu().values,
expected_interaction_trace_eval[col_index].to_cpu().values
);
}
assert_eq!(interaction_claim.claimed_sum, expected_claimed_sum);
}
}
Loading