Skip to content

Commit

Permalink
arith
Browse files Browse the repository at this point in the history
  • Loading branch information
eschorn1 committed Jan 1, 2024
1 parent edac79b commit c466bc4
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 142 deletions.
115 changes: 54 additions & 61 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,33 @@ env:
RUSTFLAGS: "-Dwarnings"
RUSTDOCFLAGS: "-Dwarnings"


jobs:
#build:
# runs-on: ubuntu-latest
# strategy:
# matrix:
# rust:
# - 1.72.0 # MSRV
# - stable
# target:
# #- thumbv7em-none-eabi
# - wasm32-unknown-unknown
# steps:
# - uses: actions/checkout@v4
# - uses: dtolnay/rust-toolchain@master
# with:
# toolchain: ${{ matrix.rust }}
# targets: ${{ matrix.target }}
# - run: cargo build --target ${{ matrix.target }} --release

#benches:
# runs-on: ubuntu-latest
# strategy:
# matrix:
# rust:
# - 1.72.0 # MSRV
# - stable
# steps:
# - uses: actions/checkout@v4
# - uses: dtolnay/rust-toolchain@master
# with:
# toolchain: ${{ matrix.rust }}
# - run: cargo build --all-features --benches

build:
runs-on: ubuntu-latest
strategy:
matrix:
rust:
- 1.72.0 # MSRV
- stable
target:
- thumbv7em-none-eabi
- wasm32-unknown-unknown
- s390x-unknown-linux-gnu
- powerpc64-unknown-linux-gnu
- riscv64gc-unknown-none-elf
- x86_64-pc-windows-gnu
- x86_64-apple-darwin
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust }}
targets: ${{ matrix.target }}
- run: cargo build --target ${{ matrix.target }} --release --no-default-features --features "ml-kem-512 ml-kem-768 ml-kem-1024"


test:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -76,38 +71,36 @@ jobs:
- run: cargo check --target ${{ matrix.target }} --all-features
- run: cargo test --release --target ${{ matrix.target }}

#cross:
# strategy:
# matrix:
# include:
# # ARM32
# - target: armv7-unknown-linux-gnueabihf
# rust: 1.72.0 # MSRV (cross)
# - target: armv7-unknown-linux-gnueabihf
# rust: stable

# # ARM64
# - target: aarch64-unknown-linux-gnu
# rust: 1.72.0 # MSRV (cross)
# - target: aarch64-unknown-linux-gnu
# rust: stable

# # PPC32
# - target: powerpc-unknown-linux-gnu
# rust: 1.72.0 # MSRV (cross)
# - target: powerpc-unknown-linux-gnu
# rust: stable

# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v4
# - run: ${{ matrix.deps }}
# - uses: dtolnay/rust-toolchain@master
# with:
# toolchain: ${{ matrix.rust }}
# targets: ${{ matrix.target }}
# - uses: RustCrypto/actions/cross-install@master
# - run: cross test --release --target ${{ matrix.target }} --all-features
cross:
strategy:
matrix:
include:
# ARM32
- target: armv7-unknown-linux-gnueabihf
rust: 1.72.0 # MSRV (cross)
- target: armv7-unknown-linux-gnueabihf
rust: stable
# ARM64
- target: aarch64-unknown-linux-gnu
rust: 1.72.0 # MSRV (cross)
- target: aarch64-unknown-linux-gnu
rust: stable
# PPC32
- target: powerpc-unknown-linux-gnu
rust: 1.72.0 # MSRV (cross)
- target: powerpc-unknown-linux-gnu
rust: stable
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- run: ${{ matrix.deps }}
- uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust }}
targets: ${{ matrix.target }}
- uses: RustCrypto/actions/cross-install@master
- run: cross test --release --target ${{ matrix.target }} --no-default-features --features "ml-kem-512 ml-kem-768 ml-kem-1024"

doc:
runs-on: ubuntu-latest
Expand Down
22 changes: 10 additions & 12 deletions benches/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,19 @@ criterion_main!(benches);

/*
Initial conditions
cargo bench
cargo bench # As of 1-1-24
Intel® Core™ i7-7700K CPU @ 4.20GHz × 8
ml_kem_512 KeyGen time: [64.996 µs 65.084 µs 65.246 µs]
ml_kem_768 KeyGen time: [102.50 µs 102.54 µs 102.62 µs]
ml_kem_1024 KeyGen time: [148.28 µs 148.33 µs 148.38 µs]
ml_kem_512 KeyGen time: [63.821 µs 63.830 µs 63.839 µs]
ml_kem_768 KeyGen time: [100.88 µs 100.89 µs 100.90 µs]
ml_kem_1024 KeyGen time: [146.53 µs 146.61 µs 146.70 µs]
ml_kem_512 Encaps time: [77.391 µs 77.424 µs 77.489 µs]
ml_kem_768 Encaps time: [117.76 µs 117.83 µs 117.90 µs]
ml_kem_1024 Encaps time: [167.77 µs 167.79 µs 167.82 µs]
ml_kem_512 Encaps time: [76.934 µs 76.948 µs 76.961 µs]
ml_kem_768 Encaps time: [117.93 µs 118.01 µs 118.08 µs]
ml_kem_1024 Encaps time: [168.68 µs 168.76 µs 168.85 µs]
ml_kem_512 Decaps time: [75.627 µs 75.671 µs 75.745 µs]
ml_kem_768 Decaps time: [115.20 µs 115.24 µs 115.27 µs]
ml_kem_1024 Decaps time: [164.48 µs 164.56 µs 164.67 µs]
ml_kem_512 Decaps time: [76.749 µs 76.887 µs 77.071 µs]
ml_kem_768 Decaps time: [117.05 µs 117.34 µs 117.84 µs]
ml_kem_1024 Decaps time: [167.51 µs 167.53 µs 167.57 µs]
*/
49 changes: 8 additions & 41 deletions src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub(crate) fn vec_add<const K: usize>(
let mut result = [[Z256(0); 256]; K];
for i in 0..vec_a.len() {
for j in 0..vec_a[i].len() {
result[i][j].set_u16(vec_a[i][j].get_u32() + vec_b[i][j].get_u32());
result[i][j] = vec_a[i][j].add(vec_b[i][j]); //.set_u16(vec_a[i][j].get_u32() + vec_b[i][j].get_u32());
}
}
result
Expand All @@ -33,7 +33,7 @@ pub(crate) fn mat_vec_mul<const K: usize>(
for j in 0..K {
let tmp = multiply_ntts(&a_hat[i][j], &u_hat[j]);
for k in 0..256 {
w_hat[i][k].set_u16(w_hat[i][k].get_u32() + tmp[k].get_u32());
w_hat[i][k] = w_hat[i][k].add(tmp[k]); //.set_u16(w_hat[i][k].get_u32() + tmp[k].get_u32());
}
}
}
Expand All @@ -53,7 +53,7 @@ pub(crate) fn mat_t_vec_mul<const K: usize>(
for j in 0..K {
let tmp = multiply_ntts(&a_hat[j][i], &u_hat[j]);
for k in 0..256 {
y_hat[i][k].set_u16(y_hat[i][k].get_u32() + tmp[k].get_u32());
y_hat[i][k] = y_hat[i][k].add(tmp[k]); //.set_u16(y_hat[i][k].get_u32() + tmp[k].get_u32());
}
}
}
Expand All @@ -70,7 +70,7 @@ pub(crate) fn dot_t_prod<const K: usize>(
for j in 0..K {
let tmp = multiply_ntts(&u_hat[j], &v_hat[j]);
for k in 0..256 {
result[k].set_u16(result[k].get_u32() + tmp[k].get_u32());
result[k] = result[k].add(tmp[k]); //.set_u16(result[k].get_u32() + tmp[k].get_u32());
}
}
result
Expand Down Expand Up @@ -136,43 +136,10 @@ pub(crate) fn j(bytes: &[u8]) -> [u8; 32] {
}


// REMOVED DUE TO ZETA_TABLE IN ntt.rs
// /// BitRev7(i) from page 21 line 839-840.
// /// Returns the integer represented by bit-reversing the unsigned 7-bit value that
// /// corresponds to the input integer i ∈ {0, . . . , 127}. (horrible perf)
// #[must_use]
// pub(crate) const fn bit_rev_7(a: u8) -> u8 {
// ((a >> 6) & 1)
// | ((a >> 4) & 2)
// | ((a >> 2) & 4)
// | (a & 8)
// | ((a << 2) & 16)
// | ((a << 4) & 32)
// | ((a << 6) & 64)
// }

// REMOVED DUE TO ZETA_TABLE IN ntt.rs
// /// HAC Algorithm 14.76 Right-to-left binary exponentiation mod Q.
// #[must_use]
// #[allow(dead_code)]
// pub(crate) fn pow_mod_q(g: u32, e: u8) -> u32 {
// let mut result = 1;
// let mut s = g;
// let mut e = e;
// while e != 0 {
// if e & 1 != 0 {
// result = (result * s) % Q;
// };
// e >>= 1;
// if e != 0 {
// s = (s * s) % Q;
// };
// }
// result
// }


/// Round to nearest
// BitRev7(i) from page 21 line 839-840 -- REMOVED DUE TO ZETA_TABLE IN ntt.rs


/// Round to nearest TODO: refine/optimize
fn nearest(numerator: u32, denominator: u32) -> u16 {
let remainder = numerator % denominator;
let quotient = u16::try_from(numerator / denominator).unwrap();
Expand Down
3 changes: 1 addition & 2 deletions src/k_pke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::helpers::{
compress, decompress, dot_t_prod, g, mat_t_vec_mul, mat_vec_mul, prf, vec_add, xof,
};
use crate::ntt::{ntt, ntt_inv};
use crate::Q;
use crate::sampling::{sample_ntt, sample_poly_cbd};
use crate::types::Z256;

Expand Down Expand Up @@ -297,7 +296,7 @@ pub(crate) fn k_pke_decrypt<
for _i in 0..K {
let yy = ntt_inv(&st_ntt_u);
for i in 0..256 {
w[i].set_u16((Q + v[i].get_u32() - yy[i].get_u32()) % Q);
w[i] = v[i].sub(yy[i]);
}
}

Expand Down
23 changes: 10 additions & 13 deletions src/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub fn ntt(array_f: &[Z256; 256]) -> [Z256; 256] {
for start in (0..256).step_by(2 * len) {
//
// 5: zeta ← ζ^{BitRev7 (k)} mod q
let zeta = ZETA_TABLE[k << 1];
let zeta = Z256(ZETA_TABLE[k << 1] as u16);


// 6: k ← k+1
Expand All @@ -32,13 +32,13 @@ pub fn ntt(array_f: &[Z256; 256]) -> [Z256; 256] {
for j in start..(start + len) {
//
// 8: t ← zeta · f_hat[ j + len] ▷ steps 8-10 done modulo q
let t = zeta * f_hat[j + len].get_u32() % Q;
let t = f_hat[j + len].mul(zeta);

// 9: f_hat[ j + len] ← f_hat [ j] − t
f_hat[j + len].set_u16((Q + f_hat[j].get_u32() - t) % Q);
f_hat[j + len] = f_hat[j].sub(t);

// 10: f_hat[ j] ← f_hat[ j] + t
f_hat[j].set_u16((f_hat[j].get_u32() + t) % Q);
f_hat[j] = f_hat[j].add(t);
//
} // 11: end for
} // 12: end for
Expand Down Expand Up @@ -70,7 +70,7 @@ pub fn ntt_inv(f_hat: &[Z256; 256]) -> [Z256; 256] {
for start in (0..256).step_by(2 * len) {
//
// 5: zeta ← ζ^{BitRev7(k)} mod q
let zeta = ZETA_TABLE[k << 1];
let zeta = Z256(ZETA_TABLE[k << 1] as u16);

// 6: k ← k − 1
k -= 1;
Expand All @@ -82,18 +82,18 @@ pub fn ntt_inv(f_hat: &[Z256; 256]) -> [Z256; 256] {
let t = f[j];

// 9: f [ j] ← t + f [ j + len] ▷ steps 9-10 done modulo q
f[j].set_u16((t.get_u32() + f[j + len].get_u32()) % Q);
f[j] = t.add(f[j + len]);

// 10: f [ j + len] ← zeta · ( f [ j + len] − t)
f[j + len].set_u16((zeta * (Q + f[j + len].get_u32() - t.get_u32())) % Q);
f[j + len] = zeta.mul(f[j + len].sub(t));
//
} // 11: end for
} // 12: end for
} // 13: end for

// 14: f ← f · 3303 mod q ▷ multiply every entry by 3303 ≡ 128^{−1} mod q
f.iter_mut()
.for_each(|item| item.set_u16(item.get_u32() * 3303 % Q));
.for_each(|item| *item = item.mul(Z256(3303)));

// 15: return f
f
Expand Down Expand Up @@ -136,13 +136,10 @@ pub fn base_case_multiply(a0: Z256, a1: Z256, b0: Z256, b1: Z256, gamma: Z256) -
// Input: γ ∈ Z_q ▷ the modulus is X^2 − γ
// Output: c0 , c1 ∈ Z_q ▷ the coeffcients of the product of the two polynomials
// 1: c0 ← a0 · b0 + a1 · b1 · γ ▷ steps 1-2 done modulo q
let c0 = Z256(
((a0.get_u32() * b0.get_u32() + (a1.get_u32() * b1.get_u32() % Q) * gamma.get_u32()) % Q)
as u16,
);
let c0 = a0.mul(b0).add(a1.mul(b1).mul(gamma));

// 2: 2: c1 ← a0 · b1 + a1 · b0
let c1 = Z256(((a0.get_u32() * b1.get_u32() + a1.get_u32() * b0.get_u32()) % Q) as u16);
let c1 = a0.mul(b1).add(a1.mul(b0));

// 3: return c0 , c1
(c0, c1)
Expand Down
9 changes: 5 additions & 4 deletions src/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub fn sample_ntt(mut byte_stream_b: impl XofReader) -> [Z256; 256] {
byte_stream_b.read(&mut bbb); // Draw 3 bytes

// 4: d1 ← B[i] + 256 · (B[i + 1] mod 16)
let d1 = u32::from(bbb[0]) + 256 * (u32::from(bbb[1]) % 16);
let d1 = u32::from(bbb[0]) + 256 * (u32::from(bbb[1]) & 0x0F);

// 5: d2 ← ⌊B[i + 1]/16⌋ + 16 · B[i + 2]
let d2 = u32::from(bbb[1]) / 16 + 16 * u32::from(bbb[2]);
Expand All @@ -33,7 +33,7 @@ pub fn sample_ntt(mut byte_stream_b: impl XofReader) -> [Z256; 256] {
if d1 < Q {
//
// 7: a_hat[j] ← d1 ▷ a_hat ∈ Z256
array_a_hat[j].set_u16(d1);
array_a_hat[j] = Z256(d1 as u16); //.set_u16(d1);

// 8: j ← j+1
j += 1;
Expand All @@ -44,7 +44,7 @@ pub fn sample_ntt(mut byte_stream_b: impl XofReader) -> [Z256; 256] {
if (d2 < Q) & (j < 256) {
//
// 11: a_hat[j] ← d2
array_a_hat[j].set_u16(d2);
array_a_hat[j] = Z256(d2 as u16); //.set_u16(d2);

// 12: j ← j+1
j += 1;
Expand Down Expand Up @@ -83,7 +83,8 @@ pub fn sample_poly_cbd<const ETA: usize, const ETA_512: usize>(byte_array_b: &[u
let y = (0..ETA).fold(0, |acc: u32, j| acc + u32::from(bit_array[2 * i * ETA + ETA + j]));

// 5: f [i] ← x − y mod q
array_f[i].set_u16((Q + x - y) % Q);
array_f[i] = Z256(x as u16).sub(Z256(y as u16));

//
} // 6: end for

Expand Down
Loading

0 comments on commit c466bc4

Please sign in to comment.