]> Raphaƫl G. Git Repositories - youtubedl/blobdiff - youtube_dl/aes.py
Update upstream source from tag 'upstream/2020.06.16'
[youtubedl] / youtube_dl / aes.py
index e9c5e21521d66baa177986e8ca878e3fc1a75461..461bb6d413a91bde8f408667d838088c5f8e11be 100644 (file)
@@ -1,16 +1,17 @@
-__all__ = ['aes_encrypt', 'key_expansion', 'aes_ctr_decrypt', 'aes_cbc_decrypt', 'aes_decrypt_text']
+from __future__ import unicode_literals
 
-import base64
 from math import ceil
 
+from .compat import compat_b64decode
 from .utils import bytes_to_intlist, intlist_to_bytes
 
 BLOCK_SIZE_BYTES = 16
 
+
 def aes_ctr_decrypt(data, key, counter):
     """
     Decrypt with aes in counter mode
-    
+
     @param {int[]} data        cipher
     @param {int[]} key         16/24/32-Byte cipher key
     @param {instance} counter  Instance whose next_value function (@returns {int[]}  16-Byte block)
@@ -19,23 +20,24 @@ def aes_ctr_decrypt(data, key, counter):
     """
     expanded_key = key_expansion(key)
     block_count = int(ceil(float(len(data)) / BLOCK_SIZE_BYTES))
-    
-    decrypted_data=[]
+
+    decrypted_data = []
     for i in range(block_count):
         counter_block = counter.next_value()
-        block = data[i*BLOCK_SIZE_BYTES : (i+1)*BLOCK_SIZE_BYTES]
-        block += [0]*(BLOCK_SIZE_BYTES - len(block))
-        
+        block = data[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES]
+        block += [0] * (BLOCK_SIZE_BYTES - len(block))
+
         cipher_counter_block = aes_encrypt(counter_block, expanded_key)
         decrypted_data += xor(block, cipher_counter_block)
     decrypted_data = decrypted_data[:len(data)]
-    
+
     return decrypted_data
 
+
 def aes_cbc_decrypt(data, key, iv):
     """
     Decrypt with aes in CBC mode
-    
+
     @param {int[]} data        cipher
     @param {int[]} key         16/24/32-Byte cipher key
     @param {int[]} iv          16-Byte IV
@@ -43,94 +45,126 @@ def aes_cbc_decrypt(data, key, iv):
     """
     expanded_key = key_expansion(key)
     block_count = int(ceil(float(len(data)) / BLOCK_SIZE_BYTES))
-    
-    decrypted_data=[]
+
+    decrypted_data = []
     previous_cipher_block = iv
     for i in range(block_count):
-        block = data[i*BLOCK_SIZE_BYTES : (i+1)*BLOCK_SIZE_BYTES]
-        block += [0]*(BLOCK_SIZE_BYTES - len(block))
-        
+        block = data[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES]
+        block += [0] * (BLOCK_SIZE_BYTES - len(block))
+
         decrypted_block = aes_decrypt(block, expanded_key)
         decrypted_data += xor(decrypted_block, previous_cipher_block)
         previous_cipher_block = block
     decrypted_data = decrypted_data[:len(data)]
-    
+
     return decrypted_data
 
+
+def aes_cbc_encrypt(data, key, iv):
+    """
+    Encrypt with aes in CBC mode. Using PKCS#7 padding
+
+    @param {int[]} data        cleartext
+    @param {int[]} key         16/24/32-Byte cipher key
+    @param {int[]} iv          16-Byte IV
+    @returns {int[]}           encrypted data
+    """
+    expanded_key = key_expansion(key)
+    block_count = int(ceil(float(len(data)) / BLOCK_SIZE_BYTES))
+
+    encrypted_data = []
+    previous_cipher_block = iv
+    for i in range(block_count):
+        block = data[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES]
+        remaining_length = BLOCK_SIZE_BYTES - len(block)
+        block += [remaining_length] * remaining_length
+        mixed_block = xor(block, previous_cipher_block)
+
+        encrypted_block = aes_encrypt(mixed_block, expanded_key)
+        encrypted_data += encrypted_block
+
+        previous_cipher_block = encrypted_block
+
+    return encrypted_data
+
+
 def key_expansion(data):
     """
     Generate key schedule
-    
+
     @param {int[]} data  16/24/32-Byte cipher key
-    @returns {int[]}     176/208/240-Byte expanded key 
+    @returns {int[]}     176/208/240-Byte expanded key
     """
-    data = data[:] # copy
+    data = data[:]  # copy
     rcon_iteration = 1
     key_size_bytes = len(data)
     expanded_key_size_bytes = (key_size_bytes // 4 + 7) * BLOCK_SIZE_BYTES
-    
+
     while len(data) < expanded_key_size_bytes:
         temp = data[-4:]
         temp = key_schedule_core(temp, rcon_iteration)
         rcon_iteration += 1
-        data += xor(temp, data[-key_size_bytes : 4-key_size_bytes])
-        
+        data += xor(temp, data[-key_size_bytes: 4 - key_size_bytes])
+
         for _ in range(3):
             temp = data[-4:]
-            data += xor(temp, data[-key_size_bytes : 4-key_size_bytes])
-        
+            data += xor(temp, data[-key_size_bytes: 4 - key_size_bytes])
+
         if key_size_bytes == 32:
             temp = data[-4:]
             temp = sub_bytes(temp)
-            data += xor(temp, data[-key_size_bytes : 4-key_size_bytes])
-        
-        for _ in range(3 if key_size_bytes == 32  else 2 if key_size_bytes == 24 else 0):
+            data += xor(temp, data[-key_size_bytes: 4 - key_size_bytes])
+
+        for _ in range(3 if key_size_bytes == 32 else 2 if key_size_bytes == 24 else 0):
             temp = data[-4:]
-            data += xor(temp, data[-key_size_bytes : 4-key_size_bytes])
+            data += xor(temp, data[-key_size_bytes: 4 - key_size_bytes])
     data = data[:expanded_key_size_bytes]
-    
+
     return data
 
+
 def aes_encrypt(data, expanded_key):
     """
     Encrypt one block with aes
-    
+
     @param {int[]} data          16-Byte state
-    @param {int[]} expanded_key  176/208/240-Byte expanded key 
+    @param {int[]} expanded_key  176/208/240-Byte expanded key
     @returns {int[]}             16-Byte cipher
     """
     rounds = len(expanded_key) // BLOCK_SIZE_BYTES - 1
 
     data = xor(data, expanded_key[:BLOCK_SIZE_BYTES])
-    for i in range(1, rounds+1):
+    for i in range(1, rounds + 1):
         data = sub_bytes(data)
         data = shift_rows(data)
         if i != rounds:
             data = mix_columns(data)
-        data = xor(data, expanded_key[i*BLOCK_SIZE_BYTES : (i+1)*BLOCK_SIZE_BYTES])
+        data = xor(data, expanded_key[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES])
 
     return data
 
+
 def aes_decrypt(data, expanded_key):
     """
     Decrypt one block with aes
-    
+
     @param {int[]} data          16-Byte cipher
     @param {int[]} expanded_key  176/208/240-Byte expanded key
     @returns {int[]}             16-Byte state
     """
     rounds = len(expanded_key) // BLOCK_SIZE_BYTES - 1
-    
+
     for i in range(rounds, 0, -1):
-        data = xor(data, expanded_key[i*BLOCK_SIZE_BYTES : (i+1)*BLOCK_SIZE_BYTES])
+        data = xor(data, expanded_key[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES])
         if i != rounds:
             data = mix_columns_inv(data)
         data = shift_rows_inv(data)
         data = sub_bytes_inv(data)
     data = xor(data, expanded_key[:BLOCK_SIZE_BYTES])
-    
+
     return data
 
+
 def aes_decrypt_text(data, password, key_size_bytes):
     """
     Decrypt text
@@ -138,35 +172,37 @@ def aes_decrypt_text(data, password, key_size_bytes):
     - The cipher key is retrieved by encrypting the first 16 Byte of 'password'
       with the first 'key_size_bytes' Bytes from 'password' (if necessary filled with 0's)
     - Mode of operation is 'counter'
-    
+
     @param {str} data                    Base64 encoded string
     @param {str,unicode} password        Password (will be encoded with utf-8)
     @param {int} key_size_bytes          Possible values: 16 for 128-Bit, 24 for 192-Bit or 32 for 256-Bit
     @returns {str}                       Decrypted data
     """
     NONCE_LENGTH_BYTES = 8
-    
-    data = bytes_to_intlist(base64.b64decode(data))
+
+    data = bytes_to_intlist(compat_b64decode(data))
     password = bytes_to_intlist(password.encode('utf-8'))
-    
-    key = password[:key_size_bytes] + [0]*(key_size_bytes - len(password))
+
+    key = password[:key_size_bytes] + [0] * (key_size_bytes - len(password))
     key = aes_encrypt(key[:BLOCK_SIZE_BYTES], key_expansion(key)) * (key_size_bytes // BLOCK_SIZE_BYTES)
-    
+
     nonce = data[:NONCE_LENGTH_BYTES]
     cipher = data[NONCE_LENGTH_BYTES:]
-    
-    class Counter:
-        __value = nonce + [0]*(BLOCK_SIZE_BYTES - NONCE_LENGTH_BYTES)
+
+    class Counter(object):
+        __value = nonce + [0] * (BLOCK_SIZE_BYTES - NONCE_LENGTH_BYTES)
+
         def next_value(self):
             temp = self.__value
             self.__value = inc(self.__value)
             return temp
-    
+
     decrypted_data = aes_ctr_decrypt(cipher, key, Counter())
     plaintext = intlist_to_bytes(decrypted_data)
-    
+
     return plaintext
 
+
 RCON = (0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36)
 SBOX = (0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
         0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
@@ -200,14 +236,14 @@ SBOX_INV = (0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x
             0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
             0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
             0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d)
-MIX_COLUMN_MATRIX = ((0x2,0x3,0x1,0x1),
-                     (0x1,0x2,0x3,0x1),
-                     (0x1,0x1,0x2,0x3),
-                     (0x3,0x1,0x1,0x2))
-MIX_COLUMN_MATRIX_INV = ((0xE,0xB,0xD,0x9),
-                         (0x9,0xE,0xB,0xD),
-                         (0xD,0x9,0xE,0xB),
-                         (0xB,0xD,0x9,0xE))
+MIX_COLUMN_MATRIX = ((0x2, 0x3, 0x1, 0x1),
+                     (0x1, 0x2, 0x3, 0x1),
+                     (0x1, 0x1, 0x2, 0x3),
+                     (0x3, 0x1, 0x1, 0x2))
+MIX_COLUMN_MATRIX_INV = ((0xE, 0xB, 0xD, 0x9),
+                         (0x9, 0xE, 0xB, 0xD),
+                         (0xD, 0x9, 0xE, 0xB),
+                         (0xB, 0xD, 0x9, 0xE))
 RIJNDAEL_EXP_TABLE = (0x01, 0x03, 0x05, 0x0F, 0x11, 0x33, 0x55, 0xFF, 0x1A, 0x2E, 0x72, 0x96, 0xA1, 0xF8, 0x13, 0x35,
                       0x5F, 0xE1, 0x38, 0x48, 0xD8, 0x73, 0x95, 0xA4, 0xF7, 0x02, 0x06, 0x0A, 0x1E, 0x22, 0x66, 0xAA,
                       0xE5, 0x34, 0x5C, 0xE4, 0x37, 0x59, 0xEB, 0x26, 0x6A, 0xBE, 0xD9, 0x70, 0x90, 0xAB, 0xE6, 0x31,
@@ -241,30 +277,37 @@ RIJNDAEL_LOG_TABLE = (0x00, 0x00, 0x19, 0x01, 0x32, 0x02, 0x1a, 0xc6, 0x4b, 0xc7
                       0x44, 0x11, 0x92, 0xd9, 0x23, 0x20, 0x2e, 0x89, 0xb4, 0x7c, 0xb8, 0x26, 0x77, 0x99, 0xe3, 0xa5,
                       0x67, 0x4a, 0xed, 0xde, 0xc5, 0x31, 0xfe, 0x18, 0x0d, 0x63, 0x8c, 0x80, 0xc0, 0xf7, 0x70, 0x07)
 
+
 def sub_bytes(data):
     return [SBOX[x] for x in data]
 
+
 def sub_bytes_inv(data):
     return [SBOX_INV[x] for x in data]
 
+
 def rotate(data):
     return data[1:] + [data[0]]
 
+
 def key_schedule_core(data, rcon_iteration):
     data = rotate(data)
     data = sub_bytes(data)
     data[0] = data[0] ^ RCON[rcon_iteration]
-    
+
     return data
 
+
 def xor(data1, data2):
-    return [x^y for x, y in zip(data1, data2)]
+    return [x ^ y for x, y in zip(data1, data2)]
+
 
 def rijndael_mul(a, b):
-    if(a==0 or b==0):
+    if(a == 0 or b == 0):
         return 0
     return RIJNDAEL_EXP_TABLE[(RIJNDAEL_LOG_TABLE[a] + RIJNDAEL_LOG_TABLE[b]) % 0xFF]
 
+
 def mix_column(data, matrix):
     data_mixed = []
     for row in range(4):
@@ -275,36 +318,44 @@ def mix_column(data, matrix):
         data_mixed.append(mixed)
     return data_mixed
 
+
 def mix_columns(data, matrix=MIX_COLUMN_MATRIX):
     data_mixed = []
     for i in range(4):
-        column = data[i*4 : (i+1)*4]
+        column = data[i * 4: (i + 1) * 4]
         data_mixed += mix_column(column, matrix)
     return data_mixed
 
+
 def mix_columns_inv(data):
     return mix_columns(data, MIX_COLUMN_MATRIX_INV)
 
+
 def shift_rows(data):
     data_shifted = []
     for column in range(4):
         for row in range(4):
-            data_shifted.append( data[((column + row) & 0b11) * 4 + row] )
+            data_shifted.append(data[((column + row) & 0b11) * 4 + row])
     return data_shifted
 
+
 def shift_rows_inv(data):
     data_shifted = []
     for column in range(4):
         for row in range(4):
-            data_shifted.append( data[((column - row) & 0b11) * 4 + row] )
+            data_shifted.append(data[((column - row) & 0b11) * 4 + row])
     return data_shifted
 
+
 def inc(data):
-    data = data[:] # copy
-    for i in range(len(data)-1,-1,-1):
+    data = data[:]  # copy
+    for i in range(len(data) - 1, -1, -1):
         if data[i] == 255:
             data[i] = 0
         else:
             data[i] = data[i] + 1
             break
     return data
+
+
+__all__ = ['aes_encrypt', 'key_expansion', 'aes_ctr_decrypt', 'aes_cbc_decrypt', 'aes_decrypt_text']