-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathchacha20.py
78 lines (67 loc) · 2.59 KB
/
chacha20.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import struct
import binascii
from typing import List
# Finite field 2^32
class F2_32:
def __init__(self, val: int):
assert isinstance(val, int)
self.val = val
def __add__(self, other):
return F2_32((self.val + other.val) & 0xffffffff)
def __xor__(self, other):
return F2_32(self.val ^ other.val)
def __lshift__(self, nbit: int):
left = (self.val << nbit%32) & 0xffffffff
right = (self.val & 0xffffffff) >> (32-(nbit%32))
return F2_32(left | right)
def __repr__(self):
return hex(self.val)
def __int__(self):
return int(self.val)
def quarter_round(a: F2_32, b: F2_32, c: F2_32, d: F2_32):
a += b; d ^= a; d <<= 16;
c += d; b ^= c; b <<= 12;
a += b; d ^= a; d <<= 8;
c += d; b ^= c; b <<= 7;
return a, b, c, d
def Qround(state: List[F2_32], idx1, idx2, idx3, idx4):
state[idx1], state[idx2], state[idx3], state[idx4] = \
quarter_round(state[idx1], state[idx2], state[idx3], state[idx4])
def inner_block(state: List[F2_32]):
Qround(state, 0, 4, 8, 12)
Qround(state, 1, 5, 9, 13)
Qround(state, 2, 6, 10, 14)
Qround(state, 3, 7, 11, 15)
Qround(state, 0, 5, 10, 15)
Qround(state, 1, 6, 11, 12)
Qround(state, 2, 7, 8, 13)
Qround(state, 3, 4, 9, 14)
return state
def serialize(state: List[F2_32]) -> List[bytes]:
return b''.join([ struct.pack('<I', int(s)) for s in state ])
def chacha20_block(key: bytes, counter: int, nonce: bytes) -> bytes:
# make a state matrix
constants = [F2_32(x) for x in struct.unpack('<IIII', b'expand 32-byte k')]
key = [F2_32(x) for x in struct.unpack('<IIIIIIII', key)]
counter = [F2_32(counter)]
nonce = [F2_32(x) for x in struct.unpack('<III', nonce)]
state = constants + key + counter + nonce
initial_state = state[:]
for i in range(10):
state = inner_block(state)
state = [ s + init_s for s, init_s in zip(state, initial_state) ]
return serialize(state)
def xor(x: bytes, y: bytes):
return bytes(a ^ b for a, b in zip(x, y))
def chacha20_encrypt(key: bytes, counter: int, nonce: bytes, plaintext: bytes):
encrypted_message = bytearray(0)
for j in range(len(plaintext) // 64):
key_stream = chacha20_block(key, counter + j, nonce)
block = plaintext[j*64 : (j+1)*64]
encrypted_message += xor(block, key_stream)
if len(plaintext) % 64 != 0:
j = len(plaintext) // 64
key_stream = chacha20_block(key, counter + j, nonce)
block = plaintext[j*64 : ]
encrypted_message += xor(block, key_stream)
return encrypted_message