-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(curve): add pADD, pMADD, pDBL algo in Jacobian Coordinates
* wip: rust host code for metal ec op * chore(mont): ignore heavy test * test(mont): change to suitable limb size * refactor(curve): use mont_cios for Jacobian Coordinates * feat(curve): impl jacobian_madd_2007_bl algorithm and corresponding tests * feat(tests): add jacobian_add_2007_b1 test and refactor jacobian_dbl_2009_l test
- Loading branch information
Showing
11 changed files
with
651 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
34 changes: 34 additions & 0 deletions
34
mopro-msm/src/msm/metal_msm/shader/curve/jacobian_madd_2007_bl.metal
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
// source: https://github.com/geometryxyz/msl-secp256k1 | ||
|
||
using namespace metal; | ||
#include <metal_stdlib> | ||
#include <metal_math> | ||
#include "jacobian.metal" | ||
|
||
kernel void run( | ||
device BigInt* prime [[ buffer(0) ]], | ||
device BigInt* a_xr [[ buffer(1) ]], | ||
device BigInt* a_yr [[ buffer(2) ]], | ||
device BigInt* a_zr [[ buffer(3) ]], | ||
device BigInt* b_xr [[ buffer(4) ]], | ||
device BigInt* b_yr [[ buffer(5) ]], | ||
device BigInt* result_xr [[ buffer(6) ]], | ||
device BigInt* result_yr [[ buffer(7) ]], | ||
device BigInt* result_zr [[ buffer(8) ]], | ||
uint gid [[ thread_position_in_grid ]] | ||
) { | ||
BigInt p; p.limbs = prime->limbs; | ||
BigInt x1; x1.limbs = a_xr->limbs; | ||
BigInt y1; y1.limbs = a_yr->limbs; | ||
BigInt z1; z1.limbs = a_zr->limbs; | ||
BigInt x2; x2.limbs = b_xr->limbs; | ||
BigInt y2; y2.limbs = b_yr->limbs; | ||
|
||
Jacobian a; a.x = x1; a.y = y1; a.z = z1; | ||
Affine b; b.x = x2; b.y = y2; | ||
|
||
Jacobian res = jacobian_madd_2007_bl(a, b, p); | ||
result_xr->limbs = res.x.limbs; | ||
result_yr->limbs = res.y.limbs; | ||
result_zr->limbs = res.z.limbs; | ||
} |
176 changes: 176 additions & 0 deletions
176
mopro-msm/src/msm/metal_msm/tests/curve/jacobian_add_2007_b1.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
// adapted from https://github.com/geometryxyz/msl-secp256k1 | ||
|
||
use ark_bn254::{Fq as BaseField, Fr as ScalarField, G1Affine as GAffine, G1Projective as G}; | ||
use ark_ec::AffineRepr; | ||
use metal::*; | ||
|
||
use crate::msm::metal_msm::host::gpu::{ | ||
create_buffer, create_empty_buffer, get_default_device, read_buffer, | ||
}; | ||
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; | ||
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs}; | ||
use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, calc_rinv_and_n0}; | ||
use ark_ff::{BigInt, PrimeField}; | ||
use ark_std::{rand::thread_rng, UniformRand}; | ||
use num_bigint::BigUint; | ||
|
||
#[test] | ||
#[serial_test::serial] | ||
pub fn test_jacobian_add_2007_bl() { | ||
let log_limb_size = 16; | ||
let p: BigUint = BaseField::MODULUS.try_into().unwrap(); | ||
|
||
let modulus_bits = BaseField::MODULUS_BIT_SIZE as u32; | ||
let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize; | ||
|
||
let r = calc_mont_radix(num_limbs, log_limb_size); | ||
let res = calc_rinv_and_n0(&p, &r, log_limb_size); | ||
let rinv = res.0; | ||
let n0 = res.1; | ||
let nsafe = calc_nsafe(log_limb_size); | ||
|
||
// Generate 2 random affine points | ||
let (a, b) = { | ||
let mut rng = thread_rng(); | ||
let base_point = GAffine::generator().into_group(); | ||
|
||
let s1 = ScalarField::rand(&mut rng); | ||
let mut s2 = ScalarField::rand(&mut rng); | ||
|
||
// Ensure s1 and s2 are different (if s1 == s2, we use pDBL instead of pADD) | ||
while s1 == s2 { | ||
s2 = ScalarField::rand(&mut rng); | ||
} | ||
|
||
(base_point * s1, base_point * s2) | ||
}; | ||
|
||
// Compute the sum in projective form using Arkworks | ||
let expected = a + b; | ||
|
||
let ax: BigUint = a.x.into(); | ||
let ay: BigUint = a.y.into(); | ||
let az: BigUint = a.z.into(); | ||
let bx: BigUint = b.x.into(); | ||
let by: BigUint = b.y.into(); | ||
let bz: BigUint = b.z.into(); | ||
|
||
let axr = (&ax * &r) % &p; | ||
let ayr = (&ay * &r) % &p; | ||
let azr = (&az * &r) % &p; | ||
let bxr = (&bx * &r) % &p; | ||
let byr = (&by * &r) % &p; | ||
let bzr = (&bz * &r) % &p; | ||
|
||
let p_limbs = BaseField::MODULUS.to_limbs(num_limbs, log_limb_size); | ||
let axr_limbs = ark_ff::BigInt::<4>::try_from(axr.clone()) | ||
.unwrap() | ||
.to_limbs(num_limbs, log_limb_size); | ||
let ayr_limbs = ark_ff::BigInt::<4>::try_from(ayr.clone()) | ||
.unwrap() | ||
.to_limbs(num_limbs, log_limb_size); | ||
let azr_limbs = ark_ff::BigInt::<4>::try_from(azr.clone()) | ||
.unwrap() | ||
.to_limbs(num_limbs, log_limb_size); | ||
let bxr_limbs = ark_ff::BigInt::<4>::try_from(bxr.clone()) | ||
.unwrap() | ||
.to_limbs(num_limbs, log_limb_size); | ||
let byr_limbs = ark_ff::BigInt::<4>::try_from(byr.clone()) | ||
.unwrap() | ||
.to_limbs(num_limbs, log_limb_size); | ||
let bzr_limbs = ark_ff::BigInt::<4>::try_from(bzr.clone()) | ||
.unwrap() | ||
.to_limbs(num_limbs, log_limb_size); | ||
|
||
let device = get_default_device(); | ||
let prime_buf = create_buffer(&device, &p_limbs); | ||
let axr_buf = create_buffer(&device, &axr_limbs); | ||
let ayr_buf = create_buffer(&device, &ayr_limbs); | ||
let azr_buf = create_buffer(&device, &azr_limbs); | ||
let bxr_buf = create_buffer(&device, &bxr_limbs); | ||
let byr_buf = create_buffer(&device, &byr_limbs); | ||
let bzr_buf = create_buffer(&device, &bzr_limbs); | ||
let result_xr_buf = create_empty_buffer(&device, num_limbs); | ||
let result_yr_buf = create_empty_buffer(&device, num_limbs); | ||
let result_zr_buf = create_empty_buffer(&device, num_limbs); | ||
|
||
let command_queue = device.new_command_queue(); | ||
let command_buffer = command_queue.new_command_buffer(); | ||
|
||
let compute_pass_descriptor = ComputePassDescriptor::new(); | ||
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); | ||
|
||
write_constants( | ||
"../mopro-msm/src/msm/metal_msm/shader", | ||
num_limbs, | ||
log_limb_size, | ||
n0, | ||
nsafe, | ||
); | ||
let library_path = compile_metal( | ||
"../mopro-msm/src/msm/metal_msm/shader/curve", | ||
"jacobian_add_2007_bl.metal", | ||
); | ||
let library = device.new_library_with_file(library_path).unwrap(); | ||
let kernel = library.get_function("run", None).unwrap(); | ||
|
||
let pipeline_state_descriptor = ComputePipelineDescriptor::new(); | ||
pipeline_state_descriptor.set_compute_function(Some(&kernel)); | ||
|
||
let pipeline_state = device | ||
.new_compute_pipeline_state_with_function( | ||
pipeline_state_descriptor.compute_function().unwrap(), | ||
) | ||
.unwrap(); | ||
|
||
encoder.set_compute_pipeline_state(&pipeline_state); | ||
encoder.set_buffer(0, Some(&prime_buf), 0); | ||
encoder.set_buffer(1, Some(&axr_buf), 0); | ||
encoder.set_buffer(2, Some(&ayr_buf), 0); | ||
encoder.set_buffer(3, Some(&azr_buf), 0); | ||
encoder.set_buffer(4, Some(&bxr_buf), 0); | ||
encoder.set_buffer(5, Some(&byr_buf), 0); | ||
encoder.set_buffer(6, Some(&bzr_buf), 0); | ||
encoder.set_buffer(7, Some(&result_xr_buf), 0); | ||
encoder.set_buffer(8, Some(&result_yr_buf), 0); | ||
encoder.set_buffer(9, Some(&result_zr_buf), 0); | ||
|
||
let thread_group_count = MTLSize { | ||
width: 1, | ||
height: 1, | ||
depth: 1, | ||
}; | ||
|
||
let thread_group_size = MTLSize { | ||
width: 1, | ||
height: 1, | ||
depth: 1, | ||
}; | ||
|
||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); | ||
encoder.end_encoding(); | ||
|
||
command_buffer.commit(); | ||
command_buffer.wait_until_completed(); | ||
|
||
let result_xr_limbs: Vec<u32> = read_buffer(&result_xr_buf, num_limbs); | ||
let result_yr_limbs: Vec<u32> = read_buffer(&result_yr_buf, num_limbs); | ||
let result_zr_limbs: Vec<u32> = read_buffer(&result_zr_buf, num_limbs); | ||
|
||
let result_xr: BigUint = BigInt::from_limbs(&result_xr_limbs, log_limb_size) | ||
.try_into() | ||
.unwrap(); | ||
let result_yr: BigUint = BigInt::from_limbs(&result_yr_limbs, log_limb_size) | ||
.try_into() | ||
.unwrap(); | ||
let result_zr: BigUint = BigInt::from_limbs(&result_zr_limbs, log_limb_size) | ||
.try_into() | ||
.unwrap(); | ||
|
||
let result_x = (result_xr * &rinv) % &p; | ||
let result_y = (result_yr * &rinv) % &p; | ||
let result_z = (result_zr * &rinv) % &p; | ||
|
||
let result = G::new(result_x.into(), result_y.into(), result_z.into()); | ||
assert!(result == expected); | ||
} |
Oops, something went wrong.