Skip to content

Commit e632e55

Browse files
authored
Montgomery Multiplication for refactored metal EC backend (#23)
* feat: import mont_mul backend from https://github.com/geometryxyz/msl-secp256k1 * feat: conversion between bigint and arbitrary limb size * test(bigint): add host test * test(bigint): adapt from https://github.com/geometryxyz/msl-secp256k1 * refactor: add overflow detection and correct suitable bigint val for each cases * chore: update path for contants.metal * test(mont_mul): adapted mont mul tests from https://github.com/geometryxyz/msl-secp256k1 * feat(mont_mul): adapted utils function related to mont_mul from https://github.com/geometers/multiprecision * feat: add limbs conversion from ark_ff bigint to arbitrary limbs in Vec<u32> * chore: update path for utils for limb conversion * fix(mont_mul): correct the conversion from arkworks' scalarfield to arbitrary limbs of Vec<u32> * test(mont_mul): adapted benchmark functions from https://github.com/geometryxyz/msl-secp256k1 * feat(metal): add mont_mul cios * test(mont_mul): add cios mont_mul test and benchmark
1 parent a0ae553 commit e632e55

22 files changed

+1197
-74
lines changed

Cargo.lock

+136-29
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mopro-msm/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ witness = { git = "https://github.com/philsippl/circom-witness-rs.git" }
6060

6161
[dev-dependencies]
6262
serial_test = "3.0.0"
63+
stopwatch = "0.0.7"
6364

6465
# [dependencies.rayon]
6566
# version = "1"

mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs

-39
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ pub trait ToLimbs {
1212
pub trait FromLimbs {
1313
fn from_u32_limbs(limbs: &[u32]) -> Self;
1414
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self;
15-
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self;
1615
fn from_u128(num: u128) -> Self;
1716
fn from_u32(num: u32) -> Self;
1817
}
@@ -78,10 +77,6 @@ impl ToLimbs for Fq {
7877
fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec<u32> {
7978
self.0.to_limbs(num_limbs, log_limb_size)
8079
}
81-
82-
fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec<u32> {
83-
self.0.to_limbs(num_limbs, log_limb_size)
84-
}
8580
}
8681

8782
impl FromLimbs for BigInteger256 {
@@ -134,35 +129,6 @@ impl FromLimbs for BigInteger256 {
134129

135130
BigInteger256::new(result)
136131
}
137-
138-
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self {
139-
let mut result = [0u64; 4];
140-
let limb_size = log_limb_size as usize;
141-
let mut accumulated_bits = 0;
142-
let mut current_u64 = 0u64;
143-
let mut result_idx = 0;
144-
145-
for &limb in limbs {
146-
// Add the current limb at the appropriate position
147-
current_u64 |= (limb as u64) << accumulated_bits;
148-
accumulated_bits += limb_size;
149-
150-
// If we've accumulated 64 bits or more, store the result
151-
while accumulated_bits >= 64 && result_idx < 4 {
152-
result[result_idx] = current_u64;
153-
current_u64 = limb as u64 >> (limb_size - (accumulated_bits - 64));
154-
accumulated_bits -= 64;
155-
result_idx += 1;
156-
}
157-
}
158-
159-
// Handle any remaining bits
160-
if accumulated_bits > 0 && result_idx < 4 {
161-
result[result_idx] = current_u64;
162-
}
163-
164-
BigInteger256::new(result)
165-
}
166132
}
167133

168134
impl FromLimbs for Fq {
@@ -193,9 +159,4 @@ impl FromLimbs for Fq {
193159
let bigint = BigInteger256::from_limbs(limbs, log_limb_size);
194160
Fq::new(mont_reduction::raw_reduction(bigint))
195161
}
196-
197-
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self {
198-
let bigint = BigInteger256::from_limbs(limbs, log_limb_size);
199-
Fq::new(mont_reduction::raw_reduction(bigint))
200-
}
201162
}

mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal

+74-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// source: https://github.com/geometryxyz/msl-secp256k1
1+
// adapted from: https://github.com/geometryxyz/msl-secp256k1
22

33
using namespace metal;
44
#include <metal_stdlib>
@@ -92,3 +92,76 @@ BigInt mont_mul_modified(
9292

9393
return conditional_reduce(s, p);
9494
}
95+
96+
/// The CIOS method for Montgomery multiplication from Tolga Acar's thesis:
97+
/// High-Speed Algorithms & Architectures For Number-Theoretic Cryptosystems
98+
/// https://www.proquest.com/openview/1018972f191afe55443658b28041c118/1
99+
BigInt mont_mul_cios(
100+
BigInt x,
101+
BigInt y,
102+
BigInt p
103+
) {
104+
uint t[NUM_LIMBS + 2] = {0}; // Extra space for carries
105+
BigInt result;
106+
107+
for (uint i = 0; i < NUM_LIMBS; i++) {
108+
// Step 1: Multiply and add
109+
uint c = 0;
110+
for (uint j = 0; j < NUM_LIMBS; j++) {
111+
uint r = t[j] + x.limbs[j] * y.limbs[i] + c;
112+
c = r >> LOG_LIMB_SIZE;
113+
t[j] = r & MASK;
114+
}
115+
uint r = t[NUM_LIMBS] + c;
116+
t[NUM_LIMBS + 1] = r >> LOG_LIMB_SIZE;
117+
t[NUM_LIMBS] = r & MASK;
118+
119+
// Step 2: Reduce
120+
uint m = (t[0] * N0) & MASK;
121+
r = t[0] + m * p.limbs[0];
122+
c = r >> LOG_LIMB_SIZE;
123+
124+
for (uint j = 1; j < NUM_LIMBS; j++) {
125+
r = t[j] + m * p.limbs[j] + c;
126+
c = r >> LOG_LIMB_SIZE;
127+
t[j - 1] = r & MASK;
128+
}
129+
130+
r = t[NUM_LIMBS] + c;
131+
c = r >> LOG_LIMB_SIZE;
132+
t[NUM_LIMBS - 1] = r & MASK;
133+
t[NUM_LIMBS] = t[NUM_LIMBS + 1] + c;
134+
}
135+
136+
// Final reduction check
137+
bool t_lt_p = false;
138+
for (uint idx = 0; idx < NUM_LIMBS; idx++) {
139+
uint i = NUM_LIMBS - 1 - idx;
140+
if (t[i] < p.limbs[i]) {
141+
t_lt_p = true;
142+
break;
143+
} else if (t[i] > p.limbs[i]) {
144+
break;
145+
}
146+
}
147+
148+
if (t_lt_p) {
149+
for (uint i = 0; i < NUM_LIMBS; i++) {
150+
result.limbs[i] = t[i];
151+
}
152+
} else {
153+
uint borrow = 0;
154+
for (uint i = 0; i < NUM_LIMBS; i++) {
155+
uint diff = t[i] - p.limbs[i] - borrow;
156+
if (t[i] < (p.limbs[i] + borrow)) {
157+
diff += (1 << LOG_LIMB_SIZE);
158+
borrow = 1;
159+
} else {
160+
borrow = 0;
161+
}
162+
result.limbs[i] = diff;
163+
}
164+
}
165+
166+
return result;
167+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// source: https://github.com/geometryxyz/msl-secp256k1
2+
3+
using namespace metal;
4+
#include <metal_stdlib>
5+
#include <metal_math>
6+
#include "mont.metal"
7+
8+
kernel void run(
9+
device BigInt* lhs [[ buffer(0) ]],
10+
device BigInt* rhs [[ buffer(1) ]],
11+
device BigInt* prime [[ buffer(2) ]],
12+
device BigInt* result [[ buffer(3) ]],
13+
uint gid [[ thread_position_in_grid ]]
14+
) {
15+
BigInt a;
16+
BigInt b;
17+
BigInt p;
18+
a.limbs = lhs->limbs;
19+
b.limbs = rhs->limbs;
20+
p.limbs = prime->limbs;
21+
22+
BigInt res = mont_mul_cios(a, b, p);
23+
result->limbs = res.limbs;
24+
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using namespace metal;
2+
#include <metal_stdlib>
3+
#include <metal_math>
4+
#include "mont.metal"
5+
6+
kernel void run(
7+
device BigInt* lhs [[ buffer(0) ]],
8+
device BigInt* rhs [[ buffer(1) ]],
9+
device BigInt* prime [[ buffer(2) ]],
10+
device array<uint, 1>* cost [[ buffer(3) ]],
11+
device BigInt* result [[ buffer(4) ]],
12+
uint gid [[ thread_position_in_grid ]]
13+
) {
14+
BigInt a;
15+
BigInt b;
16+
BigInt p;
17+
a.limbs = lhs->limbs;
18+
b.limbs = rhs->limbs;
19+
p.limbs = prime->limbs;
20+
array<uint, 1> cost_arr = *cost;
21+
22+
BigInt c = mont_mul_cios(a, a, p);
23+
for (uint i = 1; i < cost_arr[0]; i ++) {
24+
c = mont_mul_cios(c, a, p);
25+
}
26+
BigInt res = mont_mul_cios(c, b, p);
27+
result->limbs = res.limbs;
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using namespace metal;
2+
#include <metal_stdlib>
3+
#include <metal_math>
4+
#include "mont.metal"
5+
6+
kernel void run(
7+
device BigInt* lhs [[ buffer(0) ]],
8+
device BigInt* rhs [[ buffer(1) ]],
9+
device BigInt* prime [[ buffer(2) ]],
10+
device array<uint, 1>* cost [[ buffer(3) ]],
11+
device BigInt* result [[ buffer(4) ]],
12+
uint gid [[ thread_position_in_grid ]]
13+
) {
14+
BigInt a;
15+
BigInt b;
16+
BigInt p;
17+
a.limbs = lhs->limbs;
18+
b.limbs = rhs->limbs;
19+
p.limbs = prime->limbs;
20+
array<uint, 1> cost_arr = *cost;
21+
22+
BigInt c = mont_mul_modified(a, a, p);
23+
for (uint i = 1; i < cost_arr[0]; i ++) {
24+
c = mont_mul_modified(c, a, p);
25+
}
26+
BigInt res = mont_mul_modified(c, b, p);
27+
result->limbs = res.limbs;
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using namespace metal;
2+
#include <metal_stdlib>
3+
#include <metal_math>
4+
#include "mont.metal"
5+
6+
kernel void run(
7+
device BigInt* lhs [[ buffer(0) ]],
8+
device BigInt* rhs [[ buffer(1) ]],
9+
device BigInt* prime [[ buffer(2) ]],
10+
device array<uint, 1>* cost [[ buffer(3) ]],
11+
device BigInt* result [[ buffer(4) ]],
12+
uint gid [[ thread_position_in_grid ]]
13+
) {
14+
BigInt a;
15+
BigInt b;
16+
BigInt p;
17+
a.limbs = lhs->limbs;
18+
b.limbs = rhs->limbs;
19+
p.limbs = prime->limbs;
20+
array<uint, 1> cost_arr = *cost;
21+
22+
BigInt c = mont_mul_optimised(a, a, p);
23+
for (uint i = 1; i < cost_arr[0]; i ++) {
24+
c = mont_mul_optimised(c, a, p);
25+
}
26+
BigInt res = mont_mul_optimised(c, b, p);
27+
result->limbs = res.limbs;
28+
}

mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// adapted from: https://github.com/geometryxyz/msl-secp256k1
22

3-
use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs};
43
use crate::msm::metal_msm::host::gpu::{
54
create_buffer, create_empty_buffer, get_default_device, read_buffer,
65
};
76
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
7+
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
88
use ark_ff::{BigInt, BigInteger};
99
use metal::*;
1010

mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// adapted from: https://github.com/geometryxyz/msl-secp256k1
22

3-
use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs};
43
use crate::msm::metal_msm::host::gpu::{
54
create_buffer, create_empty_buffer, get_default_device, read_buffer,
65
};
76
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
7+
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
88
use ark_ff::{BigInt, BigInteger};
99
use metal::*;
1010

mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
use core::borrow;
44

5-
use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs};
65
use crate::msm::metal_msm::host::gpu::{
76
create_buffer, create_empty_buffer, get_default_device, read_buffer,
87
};
98
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
9+
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
1010
use ark_ff::{BigInt, BigInteger};
1111
use metal::*;
1212

mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// adapted from: https://github.com/geometryxyz/msl-secp256k1
22

3-
use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs};
43
use crate::msm::metal_msm::host::gpu::{
54
create_buffer, create_empty_buffer, get_default_device, read_buffer,
65
};
76
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
7+
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
88
use ark_bn254::Fr as ScalarField;
99
use ark_ff::{BigInt, BigInteger, PrimeField};
1010
use metal::*;

mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// adapted from: https://github.com/geometryxyz/msl-secp256k1
22

3-
use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs};
43
use crate::msm::metal_msm::host::gpu::{
54
create_buffer, create_empty_buffer, get_default_device, read_buffer,
65
};
76
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
7+
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
88
use ark_bn254::Fr as ScalarField;
99
use ark_ff::{BigInt, BigInteger, PrimeField};
1010
use metal::*;

mopro-msm/src/msm/metal_msm/tests/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
pub mod bigint;
33
#[cfg(test)]
44
pub mod field;
5+
#[cfg(test)]
6+
pub mod mont_backend;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#[cfg(test)]
2+
pub mod mont_benchmarks;
3+
#[cfg(test)]
4+
pub mod mont_mul_cios;
5+
#[cfg(test)]
6+
pub mod mont_mul_modified;
7+
#[cfg(test)]
8+
pub mod mont_mul_optimised;

0 commit comments

Comments
 (0)