Skip to content

Commit

Permalink
fixed some lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
PyryL committed Dec 12, 2023
1 parent f1bcc0e commit 4565eed
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 33 deletions.
4 changes: 2 additions & 2 deletions kyber/ccakem.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from secrets import token_bytes
from kyber.encryption import generate_keys, Encrypt, Decrypt
from kyber.encryption import generate_keys, Encrypt, decrypt
from kyber.utils.pseudo_random import H, G, kdf
from kyber.constants import k, n, du, dv

Expand Down Expand Up @@ -56,7 +56,7 @@ def ccakem_decrypt(ciphertext: bytes, private_key: bytes, shared_secret_length:

assert h == H(pk)

m = Decrypt(sk, ciphertext).decrypt()
m = decrypt(sk, ciphertext)
Kr = G(m + h)
K, r = Kr[:32], Kr[32:]
c = Encrypt(pk, m, r).encrypt()
Expand Down
2 changes: 1 addition & 1 deletion kyber/encryption/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from kyber.encryption.keygen import generate_keys
from kyber.encryption.encrypt import Encrypt
from kyber.encryption.decrypt import Decrypt
from kyber.encryption.decrypt import decrypt
42 changes: 19 additions & 23 deletions kyber/encryption/decrypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,29 @@
from kyber.constants import n, k, du, dv
from kyber.entities.polring import PolynomialRing

class Decrypt:
def __init__(self, private_key, ciphertext) -> None:
self._sk = private_key
self._c = ciphertext
if len(self._sk) != 32*12*k:
raise ValueError()
if len(self._c) != du*k*n//8 + dv*n//8:
raise ValueError()
def decrypt(private_key, ciphertext) -> bytes:
"""
Decrypts the given ciphertext with the given private key.
:returns Decrypted 32-bit shared secret
"""

def decrypt(self) -> bytes:
"""
Decrypts the given ciphertext with the given private key.
:returns Decrypted 32-bit shared secret
"""
if len(private_key) != 32*12*k:
raise ValueError()
if len(ciphertext) != du*k*n//8 + dv*n//8:
raise ValueError()

s = np.array(decode(self._sk, 12))
s = np.array(decode(private_key, 12))

u, v = self._c[:du*k*n//8], self._c[du*k*n//8:]
u, v = ciphertext[:du*k*n//8], ciphertext[du*k*n//8:]

u = decode(u, du)
v = decode(v, dv)[0]
u = decode(u, du)
v = decode(v, dv)[0]

u = np.array([decompress(pol, du) for pol in u])
v = decompress(v, dv)
u = np.array([decompress(pol, du) for pol in u])
v = decompress(v, dv)

m: PolynomialRing = v - np.matmul(s.T, u)
m: bytes = encode(compress([m], 1), 1)
m: PolynomialRing = v - np.matmul(s.T, u)
m: bytes = encode(compress([m], 1), 1)

assert len(m) == 32
return m
assert len(m) == 32
return m
2 changes: 1 addition & 1 deletion tests/test_cbd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from random import seed,randbytes
from random import seed, randbytes
from base64 import b64decode
from kyber.utils.cbd import cbd

Expand Down
8 changes: 4 additions & 4 deletions tests/test_decrypt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from random import seed, randbytes
from kyber.encryption import Decrypt
from kyber.encryption import decrypt
from kyber.constants import k, n, du, dv

class TestDecrypt(unittest.TestCase):
Expand All @@ -10,7 +10,7 @@ def setUp(self):
def test_decryption_outputs_valid_shared_secret(self):
private_key = randbytes(32*12*k)
ciphertext = randbytes(du*k*n//8 + dv*n//8)
shared_secret = Decrypt(private_key, ciphertext).decrypt()
shared_secret = decrypt(private_key, ciphertext)
self.assertEqual(type(shared_secret), bytes)
self.assertEqual(len(shared_secret), 32)

Expand All @@ -19,11 +19,11 @@ def test_decryption_raises_with_invalid_private_key(self):
invalid_private_key = randbytes(32*12*k + 1)
valid_ciphertext = randbytes(du*k*n//8 + dv*n//8)
with self.assertRaises(ValueError):
Decrypt(invalid_private_key, valid_ciphertext)
decrypt(invalid_private_key, valid_ciphertext)

def test_decryption_raises_with_invalid_ciphertext(self):
# this ciphertext is one byte too short
valid_private_key = randbytes(32*12*k)
invalid_ciphertext = randbytes(du*k*n//8 + dv*n//8 - 1)
with self.assertRaises(ValueError):
Decrypt(valid_private_key, invalid_ciphertext)
decrypt(valid_private_key, invalid_ciphertext)
4 changes: 2 additions & 2 deletions tests/test_encryption.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import unittest
from kyber.encryption import generate_keys, Encrypt, Decrypt
from kyber.encryption import generate_keys, Encrypt, decrypt

class TestIntegration(unittest.TestCase):
def test_encryption_symmetry(self):
# test the whole process of key generation, encryption and decryption
private_key, public_key = generate_keys()
encrypter = Encrypt(public_key)
ciphertext = encrypter.encrypt()
decrypted_shared_secret = Decrypt(private_key, ciphertext).decrypt()
decrypted_shared_secret = decrypt(private_key, ciphertext)
self.assertEqual(encrypter.secret, decrypted_shared_secret)
self.assertEqual(len(encrypter.secret), 32)

0 comments on commit 4565eed

Please sign in to comment.