sign_encrypt.py: refactor with BinaryImage class

Moves manipulations of the TA binary into a BinaryImage class for some
abstraction and better structure of the code for the different
sub-commands.

Acked-by: Jerome Forissier <jerome.forissier@linaro.org>
Acked-by: Etienne Carriere <etienne.carriere@linaro.org>
Signed-off-by: Jens Wiklander <jens.wiklander@linaro.org>
diff --git a/scripts/sign_encrypt.py b/scripts/sign_encrypt.py
index c2af915..06e5d12 100755
--- a/scripts/sign_encrypt.py
+++ b/scripts/sign_encrypt.py
@@ -8,8 +8,10 @@
 import math
 
 
-algo = {'TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256': 0x70414930,
-        'TEE_ALG_RSASSA_PKCS1_V1_5_SHA256': 0x70004830}
+sig_tee_alg = {'TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256': 0x70414930,
+               'TEE_ALG_RSASSA_PKCS1_V1_5_SHA256': 0x70004830}
+
+enc_tee_alg = {'TEE_ALG_AES_GCM': 0x40000810}
 
 enc_key_type = {'SHDR_ENC_KEY_DEV_SPECIFIC': 0x0,
                 'SHDR_ENC_KEY_CLASS_WIDE': 0x1}
@@ -18,6 +20,17 @@
 SHDR_ENCRYPTED_TA = 2
 SHDR_MAGIC = 0x4f545348
 SHDR_SIZE = 20
+EHDR_SIZE = 12
+UUID_SIZE = 16
+# Use 12 bytes for nonce per recommendation
+NONCE_SIZE = 12
+TAG_SIZE = 16
+
+
+def value_to_key(db, val):
+    for k, v in db.items():
+        if v == val:
+            return k
 
 
 def uuid_parse(s):
@@ -85,7 +98,7 @@
 
     def arg_add_algo(parser):
         parser.add_argument(
-            '--algo', required=False, choices=list(algo.keys()),
+            '--algo', required=False, choices=list(sig_tee_alg.keys()),
             default='TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256', help='''
                 The hash and signature algorithm.
                 Defaults to TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256.''')
@@ -203,244 +216,295 @@
     return parsed
 
 
-def main():
-    from cryptography import exceptions
-    from cryptography.hazmat.backends import default_backend
-    from cryptography.hazmat.primitives import serialization
-    from cryptography.hazmat.primitives import hashes
-    from cryptography.hazmat.primitives.asymmetric import padding
-    from cryptography.hazmat.primitives.asymmetric import rsa
-    from cryptography.hazmat.primitives.asymmetric import utils
-    import base64
-    import logging
-    import os
-    import struct
-
-    logging.basicConfig()
-    global logger
-    logger = logging.getLogger(os.path.basename(__file__))
-
-    args = get_args()
-
-    if args.key.startswith('arn:'):
+def load_asymmetric_key(arg_key):
+    if arg_key.startswith('arn:'):
         from sign_helper_kms import _RSAPrivateKeyInKMS
-        key = _RSAPrivateKeyInKMS(args.key)
+        key = _RSAPrivateKeyInKMS(arg_key)
     else:
-        with open(args.key, 'rb') as f:
+        from cryptography.hazmat.backends import default_backend
+        from cryptography.hazmat.primitives.serialization import (
+            load_pem_private_key, load_pem_public_key)
+
+        with open(arg_key, 'rb') as f:
             data = f.read()
 
             try:
-                key = serialization.load_pem_private_key(
-                          data,
-                          password=None,
-                          backend=default_backend())
+                key = load_pem_private_key(data, password=None,
+                                           backend=default_backend())
             except ValueError:
-                key = serialization.load_pem_public_key(
-                          data,
-                          backend=default_backend())
+                key = load_pem_public_key(data, backend=default_backend())
 
-    with open(args.inf, 'rb') as f:
-        img = f.read()
+    return key
 
-    chosen_hash = hashes.SHA256()
-    h = hashes.Hash(chosen_hash, default_backend())
 
-    digest_len = chosen_hash.digest_size
-    sig_len = math.ceil(key.key_size / 8)
+class BinaryImage:
+    def __init__(self, arg_inf, arg_key):
+        from cryptography.hazmat.primitives import hashes
 
-    img_size = len(img)
+        # Exactly what inf is holding isn't determined a this stage
+        with open(arg_inf, 'rb') as f:
+            self.inf = f.read()
 
-    hdr_version = args.ta_version  # struct shdr_bootstrap_ta::ta_version
+        self.key = load_asymmetric_key(arg_key)
 
-    magic = SHDR_MAGIC
-    if args.enc_key:
-        img_type = SHDR_ENCRYPTED_TA
-    else:
-        img_type = SHDR_BOOTSTRAP_TA
+        self.chosen_hash = hashes.SHA256()
+        self.digest_len = self.chosen_hash.digest_size
+        self.sig_len = math.ceil(self.key.key_size / 8)
 
-    shdr = struct.pack('<IIIIHH',
-                       magic, img_type, img_size, algo[args.algo],
-                       digest_len, sig_len)
-    shdr_uuid = args.uuid.bytes
-    shdr_version = struct.pack('<I', hdr_version)
+    def __pack_img(self, img_type, sign_algo):
+        import struct
 
-    if args.enc_key:
+        self.sig_algo = sign_algo
+        self.img_type = img_type
+        self.shdr = struct.pack('<IIIIHH', SHDR_MAGIC, img_type, len(self.img),
+                                sig_tee_alg[sign_algo], self.digest_len,
+                                self.sig_len)
+
+    def __calc_digest(self):
+        from cryptography.hazmat.backends import default_backend
+        from cryptography.hazmat.primitives import hashes
+
+        h = hashes.Hash(self.chosen_hash, default_backend())
+        h.update(self.shdr)
+        h.update(self.ta_uuid)
+        h.update(self.ta_version)
+        if hasattr(self, 'ehdr'):
+            h.update(self.ehdr)
+            h.update(self.nonce)
+            h.update(self.tag)
+        h.update(self.img)
+        return h.finalize()
+
+    def encrypt_ta(self, enc_key, key_type, sig_algo, uuid, ta_version):
         from cryptography.hazmat.primitives.ciphers.aead import AESGCM
-        cipher = AESGCM(bytes.fromhex(args.enc_key))
-        # Use 12 bytes for nonce per recommendation
-        nonce = os.urandom(12)
-        out = cipher.encrypt(nonce, img, None)
-        ciphertext = out[:-16]
-        # Authentication Tag is always the last 16 bytes
-        tag = out[-16:]
+        import struct
+        import os
 
-        enc_algo = 0x40000810      # TEE_ALG_AES_GCM
-        flags = enc_key_type[args.enc_key_type]
-        ehdr = struct.pack('<IIHH',
-                           enc_algo, flags, len(nonce), len(tag))
+        self.img = self.inf
 
-    h.update(shdr)
-    h.update(shdr_uuid)
-    h.update(shdr_version)
-    if args.enc_key:
-        h.update(ehdr)
-        h.update(nonce)
-        h.update(tag)
-    h.update(img)
-    img_digest = h.finalize()
+        cipher = AESGCM(bytes.fromhex(enc_key))
+        self.nonce = os.urandom(NONCE_SIZE)
+        out = cipher.encrypt(self.nonce, self.img, None)
+        self.ciphertext = out[:-TAG_SIZE]
+        # Authentication Tag is always the last bytes
+        self.tag = out[-TAG_SIZE:]
 
-    def write_image_with_signature(sig):
-        with open(args.outf, 'wb') as f:
-            f.write(shdr)
-            f.write(img_digest)
-            f.write(sig)
-            f.write(shdr_uuid)
-            f.write(shdr_version)
-            if args.enc_key:
-                f.write(ehdr)
-                f.write(nonce)
-                f.write(tag)
-                f.write(ciphertext)
-            else:
-                f.write(img)
+        enc_algo = enc_tee_alg['TEE_ALG_AES_GCM']
+        flags = enc_key_type[key_type]
+        self.ehdr = struct.pack('<IIHH', enc_algo, flags, len(self.nonce),
+                                len(self.tag))
 
-    def sign_encrypt_ta():
-        if not isinstance(key, rsa.RSAPrivateKey):
-            logger.error('Provided key cannot be used for signing, ' +
-                         'please use offline-signing mode.')
-            sys.exit(1)
-        else:
-            if args.algo == 'TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256':
-                sig = key.sign(
-                    img_digest,
-                    padding.PSS(
-                        mgf=padding.MGF1(chosen_hash),
-                        salt_length=digest_len
-                    ),
-                    utils.Prehashed(chosen_hash)
-                )
-            elif args.algo == 'TEE_ALG_RSASSA_PKCS1_V1_5_SHA256':
-                sig = key.sign(
-                    img_digest,
-                    padding.PKCS1v15(),
-                    utils.Prehashed(chosen_hash)
-                )
+        self.__pack_img(SHDR_ENCRYPTED_TA, sig_algo)
+        self.ta_uuid = uuid.bytes
+        self.ta_version = struct.pack('<I', ta_version)
+        self.img_digest = self.__calc_digest()
 
-            if len(sig) != sig_len:
-                raise Exception(("Actual signature length is not equal to ",
-                                 "the computed one: {} != {}").
-                                format(len(sig), sig_len))
-            write_image_with_signature(sig)
-            logger.info('Successfully signed application.')
+    def set_bootstrap_ta(self, sig_algo, uuid, ta_version):
+        import struct
 
-    def generate_digest():
-        with open(args.digf, 'wb+') as digfile:
-            digfile.write(base64.b64encode(img_digest))
+        self.img = self.inf
+        self.__pack_img(SHDR_BOOTSTRAP_TA, sig_algo)
+        self.ta_uuid = uuid.bytes
+        self.ta_version = struct.pack('<I', ta_version)
+        self.img_digest = self.__calc_digest()
 
-    def stitch_ta():
-        try:
-            with open(args.sigf, 'r') as sigfile:
-                sig = base64.b64decode(sigfile.read())
-        except IOError:
-            if not os.path.exists(args.digf):
-                generate_digest()
-            logger.error('No signature file found. Please sign\n %s\n' +
-                         'offline and place the signature at \n %s\n' +
-                         'or pass a different location ' +
-                         'using the --sig argument.\n',
-                         args.digf, args.sigf)
-            sys.exit(1)
-        else:
-            try:
-                if args.algo == 'TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256':
-                    key.verify(
-                        sig,
-                        img_digest,
-                        padding.PSS(
-                            mgf=padding.MGF1(chosen_hash),
-                            salt_length=digest_len
-                        ),
-                        utils.Prehashed(chosen_hash)
-                    )
-                elif args.algo == 'TEE_ALG_RSASSA_PKCS1_V1_5_SHA256':
-                    key.verify(
-                        sig,
-                        img_digest,
-                        padding.PKCS1v15(),
-                        utils.Prehashed(chosen_hash)
-                    )
-            except exceptions.InvalidSignature:
-                logger.error('Verification failed, ignoring given signature.')
-                sys.exit(1)
+    def parse(self):
+        import struct
 
-            write_image_with_signature(sig)
-            logger.info('Successfully applied signature.')
-
-    def verify_ta():
-        # Extract header
-        [magic,
-         img_type,
-         img_size,
-         algo_value,
-         digest_len,
-         sig_len] = struct.unpack('<IIIIHH', img[:SHDR_SIZE])
-
-        # Extract digest and signature
-        start, end = SHDR_SIZE, SHDR_SIZE + digest_len
-        digest = img[start:end]
-
-        start, end = end, SHDR_SIZE + digest_len + sig_len
-        signature = img[start:end]
-
-        # Extract UUID and TA version
-        start, end = end, end + 16 + 4
-        [uuid, ta_version] = struct.unpack('<16sI', img[start:end])
+        offs = 0
+        self.shdr = self.inf[offs:offs + SHDR_SIZE]
+        [magic, img_type, img_size, algo_value, digest_len,
+         sig_len] = struct.unpack('<IIIIHH', self.shdr)
+        offs += SHDR_SIZE
 
         if magic != SHDR_MAGIC:
             raise Exception("Unexpected magic: 0x{:08x}".format(magic))
 
-        if img_type != SHDR_BOOTSTRAP_TA:
-            raise Exception("Unsupported image type: {}".format(img_type))
-
-        if algo_value not in algo.values():
+        if algo_value not in sig_tee_alg.values():
             raise Exception('Unrecognized algorithm: 0x{:08x}'
                             .format(algo_value))
+        self.sig_algo = value_to_key(sig_tee_alg, algo_value)
 
-        # Verify signature against hash digest
-        if algo_value == 0x70414930:
-            key.verify(
-                signature,
-                digest,
-                padding.PSS(
-                    mgf=padding.MGF1(chosen_hash),
-                    salt_length=digest_len
-                ),
-                utils.Prehashed(chosen_hash)
-            )
+        if digest_len != self.digest_len:
+            raise Exception("Unexpected digest len: {}".format(digest_len))
+
+        self.img_digest = self.inf[offs:offs + digest_len]
+        offs += digest_len
+        self.sig = self.inf[offs:offs + sig_len]
+        offs += sig_len
+
+        if img_type == SHDR_BOOTSTRAP_TA or img_type == SHDR_ENCRYPTED_TA:
+            self.ta_uuid = self.inf[offs:offs + UUID_SIZE]
+            offs += UUID_SIZE
+            self.ta_version = self.inf[offs:offs + 4]
+            offs += 4
+            if img_type == SHDR_ENCRYPTED_TA:
+                self.ehdr = self.inf[offs: offs + EHDR_SIZE]
+                offs += EHDR_SIZE
+                [enc_algo, flags, nonce_len,
+                 tag_len] = struct.unpack('<IIHH', self.ehdr)
+                if enc_value not in enc_tee_alg.values():
+                    raise Exception('Unrecognized encrypt algorithm: 0x{:08x}'
+                                    .format(enc_value))
+                if nonce_len != 12:
+                    raise Exception("Unexpected nonce len: {}"
+                                    .format(nonce_len))
+                self.nonce = self.inf[offs:offs + nonce_len]
+                offs += nonce_len
+
+                if tag_len != 16:
+                    raise Exception("Unexpected tag len: {}".format(tag_len))
+                self.tag = self.inf[-tag_len:]
+                self.ciphertext = self.inf[offs:-tag_len]
+                if len(self.ciphertext) != img_size:
+                    raise Exception("Unexpected ciphertext size: ",
+                                    "got {}, expected {}"
+                                    .format(len(self.ciphertext), img_size))
+            else:
+                self.img = self.inf[offs:]
+                if len(self.img) != img_size:
+                    raise Exception("Unexpected img size: got {}, expected {}"
+                                    .format(len(self.img), img_size))
         else:
-            key.verify(
-                signature,
-                digest,
-                padding.PKCS1v15(),
-                utils.Prehashed(chosen_hash)
-            )
+            raise Exception("Unsupported image type: {}".format(img_type))
 
-        h = hashes.Hash(chosen_hash, default_backend())
+    def decrypt_ta(enc_key):
+        from cryptography.hazmat.primitives.ciphers.aead import AESGCM
 
-        # sizeof(struct shdr)
-        h.update(img[:SHDR_SIZE])
+        cipher = AESGCM(bytes.fromhex(enc_key))
+        self.img = cipher.decrypt(self.nonce, self.ciphertext, None)
 
-        # sizeof(struct shdr_bootstrap_ta)
-        h.update(img[start:end])
+    def __get_padding(self):
+        from cryptography.hazmat.primitives.asymmetric import padding
 
-        # raw image
-        start = end
-        end += img_size
-        h.update(img[start:end])
+        if self.sig_algo == 'TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256':
+            pad = padding.PSS(mgf=padding.MGF1(self.chosen_hash),
+                              salt_length=self.digest_len)
+        elif self.sig_algo == 'TEE_ALG_RSASSA_PKCS1_V1_5_SHA256':
+            pad = padding.PKCS1v15()
 
-        if digest != h.finalize():
+        return pad
+
+    def sign(self):
+        from cryptography.hazmat.primitives.asymmetric import utils
+        from cryptography.hazmat.primitives.asymmetric import rsa
+
+        if not isinstance(self.key, rsa.RSAPrivateKey):
+            logger.error('Provided key cannot be used for signing, ' +
+                         'please use offline-signing mode.')
+            sys.exit(1)
+        else:
+            self.sig = self.key.sign(self.img_digest, self.__get_padding(),
+                                     utils.Prehashed(self.chosen_hash))
+
+            if len(self.sig) != self.sig_len:
+                raise Exception(("Actual signature length is not equal to ",
+                                 "the computed one: {} != {}").
+                                format(len(self.sig), self.sig_len))
+
+    def add_signature(self, sigf):
+        import base64
+
+        with open(sigf, 'r') as f:
+            self.sig = base64.b64decode(f.read())
+
+        if len(self.sig) != self.sig_len:
+            raise Exception(("Actual signature length is not equal to ",
+                             "the expected one: {} != {}").
+                            format(len(self.sig), self.sig_len))
+
+    def verify_signature(self):
+        from cryptography.hazmat.primitives.asymmetric import utils
+        from cryptography.hazmat.primitives.asymmetric import rsa
+        from cryptography import exceptions
+
+        if isinstance(self.key, rsa.RSAPrivateKey):
+            pkey = self.key.public_key()
+        else:
+            pkey = self.key
+
+        try:
+            pkey.verify(self.sig, self.img_digest, self.__get_padding(),
+                        utils.Prehashed(self.chosen_hash))
+        except exceptions.InvalidSignature:
+            logger.error('Verification failed, ignoring given signature.')
+            sys.exit(1)
+
+    def verify_digest(self):
+        if self.img_digest != self.__calc_digest():
             raise Exception('Hash digest does not match')
 
+    def verify_uuid(self, uuid):
+        if self.ta_uuid != uuid.bytes:
+            raise Exception('UUID does not match')
+
+    def write(self, outf):
+        with open(outf, 'wb') as f:
+            f.write(self.shdr)
+            f.write(self.img_digest)
+            f.write(self.sig)
+            f.write(self.ta_uuid)
+            f.write(self.ta_version)
+            if hasattr(self, 'ehdr'):
+                f.write(self.ehdr)
+                f.write(self.nonce)
+                f.write(self.tag)
+                f.write(self.ciphertext)
+            else:
+                f.write(self.img)
+
+
+def load_ta_image(args):
+    ta_image = BinaryImage(args.inf, args.key)
+
+    if args.enc_key:
+        ta_image.encrypt_ta(args.enc_key, args.enc_key_type,
+                            args.algo, args.uuid, args.ta_version)
+    else:
+        ta_image.set_bootstrap_ta(args.algo, args.uuid, args.ta_version)
+
+    return ta_image
+
+
+def main():
+    import logging
+    import os
+
+    global logger
+    logging.basicConfig()
+    logger = logging.getLogger(os.path.basename(__file__))
+
+    args = get_args()
+
+    def sign_encrypt_ta():
+        ta_image = load_ta_image(args)
+        ta_image.sign()
+        ta_image.write(args.outf)
+        logger.info('Successfully signed application.')
+
+    def generate_digest():
+        ta_image = load_ta_image(args)
+        with open(args.digf, 'wb+') as digfile:
+            digfile.write(base64.b64encode(binary_image.img_digest))
+
+    def stitch_ta():
+        ta_image = load_ta_image(args)
+        ta_image.add_signature(args.sigf)
+        ta_image.verify_signature()
+        ta_image.write(args.outf)
+        logger.info('Successfully applied signature.')
+
+    def verify_ta():
+        ta_image = BinaryImage(args.inf, args.key)
+        ta_image.parse()
+        if hasattr(ta_image, 'ciphertext'):
+            if args.enc_key is None:
+                logger.error('--enc_key needed to decrypt TA')
+                sys.exit(1)
+            ta_image.decrypt_ta(args.enc_key)
+        ta_image.verify_signature()
+        ta_image.verify_digest()
+        ta_image.verify_uuid(args.uuid)
         logger.info('Trusted application is correctly verified.')
 
     # dispatch command