Skip to content

Commit bf1636b

Browse files
authored
Bigint impl for refactored metal EC backend (#21)
* feat: import mont_mul backend from https://github.com/geometryxyz/msl-secp256k1 * chore: add refactored metal_msm, will remove the previous one once this is completed * chore: migrate prev utils mod to refactored metal msm * feat: conversion between bigint and arbitrary limb size * lint * chore: ignore all metal ir and lib * test(bigint): add host test * test(bigint): adapt from https://github.com/geometryxyz/msl-secp256k1 * chore: ignore all constants file since it's been automatically generated * refactor: add overflow detection and correct suitable bigint val for each cases
1 parent a553c49 commit bf1636b

39 files changed

+2702
-1
lines changed

Cargo.lock

+175
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mopro-msm/.gitignore

+8-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,11 @@ Cargo.lock
1717
src/middleware/gpu_explorations/utils/vectors
1818

1919
# GPU exploration - proptest generated files
20-
proptest-regressions
20+
proptest-regressions
21+
22+
# Metal shader intermediate files and libraries
23+
src/msm/metal_msm/shader/**/*.ir
24+
src/msm/metal_msm/shader/**/*.lib
25+
26+
# Metal shader constants file
27+
src/msm/metal_msm/shader/**/constants.metal

mopro-msm/Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ serde_derive = "1.0"
5858
wasmer = { git = "https://github.com/oskarth/wasmer.git", rev = "09c7070" }
5959
witness = { git = "https://github.com/philsippl/circom-witness-rs.git" }
6060

61+
[dev-dependencies]
62+
serial_test = "3.0.0"
63+
6164
# [dependencies.rayon]
6265
# version = "1"
6366
# optional=false

mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs

+71
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ use crate::msm::metal::abstraction::mont_reduction;
66
// implement to_u32_limbs and from_u32_limbs for BigInt<4>
77
pub trait ToLimbs {
88
fn to_u32_limbs(&self) -> Vec<u32>;
9+
fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec<u32>;
910
}
1011

1112
pub trait FromLimbs {
1213
fn from_u32_limbs(limbs: &[u32]) -> Self;
14+
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self;
1315
fn from_u128(num: u128) -> Self;
1416
fn from_u32(num: u32) -> Self;
1517
}
@@ -26,6 +28,37 @@ impl ToLimbs for BigInteger256 {
2628
});
2729
limbs
2830
}
31+
32+
fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec<u32> {
33+
let mut result = vec![0u32; num_limbs];
34+
let limb_size = 1u32 << log_limb_size;
35+
let mask = limb_size - 1;
36+
37+
// Convert to little-endian representation
38+
let bytes = self.to_bytes_le();
39+
let mut val = 0u32;
40+
let mut bits = 0u32;
41+
let mut limb_idx = 0;
42+
43+
for &byte in bytes.iter() {
44+
val |= (byte as u32) << bits;
45+
bits += 8;
46+
47+
while bits >= log_limb_size && limb_idx < num_limbs {
48+
result[limb_idx] = val & mask;
49+
val >>= log_limb_size;
50+
bits -= log_limb_size;
51+
limb_idx += 1;
52+
}
53+
}
54+
55+
// Handle any remaining bits
56+
if bits > 0 && limb_idx < num_limbs {
57+
result[limb_idx] = val;
58+
}
59+
60+
result
61+
}
2962
}
3063

3164
// convert from little endian to big endian
@@ -40,6 +73,10 @@ impl ToLimbs for Fq {
4073
});
4174
limbs
4275
}
76+
77+
fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec<u32> {
78+
self.0.to_limbs(num_limbs, log_limb_size)
79+
}
4380
}
4481

4582
impl FromLimbs for BigInteger256 {
@@ -63,6 +100,35 @@ impl FromLimbs for BigInteger256 {
63100
fn from_u32(num: u32) -> Self {
64101
BigInteger256::new([num as u64, 0, 0, 0])
65102
}
103+
104+
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self {
105+
let mut result = [0u64; 4];
106+
let limb_size = log_limb_size as usize;
107+
let mut accumulated_bits = 0;
108+
let mut current_u64 = 0u64;
109+
let mut result_idx = 0;
110+
111+
for &limb in limbs {
112+
// Add the current limb at the appropriate position
113+
current_u64 |= (limb as u64) << accumulated_bits;
114+
accumulated_bits += limb_size;
115+
116+
// If we've accumulated 64 bits or more, store the result
117+
while accumulated_bits >= 64 && result_idx < 4 {
118+
result[result_idx] = current_u64;
119+
current_u64 = limb as u64 >> (limb_size - (accumulated_bits - 64));
120+
accumulated_bits -= 64;
121+
result_idx += 1;
122+
}
123+
}
124+
125+
// Handle any remaining bits
126+
if accumulated_bits > 0 && result_idx < 4 {
127+
result[result_idx] = current_u64;
128+
}
129+
130+
BigInteger256::new(result)
131+
}
66132
}
67133

68134
impl FromLimbs for Fq {
@@ -88,4 +154,9 @@ impl FromLimbs for Fq {
88154
num as u64, 0, 0, 0,
89155
])))
90156
}
157+
158+
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self {
159+
let bigint = BigInteger256::from_limbs(limbs, log_limb_size);
160+
Fq::new(mont_reduction::raw_reduction(bigint))
161+
}
91162
}
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
use thiserror::Error;
2+
3+
#[derive(Debug, Error)]
4+
pub enum MetalError {
5+
#[error("Couldn't find a system default device for Metal")]
6+
DeviceNotFound(),
7+
#[error("Couldn't create a new Metal library: {0}")]
8+
LibraryError(String),
9+
#[error("Couldn't create a new Metal function object: {0}")]
10+
FunctionError(String),
11+
#[error("Couldn't create a new Metal compute pipeline: {0}")]
12+
PipelineError(String),
13+
#[error("Could not calculate {1} root of unity")]
14+
RootOfUnityError(String, u64),
15+
// #[error("Input length is {0}, which is not a power of two")]
16+
// InputError(usize),
17+
#[error("Invalid input: {0}")]
18+
InputError(String),
19+
}

0 commit comments

Comments
 (0)