Skip to content

Commit a0ae553

Browse files
authored
Field impl for refactored metal EC backend (#22)
* feat: conversion between bigint and arbitrary limb size * lint * test(bigint): adapt from https://github.com/geometryxyz/msl-secp256k1 * refactor: add overflow detection and correct suitable bigint val for each cases * test(field): adapt the ff tests from https://github.com/geometryxyz/msl-secp256k1 * test(field): add check to scalarfield modulus match * chore: correct the docs, use new refactored code for correct implementation * chore: update path for contants.metal
1 parent bf1636b commit a0ae553

File tree

12 files changed

+285
-16
lines changed

12 files changed

+285
-16
lines changed

mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs

+39
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pub trait ToLimbs {
1212
pub trait FromLimbs {
1313
fn from_u32_limbs(limbs: &[u32]) -> Self;
1414
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self;
15+
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self;
1516
fn from_u128(num: u128) -> Self;
1617
fn from_u32(num: u32) -> Self;
1718
}
@@ -77,6 +78,10 @@ impl ToLimbs for Fq {
7778
fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec<u32> {
7879
self.0.to_limbs(num_limbs, log_limb_size)
7980
}
81+
82+
fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec<u32> {
83+
self.0.to_limbs(num_limbs, log_limb_size)
84+
}
8085
}
8186

8287
impl FromLimbs for BigInteger256 {
@@ -129,6 +134,35 @@ impl FromLimbs for BigInteger256 {
129134

130135
BigInteger256::new(result)
131136
}
137+
138+
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self {
139+
let mut result = [0u64; 4];
140+
let limb_size = log_limb_size as usize;
141+
let mut accumulated_bits = 0;
142+
let mut current_u64 = 0u64;
143+
let mut result_idx = 0;
144+
145+
for &limb in limbs {
146+
// Add the current limb at the appropriate position
147+
current_u64 |= (limb as u64) << accumulated_bits;
148+
accumulated_bits += limb_size;
149+
150+
// If we've accumulated 64 bits or more, store the result
151+
while accumulated_bits >= 64 && result_idx < 4 {
152+
result[result_idx] = current_u64;
153+
current_u64 = limb as u64 >> (limb_size - (accumulated_bits - 64));
154+
accumulated_bits -= 64;
155+
result_idx += 1;
156+
}
157+
}
158+
159+
// Handle any remaining bits
160+
if accumulated_bits > 0 && result_idx < 4 {
161+
result[result_idx] = current_u64;
162+
}
163+
164+
BigInteger256::new(result)
165+
}
132166
}
133167

134168
impl FromLimbs for Fq {
@@ -159,4 +193,9 @@ impl FromLimbs for Fq {
159193
let bigint = BigInteger256::from_limbs(limbs, log_limb_size);
160194
Fq::new(mont_reduction::raw_reduction(bigint))
161195
}
196+
197+
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self {
198+
let bigint = BigInteger256::from_limbs(limbs, log_limb_size);
199+
Fq::new(mont_reduction::raw_reduction(bigint))
200+
}
162201
}

mopro-msm/src/msm/metal/shader/fields/fp_bn254.h.metal

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ namespace {
99
}
1010

1111
/* Constants for bn254 field operations
12-
* N: base field modulus
12+
* N: scalar field modulus
1313
* R_SQUARED: R^2 mod N
1414
* R_SUB_N: R - N
1515
* MU: Montgomery Multiplication Constant = -N^{-1} mod (2^32)
1616
*
17-
* For bn254, the modulus is "21888242871839275222246405745257275088696311157297823662689037894645226208583" [1, 2]
17+
* For bn254, the modulus is "21888242871839275222246405745257275088548364400416034343698204186575808495617" [1, 2]
1818
* We use 8 limbs of 32 bits unsigned integers to represent the constanst
1919
*
2020
* References:

mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// source: https://github.com/geometryxyz/msl-secp256k1
22

33
using namespace metal;
4-
#include "constants.metal"
4+
#include "../constants.metal"
55

66
struct BigInt {
77
array<uint, NUM_LIMBS> limbs;

mopro-msm/src/msm/metal_msm/shader/bigint/constants.metal

-8
This file was deleted.

mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ pub fn test_bigint_add_unsafe() {
4141
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);
4242

4343
write_constants(
44-
"../mopro-msm/src/msm/metal_msm/shader/bigint",
44+
"../mopro-msm/src/msm/metal_msm/shader",
4545
num_limbs,
4646
log_limb_size,
4747
0,

mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ pub fn test_bigint_add() {
4545
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);
4646

4747
write_constants(
48-
"../mopro-msm/src/msm/metal_msm/shader/bigint",
48+
"../mopro-msm/src/msm/metal_msm/shader",
4949
num_limbs,
5050
log_limb_size,
5151
0,
@@ -133,7 +133,7 @@ pub fn test_bigint_add_no_overflow() {
133133
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);
134134

135135
write_constants(
136-
"../mopro-msm/src/msm/metal_msm/shader/bigint",
136+
"../mopro-msm/src/msm/metal_msm/shader",
137137
num_limbs,
138138
log_limb_size,
139139
0,

mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ pub fn test_bigint_sub() {
3535
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);
3636

3737
write_constants(
38-
"../mopro-msm/src/msm/metal_msm/shader/bigint",
38+
"../mopro-msm/src/msm/metal_msm/shader",
3939
num_limbs,
4040
log_limb_size,
4141
0,
@@ -129,7 +129,7 @@ fn test_bigint_sub_underflow() {
129129
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);
130130

131131
write_constants(
132-
"../mopro-msm/src/msm/metal_msm/shader/bigint",
132+
"../mopro-msm/src/msm/metal_msm/shader",
133133
num_limbs,
134134
log_limb_size,
135135
0,
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#[cfg(test)]
12
pub mod bigint_add_unsafe;
3+
#[cfg(test)]
24
pub mod bigint_add_wide;
5+
#[cfg(test)]
36
pub mod bigint_sub;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// adapted from: https://github.com/geometryxyz/msl-secp256k1
2+
3+
use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs};
4+
use crate::msm::metal_msm::host::gpu::{
5+
create_buffer, create_empty_buffer, get_default_device, read_buffer,
6+
};
7+
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
8+
use ark_bn254::Fr as ScalarField;
9+
use ark_ff::{BigInt, BigInteger, PrimeField};
10+
use metal::*;
11+
12+
#[test]
13+
#[serial_test::serial]
14+
pub fn test_ff_add() {
15+
let log_limb_size = 13;
16+
let num_limbs = 20;
17+
18+
// Scalar field modulus for bn254
19+
let p = BigInt::new([
20+
0x43E1F593F0000001,
21+
0x2833E84879B97091,
22+
0xB85045B68181585D,
23+
0x30644E72E131A029,
24+
]);
25+
assert!(p == ScalarField::MODULUS);
26+
27+
let a = BigInt::new([
28+
0x43E1F593F0000001,
29+
0x2833E84879B97091,
30+
0xB85045B68181585D,
31+
0x30644E72E131A028,
32+
]);
33+
let b = BigInt::new([
34+
0x43E1F593F0000001,
35+
0x2833E84879B97091,
36+
0xB85045B68181585D,
37+
0x30644E7200000000,
38+
]);
39+
40+
let device = get_default_device();
41+
let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size));
42+
let b_buf = create_buffer(&device, &b.to_limbs(num_limbs, log_limb_size));
43+
let p_buf = create_buffer(&device, &p.to_limbs(num_limbs, log_limb_size));
44+
let result_buf = create_empty_buffer(&device, num_limbs);
45+
46+
// Perform (a + b) % p
47+
let mut expected = a.clone();
48+
expected.add_with_carry(&b);
49+
50+
// While result >= p, subtract p
51+
while expected >= p {
52+
expected.sub_with_borrow(&p);
53+
}
54+
let expected_limbs = expected.to_limbs(num_limbs, log_limb_size);
55+
56+
let command_queue = device.new_command_queue();
57+
let command_buffer = command_queue.new_command_buffer();
58+
59+
let compute_pass_descriptor = ComputePassDescriptor::new();
60+
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);
61+
62+
write_constants(
63+
"../mopro-msm/src/msm/metal_msm/shader",
64+
num_limbs,
65+
log_limb_size,
66+
0,
67+
0,
68+
);
69+
let library_path = compile_metal(
70+
"../mopro-msm/src/msm/metal_msm/shader/field",
71+
"ff_add.metal",
72+
);
73+
let library = device.new_library_with_file(library_path).unwrap();
74+
let kernel = library.get_function("run", None).unwrap();
75+
76+
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
77+
pipeline_state_descriptor.set_compute_function(Some(&kernel));
78+
79+
let pipeline_state = device
80+
.new_compute_pipeline_state_with_function(
81+
pipeline_state_descriptor.compute_function().unwrap(),
82+
)
83+
.unwrap();
84+
85+
encoder.set_compute_pipeline_state(&pipeline_state);
86+
encoder.set_buffer(0, Some(&a_buf), 0);
87+
encoder.set_buffer(1, Some(&b_buf), 0);
88+
encoder.set_buffer(2, Some(&p_buf), 0);
89+
encoder.set_buffer(3, Some(&result_buf), 0);
90+
91+
let thread_group_count = MTLSize {
92+
width: 1,
93+
height: 1,
94+
depth: 1,
95+
};
96+
97+
let thread_group_size = MTLSize {
98+
width: 1,
99+
height: 1,
100+
depth: 1,
101+
};
102+
103+
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
104+
encoder.end_encoding();
105+
106+
command_buffer.commit();
107+
command_buffer.wait_until_completed();
108+
109+
let result_limbs: Vec<u32> = read_buffer(&result_buf, num_limbs);
110+
let result = BigInt::from_limbs(&result_limbs, log_limb_size);
111+
112+
assert!(result == expected);
113+
assert!(result_limbs == expected_limbs);
114+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// adapted from: https://github.com/geometryxyz/msl-secp256k1
2+
3+
use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs};
4+
use crate::msm::metal_msm::host::gpu::{
5+
create_buffer, create_empty_buffer, get_default_device, read_buffer,
6+
};
7+
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
8+
use ark_bn254::Fr as ScalarField;
9+
use ark_ff::{BigInt, BigInteger, PrimeField};
10+
use metal::*;
11+
12+
#[test]
13+
#[serial_test::serial]
14+
pub fn test_ff_sub() {
15+
let log_limb_size = 13;
16+
let num_limbs = 20;
17+
18+
// Scalar field modulus for bn254
19+
let p = BigInt::new([
20+
0x43E1F593F0000001,
21+
0x2833E84879B97091,
22+
0xB85045B68181585D,
23+
0x30644E72E131A029,
24+
]);
25+
assert!(p == ScalarField::MODULUS);
26+
27+
let a = BigInt::new([
28+
0x43E1F593F0000001,
29+
0x2833E84879B97091,
30+
0xB85045B68181585D,
31+
0x30644E72E131A028,
32+
]);
33+
let b = BigInt::new([
34+
0xAAAAAAAAF0000001,
35+
0x2833E84879B97091,
36+
0xB85045B68181585D,
37+
0x30644E7200000000,
38+
]);
39+
40+
let device = get_default_device();
41+
let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size));
42+
let b_buf = create_buffer(&device, &b.to_limbs(num_limbs, log_limb_size));
43+
let p_buf = create_buffer(&device, &p.to_limbs(num_limbs, log_limb_size));
44+
let result_buf = create_empty_buffer(&device, num_limbs);
45+
46+
// Perform (a - b) % p
47+
let mut expected = a.clone();
48+
expected.sub_with_borrow(&b);
49+
50+
// If result is negative, add p until it's positive
51+
while expected < BigInt::zero() {
52+
expected.add_with_carry(&p);
53+
}
54+
let expected_limbs = expected.to_limbs(num_limbs, log_limb_size);
55+
56+
let command_queue = device.new_command_queue();
57+
let command_buffer = command_queue.new_command_buffer();
58+
59+
let compute_pass_descriptor = ComputePassDescriptor::new();
60+
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);
61+
62+
write_constants(
63+
"../mopro-msm/src/msm/metal_msm/shader",
64+
num_limbs,
65+
log_limb_size,
66+
0,
67+
0,
68+
);
69+
let library_path = compile_metal(
70+
"../mopro-msm/src/msm/metal_msm/shader/field",
71+
"ff_sub.metal",
72+
);
73+
let library = device.new_library_with_file(library_path).unwrap();
74+
let kernel = library.get_function("run", None).unwrap();
75+
76+
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
77+
pipeline_state_descriptor.set_compute_function(Some(&kernel));
78+
79+
let pipeline_state = device
80+
.new_compute_pipeline_state_with_function(
81+
pipeline_state_descriptor.compute_function().unwrap(),
82+
)
83+
.unwrap();
84+
85+
encoder.set_compute_pipeline_state(&pipeline_state);
86+
encoder.set_buffer(0, Some(&a_buf), 0);
87+
encoder.set_buffer(1, Some(&b_buf), 0);
88+
encoder.set_buffer(2, Some(&p_buf), 0);
89+
encoder.set_buffer(3, Some(&result_buf), 0);
90+
91+
let thread_group_count = MTLSize {
92+
width: 1,
93+
height: 1,
94+
depth: 1,
95+
};
96+
97+
let thread_group_size = MTLSize {
98+
width: 1,
99+
height: 1,
100+
depth: 1,
101+
};
102+
103+
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
104+
encoder.end_encoding();
105+
106+
command_buffer.commit();
107+
command_buffer.wait_until_completed();
108+
109+
let result_limbs: Vec<u32> = read_buffer(&result_buf, num_limbs);
110+
let result = BigInt::from_limbs(&result_limbs, log_limb_size);
111+
112+
assert!(result == expected);
113+
assert!(result_limbs == expected_limbs);
114+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#[cfg(test)]
2+
pub mod ff_add;
3+
#[cfg(test)]
4+
pub mod ff_sub;
+3
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1+
#[cfg(test)]
12
pub mod bigint;
3+
#[cfg(test)]
4+
pub mod field;

0 commit comments

Comments
 (0)