diff --git a/curve25519-dalek/benches/dalek_benchmarks.rs b/curve25519-dalek/benches/dalek_benchmarks.rs index b8bdd6772..5e6e4a6d0 100644 --- a/curve25519-dalek/benches/dalek_benchmarks.rs +++ b/curve25519-dalek/benches/dalek_benchmarks.rs @@ -155,7 +155,7 @@ mod multiscalar_benches { // rerandomize the scalars for every call just in case. b.iter_batched( || construct_scalars(size), - |scalars| EdwardsPoint::multiscalar_mul(&scalars, &points), + |scalars| EdwardsPoint::multiscalar_mul_alloc(points.iter().zip(scalars)), BatchSize::SmallInput, ); }, diff --git a/curve25519-dalek/src/backend/mod.rs b/curve25519-dalek/src/backend/mod.rs index cfb8b003f..3d57944d6 100644 --- a/curve25519-dalek/src/backend/mod.rs +++ b/curve25519-dalek/src/backend/mod.rs @@ -36,6 +36,7 @@ use crate::EdwardsPoint; use crate::Scalar; +use crate::traits::MultiscalarMul; pub mod serial; @@ -191,30 +192,48 @@ impl VartimePrecomputedStraus { } } +#[allow(missing_docs)] +pub fn straus_multiscalar_mul( + points_and_scalars: &[(EdwardsPoint, Scalar); N], +) -> EdwardsPoint { + match get_selected_backend() { + #[cfg(curve25519_dalek_backend = "simd")] + BackendKind::Avx2 => { + vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul(points_and_scalars) + } + #[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))] + BackendKind::Avx512 => { + vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul( + scalars, points, + ) + } + BackendKind::Serial => { + serial::scalar_mul::straus::Straus::multiscalar_mul(points_and_scalars) + } + } +} + #[allow(missing_docs)] #[cfg(feature = "alloc")] -pub fn straus_multiscalar_mul(scalars: I, points: J) -> EdwardsPoint +pub fn straus_multiscalar_mul_alloc(points_and_scalars: I) -> EdwardsPoint where - I: IntoIterator, - I::Item: core::borrow::Borrow, - J: IntoIterator, - J::Item: core::borrow::Borrow, + I: IntoIterator, + P: core::borrow::Borrow, + S: core::borrow::Borrow, { - use crate::traits::MultiscalarMul; - match get_selected_backend() { #[cfg(curve25519_dalek_backend = "simd")] BackendKind::Avx2 => { - vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul::(scalars, points) + vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul_alloc(points_and_scalars) } #[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))] BackendKind::Avx512 => { - vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul::( - scalars, points, + vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul_alloc( + points_and_scalars, ) } BackendKind::Serial => { - serial::scalar_mul::straus::Straus::multiscalar_mul::(scalars, points) + serial::scalar_mul::straus::Straus::multiscalar_mul_alloc(points_and_scalars) } } } diff --git a/curve25519-dalek/src/backend/serial/scalar_mul/mod.rs b/curve25519-dalek/src/backend/serial/scalar_mul/mod.rs index 7747decc3..658802c0f 100644 --- a/curve25519-dalek/src/backend/serial/scalar_mul/mod.rs +++ b/curve25519-dalek/src/backend/serial/scalar_mul/mod.rs @@ -23,7 +23,6 @@ pub mod variable_base; #[allow(missing_docs)] pub mod vartime_double_base; -#[cfg(feature = "alloc")] pub mod straus; #[cfg(feature = "alloc")] diff --git a/curve25519-dalek/src/backend/serial/scalar_mul/straus.rs b/curve25519-dalek/src/backend/serial/scalar_mul/straus.rs index 9c95b4fc6..4239aa724 100644 --- a/curve25519-dalek/src/backend/serial/scalar_mul/straus.rs +++ b/curve25519-dalek/src/backend/serial/scalar_mul/straus.rs @@ -13,15 +13,20 @@ #![allow(non_snake_case)] +#[cfg(feature = "alloc")] use alloc::vec::Vec; +#[cfg(feature = "alloc")] use core::borrow::Borrow; -use core::cmp::Ordering; +use crate::backend::serial::curve_models::ProjectiveNielsPoint; use crate::edwards::EdwardsPoint; use crate::scalar::Scalar; +use crate::traits::Identity; use crate::traits::MultiscalarMul; +#[cfg(feature = "alloc")] use crate::traits::VartimeMultiscalarMul; +use crate::window::LookupTable; /// Perform multiscalar multiplication by the interleaved window /// method, also known as Straus' method (since it was apparently @@ -49,101 +54,120 @@ pub struct Straus {} impl MultiscalarMul for Straus { type Point = EdwardsPoint; - /// Constant-time Straus using a fixed window of size \\(4\\). - /// - /// Our goal is to compute - /// \\[ - /// Q = s_1 P_1 + \cdots + s_n P_n. - /// \\] - /// - /// For each point \\( P_i \\), precompute a lookup table of - /// \\[ - /// P_i, 2P_i, 3P_i, 4P_i, 5P_i, 6P_i, 7P_i, 8P_i. - /// \\] - /// - /// For each scalar \\( s_i \\), compute its radix-\\(2^4\\) - /// signed digits \\( s_{i,j} \\), i.e., - /// \\[ - /// s_i = s_{i,0} + s_{i,1} 16^1 + ... + s_{i,63} 16^{63}, - /// \\] - /// with \\( -8 \leq s_{i,j} < 8 \\). Since \\( 0 \leq |s_{i,j}| - /// \leq 8 \\), we can retrieve \\( s_{i,j} P_i \\) from the - /// lookup table with a conditional negation: using signed - /// digits halves the required table size. - /// - /// Then as in the single-base fixed window case, we have - /// \\[ - /// \begin{aligned} - /// s_i P_i &= P_i (s_{i,0} + s_{i,1} 16^1 + \cdots + s_{i,63} 16^{63}) \\\\ - /// s_i P_i &= P_i s_{i,0} + P_i s_{i,1} 16^1 + \cdots + P_i s_{i,63} 16^{63} \\\\ - /// s_i P_i &= P_i s_{i,0} + 16(P_i s_{i,1} + 16( \cdots +16P_i s_{i,63})\cdots ) - /// \end{aligned} - /// \\] - /// so each \\( s_i P_i \\) can be computed by alternately adding - /// a precomputed multiple \\( P_i s_{i,j} \\) of \\( P_i \\) and - /// repeatedly doubling. - /// - /// Now consider the two-dimensional sum - /// \\[ - /// \begin{aligned} - /// s\_1 P\_1 &=& P\_1 s\_{1,0} &+& 16 (P\_1 s\_{1,1} &+& 16 ( \cdots &+& 16 P\_1 s\_{1,63}&) \cdots ) \\\\ - /// + & & + & & + & & & & + & \\\\ - /// s\_2 P\_2 &=& P\_2 s\_{2,0} &+& 16 (P\_2 s\_{2,1} &+& 16 ( \cdots &+& 16 P\_2 s\_{2,63}&) \cdots ) \\\\ - /// + & & + & & + & & & & + & \\\\ - /// \vdots & & \vdots & & \vdots & & & & \vdots & \\\\ - /// + & & + & & + & & & & + & \\\\ - /// s\_n P\_n &=& P\_n s\_{n,0} &+& 16 (P\_n s\_{n,1} &+& 16 ( \cdots &+& 16 P\_n s\_{n,63}&) \cdots ) - /// \end{aligned} - /// \\] - /// The sum of the left-hand column is the result \\( Q \\); by - /// computing the two-dimensional sum on the right column-wise, - /// top-to-bottom, then right-to-left, we need to multiply by \\( - /// 16\\) only once per column, sharing the doublings across all - /// of the input points. - fn multiscalar_mul(scalars: I, points: J) -> EdwardsPoint - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator, - J::Item: Borrow, - { - use crate::backend::serial::curve_models::ProjectiveNielsPoint; - use crate::traits::Identity; - use crate::window::LookupTable; + fn multiscalar_mul( + points_and_scalars: &[(EdwardsPoint, Scalar); N], + ) -> EdwardsPoint { + let lookup_tables: [_; N] = core::array::from_fn(|index| { + LookupTable::::from(&points_and_scalars[index].0) + }); - let lookup_tables: Vec<_> = points - .into_iter() - .map(|point| LookupTable::::from(point.borrow())) - .collect(); + let scalar_digits: [_; N] = + core::array::from_fn(|index| points_and_scalars[index].1.as_radix_16()); + multiscalar_mul(&scalar_digits, &lookup_tables) + } + + #[cfg(feature = "alloc")] + fn multiscalar_mul_alloc(points_and_scalars: I) -> EdwardsPoint + where + I: IntoIterator, + P: Borrow, + S: Borrow, + { // This puts the scalar digits into a heap-allocated Vec. // To ensure that these are erased, pass ownership of the Vec into a // Zeroizing wrapper. #[cfg_attr(not(feature = "zeroize"), allow(unused_mut))] - let mut scalar_digits: Vec<_> = scalars + let (lookup_tables, mut scalar_digits): (Vec<_>, Vec<_>) = points_and_scalars .into_iter() - .map(|s| s.borrow().as_radix_16()) - .collect(); + .map(|(p, s)| { + ( + LookupTable::::from(p.borrow()), + s.borrow().as_radix_16(), + ) + }) + .unzip(); - let mut Q = EdwardsPoint::identity(); - for j in (0..64).rev() { - Q = Q.mul_by_pow_2(4); - let it = scalar_digits.iter().zip(lookup_tables.iter()); - for (s_i, lookup_table_i) in it { - // R_i = s_{i,j} * P_i - let R_i = lookup_table_i.select(s_i[j]); - // Q = Q + R_i - Q = (&Q + &R_i).as_extended(); - } - } + let Q = multiscalar_mul(&scalar_digits, &lookup_tables); #[cfg(feature = "zeroize")] - zeroize::Zeroize::zeroize(&mut scalar_digits); + zeroize::Zeroize::zeroize(&mut scalar_digits.iter_mut()); Q } } +/// Constant-time Straus using a fixed window of size \\(4\\). +/// +/// Our goal is to compute +/// \\[ +/// Q = s_1 P_1 + \cdots + s_n P_n. +/// \\] +/// +/// For each point \\( P_i \\), precompute a lookup table of +/// \\[ +/// P_i, 2P_i, 3P_i, 4P_i, 5P_i, 6P_i, 7P_i, 8P_i. +/// \\] +/// +/// For each scalar \\( s_i \\), compute its radix-\\(2^4\\) +/// signed digits \\( s_{i,j} \\), i.e., +/// \\[ +/// s_i = s_{i,0} + s_{i,1} 16^1 + ... + s_{i,63} 16^{63}, +/// \\] +/// with \\( -8 \leq s_{i,j} < 8 \\). Since \\( 0 \leq |s_{i,j}| +/// \leq 8 \\), we can retrieve \\( s_{i,j} P_i \\) from the +/// lookup table with a conditional negation: using signed +/// digits halves the required table size. +/// +/// Then as in the single-base fixed window case, we have +/// \\[ +/// \begin{aligned} +/// s_i P_i &= P_i (s_{i,0} + s_{i,1} 16^1 + \cdots + s_{i,63} 16^{63}) \\\\ +/// s_i P_i &= P_i s_{i,0} + P_i s_{i,1} 16^1 + \cdots + P_i s_{i,63} 16^{63} \\\\ +/// s_i P_i &= P_i s_{i,0} + 16(P_i s_{i,1} + 16( \cdots +16P_i s_{i,63})\cdots ) +/// \end{aligned} +/// \\] +/// so each \\( s_i P_i \\) can be computed by alternately adding +/// a precomputed multiple \\( P_i s_{i,j} \\) of \\( P_i \\) and +/// repeatedly doubling. +/// +/// Now consider the two-dimensional sum +/// \\[ +/// \begin{aligned} +/// s\_1 P\_1 &=& P\_1 s\_{1,0} &+& 16 (P\_1 s\_{1,1} &+& 16 ( \cdots &+& 16 P\_1 s\_{1,63}&) \cdots ) \\\\ +/// + & & + & & + & & & & + & \\\\ +/// s\_2 P\_2 &=& P\_2 s\_{2,0} &+& 16 (P\_2 s\_{2,1} &+& 16 ( \cdots &+& 16 P\_2 s\_{2,63}&) \cdots ) \\\\ +/// + & & + & & + & & & & + & \\\\ +/// \vdots & & \vdots & & \vdots & & & & \vdots & \\\\ +/// + & & + & & + & & & & + & \\\\ +/// s\_n P\_n &=& P\_n s\_{n,0} &+& 16 (P\_n s\_{n,1} &+& 16 ( \cdots &+& 16 P\_n s\_{n,63}&) \cdots ) +/// \end{aligned} +/// \\] +/// The sum of the left-hand column is the result \\( Q \\); by +/// computing the two-dimensional sum on the right column-wise, +/// top-to-bottom, then right-to-left, we need to multiply by \\( +/// 16\\) only once per column, sharing the doublings across all +/// of the input points. +fn multiscalar_mul( + scalar_digits: &[[i8; 64]], + lookup_tables: &[LookupTable], +) -> EdwardsPoint { + let mut Q = EdwardsPoint::identity(); + for j in (0..64).rev() { + Q = Q.mul_by_pow_2(4); + let it = scalar_digits.iter().zip(lookup_tables.iter()); + for (s_i, lookup_table_i) in it { + // R_i = s_{i,j} * P_i + let R_i = lookup_table_i.select(s_i[j]); + // Q = Q + R_i + Q = (&Q + &R_i).as_extended(); + } + } + + Q +} + +#[cfg(feature = "alloc")] impl VartimeMultiscalarMul for Straus { type Point = EdwardsPoint; @@ -167,6 +191,7 @@ impl VartimeMultiscalarMul for Straus { }; use crate::traits::Identity; use crate::window::NafLookupTable5; + use core::cmp::Ordering; let nafs: Vec<_> = scalars .into_iter() diff --git a/curve25519-dalek/src/backend/vector/scalar_mul/mod.rs b/curve25519-dalek/src/backend/vector/scalar_mul/mod.rs index fed3470e7..71c4be461 100644 --- a/curve25519-dalek/src/backend/vector/scalar_mul/mod.rs +++ b/curve25519-dalek/src/backend/vector/scalar_mul/mod.rs @@ -18,7 +18,6 @@ pub mod variable_base; pub mod vartime_double_base; #[allow(missing_docs)] -#[cfg(feature = "alloc")] pub mod straus; #[allow(missing_docs)] diff --git a/curve25519-dalek/src/backend/vector/scalar_mul/straus.rs b/curve25519-dalek/src/backend/vector/scalar_mul/straus.rs index 23516dd23..4c201db8c 100644 --- a/curve25519-dalek/src/backend/vector/scalar_mul/straus.rs +++ b/curve25519-dalek/src/backend/vector/scalar_mul/straus.rs @@ -20,13 +20,11 @@ )] pub mod spec { + #[cfg(feature = "alloc")] use alloc::vec::Vec; + #[cfg(feature = "alloc")] use core::borrow::Borrow; - use core::cmp::Ordering; - - #[cfg(feature = "zeroize")] - use zeroize::Zeroizing; #[for_target_feature("avx2")] use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint}; @@ -36,8 +34,10 @@ pub mod spec { use crate::edwards::EdwardsPoint; use crate::scalar::Scalar; - use crate::traits::{Identity, MultiscalarMul, VartimeMultiscalarMul}; - use crate::window::{LookupTable, NafLookupTable5}; + #[cfg(feature = "alloc")] + use crate::traits::VartimeMultiscalarMul; + use crate::traits::{Identity, MultiscalarMul}; + use crate::window::LookupTable; /// Multiscalar multiplication using interleaved window / Straus' /// method. See the `Straus` struct in the serial backend for more @@ -52,41 +52,65 @@ pub mod spec { impl MultiscalarMul for Straus { type Point = EdwardsPoint; - fn multiscalar_mul(scalars: I, points: J) -> EdwardsPoint + fn multiscalar_mul( + points_and_scalars: &[(EdwardsPoint, Scalar); N], + ) -> EdwardsPoint { + // Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P] + // for each input point P + let lookup_tables: [_; N] = core::array::from_fn(|index| { + LookupTable::::from(&points_and_scalars[index].0) + }); + + let scalar_digits: [_; N] = + core::array::from_fn(|index| points_and_scalars[index].1.as_radix_16()); + + multiscalar_mul(&scalar_digits, &lookup_tables) + } + + #[cfg(feature = "alloc")] + fn multiscalar_mul_alloc(points_and_scalars: I) -> EdwardsPoint where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator, - J::Item: Borrow, + I: IntoIterator, + P: Borrow, + S: Borrow, { // Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P] // for each input point P - let lookup_tables: Vec<_> = points + let (lookup_tables, scalar_digits_vec): (Vec<_>, Vec<_>) = points_and_scalars .into_iter() - .map(|point| LookupTable::::from(point.borrow())) - .collect(); + .map(|(p, s)| { + ( + LookupTable::::from(p.borrow()), + s.borrow().as_radix_16(), + ) + }) + .unzip(); - let scalar_digits_vec: Vec<_> = scalars - .into_iter() - .map(|s| s.borrow().as_radix_16()) - .collect(); // Pass ownership to a `Zeroizing` wrapper #[cfg(feature = "zeroize")] - let scalar_digits_vec = Zeroizing::new(scalar_digits_vec); + let scalar_digits_vec = zeroize::Zeroizing::new(scalar_digits_vec); - let mut Q = ExtendedPoint::identity(); - for j in (0..64).rev() { - Q = Q.mul_by_pow_2(4); - let it = scalar_digits_vec.iter().zip(lookup_tables.iter()); - for (s_i, lookup_table_i) in it { - // Q = Q + s_{i,j} * P_i - Q = &Q + &lookup_table_i.select(s_i[j]); - } + multiscalar_mul(&scalar_digits_vec, &lookup_tables) + } + } + + fn multiscalar_mul( + scalar_digits: &[[i8; 64]], + lookup_tables: &[LookupTable], + ) -> EdwardsPoint { + let mut Q = ExtendedPoint::identity(); + for j in (0..64).rev() { + Q = Q.mul_by_pow_2(4); + let it = scalar_digits.iter().zip(lookup_tables.iter()); + for (s_i, lookup_table_i) in it { + // Q = Q + s_{i,j} * P_i + Q = &Q + &lookup_table_i.select(s_i[j]); } - Q.into() } + Q.into() } + #[cfg(feature = "alloc")] impl VartimeMultiscalarMul for Straus { type Point = EdwardsPoint; @@ -96,6 +120,9 @@ pub mod spec { I::Item: Borrow, J: IntoIterator>, { + use crate::window::NafLookupTable5; + use core::cmp::Ordering; + let nafs: Vec<_> = scalars .into_iter() .map(|c| c.borrow().non_adjacent_form(5)) diff --git a/curve25519-dalek/src/edwards.rs b/curve25519-dalek/src/edwards.rs index f7d2e6906..bfc8a617e 100644 --- a/curve25519-dalek/src/edwards.rs +++ b/curve25519-dalek/src/edwards.rs @@ -154,7 +154,6 @@ use crate::traits::{Identity, IsIdentity}; use affine::AffinePoint; -#[cfg(feature = "alloc")] use crate::traits::MultiscalarMul; #[cfg(feature = "alloc")] use crate::traits::{VartimeMultiscalarMul, VartimePrecomputedMultiscalarMul}; @@ -951,35 +950,23 @@ impl EdwardsPoint { // These use the iterator's size hint and the target settings to // forward to a specific backend implementation. -#[cfg(feature = "alloc")] impl MultiscalarMul for EdwardsPoint { type Point = EdwardsPoint; - fn multiscalar_mul(scalars: I, points: J) -> EdwardsPoint + fn multiscalar_mul( + points_and_scalars: &[(Self::Point, Scalar); N], + ) -> Self::Point { + crate::backend::straus_multiscalar_mul(points_and_scalars) + } + + #[cfg(feature = "alloc")] + fn multiscalar_mul_alloc(points_and_scalars: I) -> EdwardsPoint where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator, - J::Item: Borrow, + I: IntoIterator, + P: Borrow, + S: Borrow, { - // Sanity-check lengths of input iterators - let mut scalars = scalars.into_iter(); - let mut points = points.into_iter(); - - // Lower and upper bounds on iterators - let (s_lo, s_hi) = scalars.by_ref().size_hint(); - let (p_lo, p_hi) = points.by_ref().size_hint(); - - // They should all be equal - assert_eq!(s_lo, p_lo); - assert_eq!(s_hi, Some(s_lo)); - assert_eq!(p_hi, Some(p_lo)); - - // Now we know there's a single size. When we do - // size-dependent algorithm dispatch, use this as the hint. - let _size = s_lo; - - crate::backend::straus_multiscalar_mul(scalars, points) + crate::backend::straus_multiscalar_mul_alloc(points_and_scalars) } } @@ -2252,7 +2239,7 @@ mod test { let Gs = xs.iter().map(EdwardsPoint::mul_base).collect::>(); // Compute H1 = (consttime) - let H1 = EdwardsPoint::multiscalar_mul(&xs, &Gs); + let H1 = EdwardsPoint::multiscalar_mul_alloc(Gs.iter().zip(&xs)); // Compute H2 = (vartime) let H2 = EdwardsPoint::vartime_multiscalar_mul(&xs, &Gs); // Compute H3 = = sum(xi^2) * B @@ -2409,9 +2396,10 @@ mod test { &[A_SCALAR, B_SCALAR], &[A, constants::ED25519_BASEPOINT_POINT], ); - let result_consttime = EdwardsPoint::multiscalar_mul( - &[A_SCALAR, B_SCALAR], - &[A, constants::ED25519_BASEPOINT_POINT], + let result_consttime = EdwardsPoint::multiscalar_mul_alloc( + [A, constants::ED25519_BASEPOINT_POINT] + .into_iter() + .zip([A_SCALAR, B_SCALAR]), ); assert_eq!(result_vartime.compress(), result_consttime.compress()); diff --git a/curve25519-dalek/src/ristretto.rs b/curve25519-dalek/src/ristretto.rs index 8b867930d..dbf7fb890 100644 --- a/curve25519-dalek/src/ristretto.rs +++ b/curve25519-dalek/src/ristretto.rs @@ -206,9 +206,9 @@ use crate::scalar::Scalar; #[cfg(feature = "precomputed-tables")] use crate::traits::BasepointTable; -use crate::traits::Identity; +use crate::traits::{Identity, MultiscalarMul}; #[cfg(feature = "alloc")] -use crate::traits::{MultiscalarMul, VartimeMultiscalarMul, VartimePrecomputedMultiscalarMul}; +use crate::traits::{VartimeMultiscalarMul, VartimePrecomputedMultiscalarMul}; // ------------------------------------------------------------------------ // Compressed points @@ -1007,19 +1007,30 @@ define_mul_variants!(LHS = Scalar, RHS = RistrettoPoint, Output = RistrettoPoint // These use iterator combinators to unwrap the underlying points and // forward to the EdwardsPoint implementations. -#[cfg(feature = "alloc")] impl MultiscalarMul for RistrettoPoint { type Point = RistrettoPoint; - fn multiscalar_mul(scalars: I, points: J) -> RistrettoPoint + fn multiscalar_mul( + points_and_scalars: &[(RistrettoPoint, Scalar); N], + ) -> RistrettoPoint { + let points_and_scalars: [_; N] = core::array::from_fn(|index| { + let (p, s) = points_and_scalars[index]; + (p.0, s) + }); + RistrettoPoint(EdwardsPoint::multiscalar_mul(&points_and_scalars)) + } + + #[cfg(feature = "alloc")] + fn multiscalar_mul_alloc(points_and_scalars: I) -> RistrettoPoint where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator, - J::Item: Borrow, + I: IntoIterator, + P: Borrow, + S: Borrow, { - let extended_points = points.into_iter().map(|P| P.borrow().0); - RistrettoPoint(EdwardsPoint::multiscalar_mul(scalars, extended_points)) + let points_and_scalars = points_and_scalars + .into_iter() + .map(|(p, s)| (p.borrow().0, s)); + RistrettoPoint(EdwardsPoint::multiscalar_mul_alloc(points_and_scalars)) } } diff --git a/curve25519-dalek/src/traits.rs b/curve25519-dalek/src/traits.rs index 834b8f27f..9a5660987 100644 --- a/curve25519-dalek/src/traits.rs +++ b/curve25519-dalek/src/traits.rs @@ -79,6 +79,49 @@ pub trait MultiscalarMul { /// The type of point being multiplied, e.g., `RistrettoPoint`. type Point; + /// Given an iterator of (possibly secret) scalars and an iterator of + /// public points, compute + /// $$ + /// Q = c\_1 P\_1 + \cdots + c\_n P\_n. + /// $$ + /// + /// It is an error to call this function with two iterators of different lengths. + /// + /// # Examples + /// + /// The trait bound aims for maximum flexibility: the inputs must be + /// convertible to iterators (`I: IntoIter`), and the iterator's items + /// must be `Borrow` (or `Borrow`), to allow + /// iterators returning either `Scalar`s or `&Scalar`s. + /// + /// ``` + /// use curve25519_dalek::constants; + /// use curve25519_dalek::traits::MultiscalarMul; + /// use curve25519_dalek::ristretto::RistrettoPoint; + /// use curve25519_dalek::scalar::Scalar; + /// + /// // Some scalars + /// let a = Scalar::from(87329482u64); + /// let b = Scalar::from(37264829u64); + /// let c = Scalar::from(98098098u64); + /// + /// // Some points + /// let P = constants::RISTRETTO_BASEPOINT_POINT; + /// let Q = P + P; + /// let R = P + Q; + /// + /// // A1 = a*P + b*Q + c*R + /// let A1 = RistrettoPoint::multiscalar_mul(&[(P, a), (Q, b), (R, c)]); + /// + /// // A2 = (-a)*P + (-b)*Q + (-c)*R + /// let A2 = RistrettoPoint::multiscalar_mul(&[(P, -a), (Q, -b), (R, -c)]); + /// + /// assert_eq!(A1.compress(), (-A2).compress()); + /// ``` + fn multiscalar_mul( + points_and_scalars: &[(Self::Point, Scalar); N], + ) -> Self::Point; + /// Given an iterator of (possibly secret) scalars and an iterator of /// public points, compute /// $$ @@ -114,23 +157,23 @@ pub trait MultiscalarMul { /// /// // A1 = a*P + b*Q + c*R /// let abc = [a,b,c]; - /// let A1 = RistrettoPoint::multiscalar_mul(&abc, &[P,Q,R]); + /// let A1 = RistrettoPoint::multiscalar_mul_alloc([P,Q,R].into_iter().zip(&abc)); /// // Note: (&abc).into_iter(): Iterator /// /// // A2 = (-a)*P + (-b)*Q + (-c)*R /// let minus_abc = abc.iter().map(|x| -x); - /// let A2 = RistrettoPoint::multiscalar_mul(minus_abc, &[P,Q,R]); + /// let A2 = RistrettoPoint::multiscalar_mul_alloc([P,Q,R].into_iter().zip(minus_abc)); /// // Note: minus_abc.into_iter(): Iterator /// /// assert_eq!(A1.compress(), (-A2).compress()); /// # } /// ``` - fn multiscalar_mul(scalars: I, points: J) -> Self::Point + #[cfg(feature = "alloc")] + fn multiscalar_mul_alloc(points_and_scalars: I) -> Self::Point where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator, - J::Item: Borrow; + I: IntoIterator, + P: Borrow, + S: Borrow; } /// A trait for variable-time multiscalar multiplication without precomputation.