From d20d03cd7f22bc38e4564e8e89feb2868cb859a1 Mon Sep 17 00:00:00 2001 From: Sam Kim Date: Thu, 30 Sep 2021 11:54:14 -0400 Subject: [PATCH] clean up ElGamal decryption --- zk-token-sdk/src/encryption/dlog.rs | 227 +++++-------------------- zk-token-sdk/src/encryption/elgamal.rs | 32 +++- zk-token-sdk/src/encryption/mod.rs | 2 +- 3 files changed, 74 insertions(+), 187 deletions(-) diff --git a/zk-token-sdk/src/encryption/dlog.rs b/zk-token-sdk/src/encryption/dlog.rs index 4ee83982be..96bc1f3992 100644 --- a/zk-token-sdk/src/encryption/dlog.rs +++ b/zk-token-sdk/src/encryption/dlog.rs @@ -1,17 +1,12 @@ -use core::ops::{Add, Neg, Sub}; - -use curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT as G; -use curve25519_dalek::ristretto::RistrettoPoint; -use curve25519_dalek::scalar::Scalar; -use curve25519_dalek::traits::Identity; - -use std::collections::HashMap; -use std::hash::{Hash, Hasher}; - -use serde::{Deserialize, Serialize}; +use { + curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar, traits::Identity}, + serde::{Deserialize, Serialize}, + std::collections::HashMap, +}; +const TWO15: u32 = 32768; const TWO14: u32 = 16384; // 2^14 -const TWO16: u32 = 65536; // 2^16 + // const TWO16: u32 = 65536; // 2^16 const TWO18: u32 = 262144; // 2^18 /// Type that captures a discrete log challenge. @@ -25,93 +20,52 @@ pub struct DiscreteLogInstance { pub target: RistrettoPoint, } -/// Solves the discrete log instance using a 16/16 bit offline/online split -impl DiscreteLogInstance { - /// Solves the discrete log problem under the assumption that the solution - /// is a 32-bit number. - pub fn decode_u32(self) -> Option { - let hashmap = DiscreteLogInstance::decode_u32_precomputation(self.generator); - self.decode_u32_online(&hashmap) - } +/// Builds a HashMap of 2^18 elements +pub fn decode_u32_precomputation(generator: RistrettoPoint) -> HashMap<[u8; 32], u32> { + let mut hashmap = HashMap::new(); - /// Builds a HashMap of 2^16 elements - pub fn decode_u32_precomputation(generator: RistrettoPoint) -> HashMap { - let mut hashmap = HashMap::new(); + let two12_scalar = Scalar::from(TWO14); + let identity = RistrettoPoint::identity(); // 0 * G + let generator = two12_scalar * generator; // 2^12 * G - let two16_scalar = Scalar::from(TWO16); - let identity = HashableRistretto(RistrettoPoint::identity()); // 0 * G - let generator = HashableRistretto(two16_scalar * generator); // 2^16 * G + // iterator for 2^12*0G , 2^12*1G, 2^12*2G, ... + let ristretto_iter = RistrettoIterator::new(identity, generator); + let mut steps_for_breakpoint = 0; + ristretto_iter.zip(0..TWO18).for_each(|(elem, x_hi)| { + let key = elem.compress().to_bytes(); + hashmap.insert(key, x_hi); - // iterator for 2^16*0G , 2^16*1G, 2^16*2G, ... - let ristretto_iter = RistrettoIterator::new(identity, generator); - ristretto_iter.zip(0..TWO16).for_each(|(elem, x_hi)| { - hashmap.insert(elem, x_hi); - }); + // unclean way to print status update; will clean up later + if x_hi % TWO15 == 0 { + println!(" [{:?}/8] completed", steps_for_breakpoint); + steps_for_breakpoint += 1; + } + }); + println!(" [8/8] completed"); - hashmap - } - - /// Solves the discrete log instance using the pre-computed HashMap by enumerating through 2^16 - /// possible solutions - pub fn decode_u32_online(self, hashmap: &HashMap) -> Option { - // iterator for 0G, -1G, -2G, ... - let ristretto_iter = RistrettoIterator::new( - HashableRistretto(self.target), - HashableRistretto(-self.generator), - ); - - let mut decoded = None; - ristretto_iter.zip(0..TWO16).for_each(|(elem, x_lo)| { - if hashmap.contains_key(&elem) { - let x_hi = hashmap[&elem]; - decoded = Some(x_lo + TWO16 * x_hi); - } - }); - decoded - } + hashmap } /// Solves the discrete log instance using a 18/14 bit offline/online split impl DiscreteLogInstance { /// Solves the discrete log problem under the assumption that the solution /// is a 32-bit number. - pub fn decode_u32_alt(self) -> Option { - let hashmap = DiscreteLogInstance::decode_u32_precomputation_alt(self.generator); - self.decode_u32_online_alt(&hashmap) - } - - /// Builds a HashMap of 2^18 elements - pub fn decode_u32_precomputation_alt( - generator: RistrettoPoint, - ) -> HashMap { - let mut hashmap = HashMap::new(); - - let two12_scalar = Scalar::from(TWO14); - let identity = HashableRistretto(RistrettoPoint::identity()); // 0 * G - let generator = HashableRistretto(two12_scalar * generator); // 2^12 * G - - // iterator for 2^12*0G , 2^12*1G, 2^12*2G, ... - let ristretto_iter = RistrettoIterator::new(identity, generator); - ristretto_iter.zip(0..TWO18).for_each(|(elem, x_hi)| { - hashmap.insert(elem, x_hi); - }); - - hashmap + pub fn decode_u32(self) -> Option { + let hashmap = decode_u32_precomputation(self.generator); + self.decode_u32_online(&hashmap) } /// Solves the discrete log instance using the pre-computed HashMap by enumerating through 2^14 /// possible solutions - pub fn decode_u32_online_alt(self, hashmap: &HashMap) -> Option { + pub fn decode_u32_online(self, hashmap: &HashMap<[u8; 32], u32>) -> Option { // iterator for 0G, -1G, -2G, ... - let ristretto_iter = RistrettoIterator::new( - HashableRistretto(self.target), - HashableRistretto(-self.generator), - ); + let ristretto_iter = RistrettoIterator::new(self.target, -self.generator); let mut decoded = None; ristretto_iter.zip(0..TWO14).for_each(|(elem, x_lo)| { - if hashmap.contains_key(&elem) { - let x_hi = hashmap[&elem]; + let key = elem.compress().to_bytes(); + if hashmap.contains_key(&key) { + let x_hi = hashmap[&key]; decoded = Some(x_lo + TWO14 * x_hi); } }); @@ -119,92 +73,34 @@ impl DiscreteLogInstance { } } -/// Type wrapper for RistrettoPoint that implements the Hash trait -#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq)] -pub struct HashableRistretto(pub RistrettoPoint); - -impl HashableRistretto { - pub fn encode>(amount: T) -> Self { - HashableRistretto(amount.into() * G) - } -} - -impl Hash for HashableRistretto { - fn hash(&self, state: &mut H) { - bincode::serialize(self).unwrap().hash(state); - } -} - -impl PartialEq for HashableRistretto { - fn eq(&self, other: &Self) -> bool { - self == other - } -} - /// HashableRistretto iterator. /// /// Given an initial point X and a stepping point P, the iterator iterates through /// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ... struct RistrettoIterator { - pub curr: HashableRistretto, - pub step: HashableRistretto, + pub curr: RistrettoPoint, + pub step: RistrettoPoint, } impl RistrettoIterator { - fn new(curr: HashableRistretto, step: HashableRistretto) -> Self { + fn new(curr: RistrettoPoint, step: RistrettoPoint) -> Self { RistrettoIterator { curr, step } } } impl Iterator for RistrettoIterator { - type Item = HashableRistretto; + type Item = RistrettoPoint; fn next(&mut self) -> Option { let r = self.curr; - self.curr = self.curr + self.step; + self.curr += self.step; Some(r) } } -impl<'a, 'b> Add<&'b HashableRistretto> for &'a HashableRistretto { - type Output = HashableRistretto; - - fn add(self, other: &HashableRistretto) -> HashableRistretto { - HashableRistretto(self.0 + other.0) - } -} - -define_add_variants!( - LHS = HashableRistretto, - RHS = HashableRistretto, - Output = HashableRistretto -); - -impl<'a, 'b> Sub<&'b HashableRistretto> for &'a HashableRistretto { - type Output = HashableRistretto; - - fn sub(self, other: &HashableRistretto) -> HashableRistretto { - HashableRistretto(self.0 - other.0) - } -} - -define_sub_variants!( - LHS = HashableRistretto, - RHS = HashableRistretto, - Output = HashableRistretto -); - -impl Neg for HashableRistretto { - type Output = HashableRistretto; - - fn neg(self) -> HashableRistretto { - HashableRistretto(-self.0) - } -} - #[cfg(test)] mod tests { - use super::*; + use {super::*, curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT as G}; /// Discrete log test for 16/16 split /// @@ -212,7 +108,6 @@ mod tests { /// - 8 sec for precomputation /// - 3 sec for online computation #[test] - #[ignore] fn test_decode_correctness() { let amount: u32 = 65545; @@ -223,7 +118,7 @@ mod tests { // Very informal measurements for now let start_precomputation = time::precise_time_s(); - let precomputed_hashmap = DiscreteLogInstance::decode_u32_precomputation(G); + let precomputed_hashmap = decode_u32_precomputation(G); let end_precomputation = time::precise_time_s(); let start_online = time::precise_time_s(); @@ -241,42 +136,4 @@ mod tests { end_online - start_online ); } - - /// Discrete log test for 18/14 split - /// - /// Very informal measurements on my machine: - /// - 33 sec for precomputation - /// - 0.8 sec for online computation - #[test] - #[ignore] - fn test_decode_alt_correctness() { - let amount: u32 = 65545; - - let instance = DiscreteLogInstance { - generator: G, - target: Scalar::from(amount) * G, - }; - - // Very informal measurements for now - let start_precomputation = time::precise_time_s(); - let precomputed_hashmap = DiscreteLogInstance::decode_u32_precomputation_alt(G); - let end_precomputation = time::precise_time_s(); - - let start_online = time::precise_time_s(); - let computed_amount = instance - .decode_u32_online_alt(&precomputed_hashmap) - .unwrap(); - let end_online = time::precise_time_s(); - - assert_eq!(amount, computed_amount); - - println!( - "18/14 Split precomputation: {:?} sec", - end_precomputation - start_precomputation - ); - println!( - "18/14 Split online computation: {:?} sec", - end_online - start_online - ); - } } diff --git a/zk-token-sdk/src/encryption/elgamal.rs b/zk-token-sdk/src/encryption/elgamal.rs index b4476aafa3..c476403f6e 100644 --- a/zk-token-sdk/src/encryption/elgamal.rs +++ b/zk-token-sdk/src/encryption/elgamal.rs @@ -2,7 +2,7 @@ use rand::{rngs::OsRng, CryptoRng, RngCore}; use { crate::encryption::{ - encode::DiscreteLogInstance, + dlog::DiscreteLogInstance, pedersen::{Pedersen, PedersenBase, PedersenComm, PedersenDecHandle, PedersenOpen}, }, arrayref::{array_ref, array_refs}, @@ -12,6 +12,7 @@ use { scalar::Scalar, }, serde::{Deserialize, Serialize}, + std::collections::HashMap, std::convert::TryInto, subtle::{Choice, ConstantTimeEq}, zeroize::Zeroize, @@ -100,6 +101,17 @@ impl ElGamal { let discrete_log_instance = ElGamal::decrypt(sk, ct); discrete_log_instance.decode_u32() } + + /// On input a secret key, ciphertext, and hashmap, the function decrypts the + /// ciphertext for a u32 value. + pub fn decrypt_u32_online( + sk: &ElGamalSK, + ct: &ElGamalCiphertext, + hashmap: &HashMap<[u8; 32], u32>, + ) -> Option { + let discrete_log_instance = ElGamal::decrypt(sk, ct); + discrete_log_instance.decode_u32_online(hashmap) + } } /// Public key for the ElGamal encryption scheme. @@ -164,6 +176,15 @@ impl ElGamalSK { ElGamal::decrypt_u32(self, ct) } + /// Utility method for code ergonomics. + pub fn decrypt_u32_online( + &self, + ct: &ElGamalCiphertext, + hashmap: &HashMap<[u8; 32], u32>, + ) -> Option { + ElGamal::decrypt_u32_online(self, ct, hashmap) + } + pub fn to_bytes(&self) -> [u8; 32] { self.0.to_bytes() } @@ -249,6 +270,15 @@ impl ElGamalCiphertext { pub fn decrypt_u32(&self, sk: &ElGamalSK) -> Option { ElGamal::decrypt_u32(sk, self) } + + /// Utility method for code ergonomics. + pub fn decrypt_u32_online( + &self, + sk: &ElGamalSK, + hashmap: &HashMap<[u8; 32], u32>, + ) -> Option { + ElGamal::decrypt_u32_online(sk, self, hashmap) + } } impl<'a, 'b> Add<&'b ElGamalCiphertext> for &'a ElGamalCiphertext { diff --git a/zk-token-sdk/src/encryption/mod.rs b/zk-token-sdk/src/encryption/mod.rs index 72985c4ed4..40e70f9861 100644 --- a/zk-token-sdk/src/encryption/mod.rs +++ b/zk-token-sdk/src/encryption/mod.rs @@ -1,3 +1,3 @@ -pub mod elgamal; pub mod dlog; +pub mod elgamal; pub mod pedersen;