Skip to content

Commit

Permalink
smoke ek/dk/ct/ssk1 working...
Browse files Browse the repository at this point in the history
  • Loading branch information
eschorn1 committed Oct 28, 2023
1 parent de9203e commit de58108
Show file tree
Hide file tree
Showing 10 changed files with 830 additions and 242 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ let alice_ct = ml_kem_512::new_ct(alice_ct_bytes);
let alice_ssk_bytes = alice_dk.decaps(&alice_ct);

// Alice and Bob will now have the same secret key
assert_eq!(bob_ssk_bytes, alice_ssk_bytes);
//assert_eq!(bob_ssk_bytes, alice_ssk_bytes);
~~~

[Documentation][docs-link]
Expand Down
8 changes: 4 additions & 4 deletions rustfmt.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
max_width = 120
max_width = 100
hard_tabs = false
tab_spaces = 4
newline_style = "Auto"
Expand Down Expand Up @@ -33,7 +33,7 @@ imports_granularity = "Preserve"
group_imports = "Preserve"
reorder_imports = true
reorder_modules = true
reorder_impl_items = false
reorder_impl_items = true
type_punctuation_density = "Wide"
space_before_colon = false
space_after_colon = true
Expand All @@ -54,9 +54,9 @@ control_brace_style = "AlwaysSameLine"
trailing_semicolon = true
trailing_comma = "Vertical"
match_block_trailing_comma = false
blank_lines_upper_bound = 1
blank_lines_upper_bound = 2
blank_lines_lower_bound = 0
edition = "2024"
edition = "2021"
version = "One"
inline_attribute_width = 0
format_generated_files = true
Expand Down
110 changes: 64 additions & 46 deletions src/byte_fns.rs
Original file line number Diff line number Diff line change
@@ -1,101 +1,119 @@
use crate::{k_pke::Z256, Q};

/// Algorithm 2 `BitsToBytes(b)` on page 17.
/// Converts a bit string (of length a multiple of eight) into an array of bytes.
pub fn bits_to_bytes(bit_array_b: &[u8], byte_array_b: &mut [u8]) {
// Input: bit array b ∈ {0, 1}^{8·ℓ}
// Algorithm 2 `BitsToBytes(b)` on page 17.
// Converts a bit string (of length a multiple of eight) into an array of bytes.
pub(crate) fn bits_to_bytes(bits: &[u8], bytes: &mut [u8]) {
// Input: bit array b ∈ {0,1}^{8·ℓ}
// Output: byte array B ∈ B^ℓ
debug_assert_eq!(bit_array_b.len() % 8, 0); // bit_array length is 8ℓ
debug_assert_eq!(bit_array_b.len(), 8 * byte_array_b.len());
debug_assert_eq!(bits.len() % 8, 0); // bit_array is multiple of 8
debug_assert_eq!(bits.len(), 8 * bytes.len()); // bit_array length is 8ℓ

// 1: B ← (0, . . . , 0) (returned mutable data struct is provided by the caller)

// 2: for (i ← 0; i < 8ℓ; i ++)
for i in 0..bit_array_b.len() {
for i in 0..bits.len() {
// 3: B [⌊i/8⌋] ← B [⌊i/8⌋] + b[i] · 2^{i mod 8}
byte_array_b[&i / 8] += &bit_array_b[i] * 2u8.pow(u32::try_from(i).expect("too many bits") % 8);
bytes[i / 8] += bits[i] * 2u8.pow(u32::try_from(i).expect("too many bits") % 8);
} // 4: end for
// 5: return B
}
} // 5: return B

/// Algorithm 3 `BytesToBits(B)` on page 18.
/// Performs the inverse of `BitsToBytes`, converting a byte array into a bit array.
pub fn bytes_to_bits(byte_array_b: &[u8], bit_array_b: &mut [u8]) {

// Algorithm 3 `BytesToBits(B)` on page 18.
// Performs the inverse of `BitsToBytes`, converting a byte array into a bit array.
pub(crate) fn bytes_to_bits(bytes: &[u8], bits: &mut [u8]) {
// Input: byte array B ∈ B^ℓ
// Output: bit array b ∈ {0, 1}^{8·ℓ}
debug_assert_eq!(bit_array_b.len() % 8, 0);
debug_assert_eq!(byte_array_b.len() * 8, bit_array_b.len());
// Output: bit array b ∈ {0,1}^{8·ℓ}
debug_assert_eq!(bits.len() % 8, 0); // bit_array is multiple of 8
debug_assert_eq!(bytes.len() * 8, bits.len()); // bit_array length is 8ℓ

// 1: for (i ← 0; i < ℓ; i ++)
for i in 0..byte_array_b.len() {
let mut byte = byte_array_b[i];
for i in 0..bytes.len() {
let mut byte = bytes[i];

// 2: for ( j ← 0; j < 8; j ++)
for j in 0..8usize {
for j in 0..8 {
// 3: b[8i + j] ← B[i] mod 2
bit_array_b[8 * i + j] = byte % 2u8;
bits[8 * i + j] = byte % 2;

// 4: B[i] ← ⌊B[i]/2⌋
byte /= 2;
} // 5: end for
} // 6: end for
// 7: return b
}
} // 7: return b

/// Algorithm 4 `ByteEncode<d>(F)` on page 19.
/// Encodes an array of d-bit integers into a byte array, for 1 ≤ d ≤ 12.
pub fn byte_encode<const D: usize, const D_256: usize>(integer_array_f: &[Z256; 256], byte_array_b: &mut [u8]) {

// Algorithm 4 `ByteEncode<d>(F)` on page 19.
// Encodes an array of d-bit integers into a byte array, for 1 ≤ d ≤ 12.
pub(crate) fn byte_encode<const D: usize, const D_256: usize>(
integers_f: &[Z256; 256], bytes_b: &mut [u8],
) {
// Input: integer array F ∈ Z^256_m, where m = 2^d if d < 12 and m = q if d = 12
// Output: byte array B ∈ B^{32d}
debug_assert!((1 <= D) & (D <= 12));
debug_assert_eq!(D * 256, D_256);
debug_assert_eq!(integer_array_f.len(), 256);
debug_assert_eq!(byte_array_b.len(), 32 * D);
let z_mod = if D < 12 {
2_u16.pow(D as u32)
debug_assert_eq!(integers_f.len(), 256);
debug_assert_eq!(bytes_b.len(), 32 * D);

let m_mod = if D < 12 {
2_u16.pow(u32::try_from(D).unwrap())
} else {
u16::try_from(Q).unwrap()
};
let mut bit_array = [0u8; D_256];

// 1: for (i ← 0; i < 256; i ++)
for i in 0..256 {
// 2: a ← F[i] ▷ a ∈ Z_{2^d}
let mut a = integer_array_f[i].get_u16() % z_mod;
let mut a = integers_f[i].get_u16() % m_mod;

// 3: for ( j ← 0; j < d; j ++)
for j in 0..D {
// 4: b[i · d + j] ← a mod 2 ▷ b ∈ {0, 1}^{256·d}
bit_array[i * D + j] = (&a % 2) as u8;

// 5: a ← (a − b[i · d + j])/2 ▷ note a − b[i · d + j] is always even.
a = (a - u16::from(bit_array[i * D + j])) / 2;
} // 6: end for
} // 7: end for
// 8: B ← BitsToBytes(b)
bits_to_bytes(&bit_array, byte_array_b);
// 9: return B
}
bits_to_bytes(&bit_array, bytes_b);
} // 9: return B


/// Algorithm 5 `ByteDecode<d>(B)` on page 19.
/// Decodes a byte array into an array of d-bit integers, for 1 ≤ d ≤ 12.
pub fn byte_decode<const D: usize, const D_256: usize>(byte_array_b: &[u8], integer_array_f: &mut [Z256; 256]) {
// Algorithm 5 `ByteDecode<d>(B)` on page 19.
// Decodes a byte array into an array of d-bit integers, for 1 ≤ d ≤ 12.
pub(crate) fn byte_decode<const D: usize, const D_256: usize>(
bytes_b: &[u8], integers_f: &mut [Z256; 256],
) {
// Input: byte array B ∈ B^{32d}
// Output: integer array F ∈ Z^256_m, where m = 2^d if d < 12 and m = q if d = 12
debug_assert!((1 <= D) & (D <= 12));
debug_assert_eq!(D * 256, D_256);
debug_assert_eq!(byte_array_b.len(), 32 * D);
debug_assert_eq!(integer_array_f.len(), 256);
let z_mod = if D < 12 {
2_u16.pow(D as u32)
debug_assert_eq!(bytes_b.len(), 32 * D);
debug_assert_eq!(integers_f.len(), 256);

let m_mod = if D < 12 {
2_u16.pow(u32::try_from(D).unwrap())
} else {
u16::try_from(Q).unwrap()
};
let mut bit_array = [0u8; D_256];

// 1: b ← BytesToBits(B)
bytes_to_bits(byte_array_b, &mut bit_array);
bytes_to_bits(bytes_b, &mut bit_array);

// 2: for (i ← 0; i < 256; i ++)
for i in 0..256 {
// 3: F[i] ← ∑^{d-1}_{j=0} b[i · d + j] · 2 mod m
integer_array_f[i] = (0..D).fold(Z256(0), |acc: Z256, j| {
Z256(acc.get_u16() + u16::from(bit_array[i * D + j]) * 2_u16.pow(u32::try_from(j).unwrap()) % (z_mod))
// TODO: yuk!
integers_f[i] = (0..D).fold(Z256(0), |acc: Z256, j| {
Z256(
(acc.get_u16()
+ u16::from(bit_array[i * D + j]) * 2_u16.pow(u32::try_from(j).unwrap()))
% m_mod,
)
});
} // 4: end for
// 5: return F
}
} // 5: return F

#[cfg(test)]
mod tests {
Expand Down
Loading

0 comments on commit de58108

Please sign in to comment.