sign_encrypt.py: refactor with BinaryImage class
Moves manipulations of the TA binary into a BinaryImage class for some
abstraction and better structure of the code for the different
sub-commands.
Acked-by: Jerome Forissier <jerome.forissier@linaro.org>
Acked-by: Etienne Carriere <etienne.carriere@linaro.org>
Signed-off-by: Jens Wiklander <jens.wiklander@linaro.org>
diff --git a/scripts/sign_encrypt.py b/scripts/sign_encrypt.py
index c2af915..06e5d12 100755
--- a/scripts/sign_encrypt.py
+++ b/scripts/sign_encrypt.py
@@ -8,8 +8,10 @@
import math
-algo = {'TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256': 0x70414930,
- 'TEE_ALG_RSASSA_PKCS1_V1_5_SHA256': 0x70004830}
+sig_tee_alg = {'TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256': 0x70414930,
+ 'TEE_ALG_RSASSA_PKCS1_V1_5_SHA256': 0x70004830}
+
+enc_tee_alg = {'TEE_ALG_AES_GCM': 0x40000810}
enc_key_type = {'SHDR_ENC_KEY_DEV_SPECIFIC': 0x0,
'SHDR_ENC_KEY_CLASS_WIDE': 0x1}
@@ -18,6 +20,17 @@
SHDR_ENCRYPTED_TA = 2
SHDR_MAGIC = 0x4f545348
SHDR_SIZE = 20
+EHDR_SIZE = 12
+UUID_SIZE = 16
+# Use 12 bytes for nonce per recommendation
+NONCE_SIZE = 12
+TAG_SIZE = 16
+
+
+def value_to_key(db, val):
+ for k, v in db.items():
+ if v == val:
+ return k
def uuid_parse(s):
@@ -85,7 +98,7 @@
def arg_add_algo(parser):
parser.add_argument(
- '--algo', required=False, choices=list(algo.keys()),
+ '--algo', required=False, choices=list(sig_tee_alg.keys()),
default='TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256', help='''
The hash and signature algorithm.
Defaults to TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256.''')
@@ -203,244 +216,295 @@
return parsed
-def main():
- from cryptography import exceptions
- from cryptography.hazmat.backends import default_backend
- from cryptography.hazmat.primitives import serialization
- from cryptography.hazmat.primitives import hashes
- from cryptography.hazmat.primitives.asymmetric import padding
- from cryptography.hazmat.primitives.asymmetric import rsa
- from cryptography.hazmat.primitives.asymmetric import utils
- import base64
- import logging
- import os
- import struct
-
- logging.basicConfig()
- global logger
- logger = logging.getLogger(os.path.basename(__file__))
-
- args = get_args()
-
- if args.key.startswith('arn:'):
+def load_asymmetric_key(arg_key):
+ if arg_key.startswith('arn:'):
from sign_helper_kms import _RSAPrivateKeyInKMS
- key = _RSAPrivateKeyInKMS(args.key)
+ key = _RSAPrivateKeyInKMS(arg_key)
else:
- with open(args.key, 'rb') as f:
+ from cryptography.hazmat.backends import default_backend
+ from cryptography.hazmat.primitives.serialization import (
+ load_pem_private_key, load_pem_public_key)
+
+ with open(arg_key, 'rb') as f:
data = f.read()
try:
- key = serialization.load_pem_private_key(
- data,
- password=None,
- backend=default_backend())
+ key = load_pem_private_key(data, password=None,
+ backend=default_backend())
except ValueError:
- key = serialization.load_pem_public_key(
- data,
- backend=default_backend())
+ key = load_pem_public_key(data, backend=default_backend())
- with open(args.inf, 'rb') as f:
- img = f.read()
+ return key
- chosen_hash = hashes.SHA256()
- h = hashes.Hash(chosen_hash, default_backend())
- digest_len = chosen_hash.digest_size
- sig_len = math.ceil(key.key_size / 8)
+class BinaryImage:
+ def __init__(self, arg_inf, arg_key):
+ from cryptography.hazmat.primitives import hashes
- img_size = len(img)
+ # Exactly what inf is holding isn't determined a this stage
+ with open(arg_inf, 'rb') as f:
+ self.inf = f.read()
- hdr_version = args.ta_version # struct shdr_bootstrap_ta::ta_version
+ self.key = load_asymmetric_key(arg_key)
- magic = SHDR_MAGIC
- if args.enc_key:
- img_type = SHDR_ENCRYPTED_TA
- else:
- img_type = SHDR_BOOTSTRAP_TA
+ self.chosen_hash = hashes.SHA256()
+ self.digest_len = self.chosen_hash.digest_size
+ self.sig_len = math.ceil(self.key.key_size / 8)
- shdr = struct.pack('<IIIIHH',
- magic, img_type, img_size, algo[args.algo],
- digest_len, sig_len)
- shdr_uuid = args.uuid.bytes
- shdr_version = struct.pack('<I', hdr_version)
+ def __pack_img(self, img_type, sign_algo):
+ import struct
- if args.enc_key:
+ self.sig_algo = sign_algo
+ self.img_type = img_type
+ self.shdr = struct.pack('<IIIIHH', SHDR_MAGIC, img_type, len(self.img),
+ sig_tee_alg[sign_algo], self.digest_len,
+ self.sig_len)
+
+ def __calc_digest(self):
+ from cryptography.hazmat.backends import default_backend
+ from cryptography.hazmat.primitives import hashes
+
+ h = hashes.Hash(self.chosen_hash, default_backend())
+ h.update(self.shdr)
+ h.update(self.ta_uuid)
+ h.update(self.ta_version)
+ if hasattr(self, 'ehdr'):
+ h.update(self.ehdr)
+ h.update(self.nonce)
+ h.update(self.tag)
+ h.update(self.img)
+ return h.finalize()
+
+ def encrypt_ta(self, enc_key, key_type, sig_algo, uuid, ta_version):
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
- cipher = AESGCM(bytes.fromhex(args.enc_key))
- # Use 12 bytes for nonce per recommendation
- nonce = os.urandom(12)
- out = cipher.encrypt(nonce, img, None)
- ciphertext = out[:-16]
- # Authentication Tag is always the last 16 bytes
- tag = out[-16:]
+ import struct
+ import os
- enc_algo = 0x40000810 # TEE_ALG_AES_GCM
- flags = enc_key_type[args.enc_key_type]
- ehdr = struct.pack('<IIHH',
- enc_algo, flags, len(nonce), len(tag))
+ self.img = self.inf
- h.update(shdr)
- h.update(shdr_uuid)
- h.update(shdr_version)
- if args.enc_key:
- h.update(ehdr)
- h.update(nonce)
- h.update(tag)
- h.update(img)
- img_digest = h.finalize()
+ cipher = AESGCM(bytes.fromhex(enc_key))
+ self.nonce = os.urandom(NONCE_SIZE)
+ out = cipher.encrypt(self.nonce, self.img, None)
+ self.ciphertext = out[:-TAG_SIZE]
+ # Authentication Tag is always the last bytes
+ self.tag = out[-TAG_SIZE:]
- def write_image_with_signature(sig):
- with open(args.outf, 'wb') as f:
- f.write(shdr)
- f.write(img_digest)
- f.write(sig)
- f.write(shdr_uuid)
- f.write(shdr_version)
- if args.enc_key:
- f.write(ehdr)
- f.write(nonce)
- f.write(tag)
- f.write(ciphertext)
- else:
- f.write(img)
+ enc_algo = enc_tee_alg['TEE_ALG_AES_GCM']
+ flags = enc_key_type[key_type]
+ self.ehdr = struct.pack('<IIHH', enc_algo, flags, len(self.nonce),
+ len(self.tag))
- def sign_encrypt_ta():
- if not isinstance(key, rsa.RSAPrivateKey):
- logger.error('Provided key cannot be used for signing, ' +
- 'please use offline-signing mode.')
- sys.exit(1)
- else:
- if args.algo == 'TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256':
- sig = key.sign(
- img_digest,
- padding.PSS(
- mgf=padding.MGF1(chosen_hash),
- salt_length=digest_len
- ),
- utils.Prehashed(chosen_hash)
- )
- elif args.algo == 'TEE_ALG_RSASSA_PKCS1_V1_5_SHA256':
- sig = key.sign(
- img_digest,
- padding.PKCS1v15(),
- utils.Prehashed(chosen_hash)
- )
+ self.__pack_img(SHDR_ENCRYPTED_TA, sig_algo)
+ self.ta_uuid = uuid.bytes
+ self.ta_version = struct.pack('<I', ta_version)
+ self.img_digest = self.__calc_digest()
- if len(sig) != sig_len:
- raise Exception(("Actual signature length is not equal to ",
- "the computed one: {} != {}").
- format(len(sig), sig_len))
- write_image_with_signature(sig)
- logger.info('Successfully signed application.')
+ def set_bootstrap_ta(self, sig_algo, uuid, ta_version):
+ import struct
- def generate_digest():
- with open(args.digf, 'wb+') as digfile:
- digfile.write(base64.b64encode(img_digest))
+ self.img = self.inf
+ self.__pack_img(SHDR_BOOTSTRAP_TA, sig_algo)
+ self.ta_uuid = uuid.bytes
+ self.ta_version = struct.pack('<I', ta_version)
+ self.img_digest = self.__calc_digest()
- def stitch_ta():
- try:
- with open(args.sigf, 'r') as sigfile:
- sig = base64.b64decode(sigfile.read())
- except IOError:
- if not os.path.exists(args.digf):
- generate_digest()
- logger.error('No signature file found. Please sign\n %s\n' +
- 'offline and place the signature at \n %s\n' +
- 'or pass a different location ' +
- 'using the --sig argument.\n',
- args.digf, args.sigf)
- sys.exit(1)
- else:
- try:
- if args.algo == 'TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256':
- key.verify(
- sig,
- img_digest,
- padding.PSS(
- mgf=padding.MGF1(chosen_hash),
- salt_length=digest_len
- ),
- utils.Prehashed(chosen_hash)
- )
- elif args.algo == 'TEE_ALG_RSASSA_PKCS1_V1_5_SHA256':
- key.verify(
- sig,
- img_digest,
- padding.PKCS1v15(),
- utils.Prehashed(chosen_hash)
- )
- except exceptions.InvalidSignature:
- logger.error('Verification failed, ignoring given signature.')
- sys.exit(1)
+ def parse(self):
+ import struct
- write_image_with_signature(sig)
- logger.info('Successfully applied signature.')
-
- def verify_ta():
- # Extract header
- [magic,
- img_type,
- img_size,
- algo_value,
- digest_len,
- sig_len] = struct.unpack('<IIIIHH', img[:SHDR_SIZE])
-
- # Extract digest and signature
- start, end = SHDR_SIZE, SHDR_SIZE + digest_len
- digest = img[start:end]
-
- start, end = end, SHDR_SIZE + digest_len + sig_len
- signature = img[start:end]
-
- # Extract UUID and TA version
- start, end = end, end + 16 + 4
- [uuid, ta_version] = struct.unpack('<16sI', img[start:end])
+ offs = 0
+ self.shdr = self.inf[offs:offs + SHDR_SIZE]
+ [magic, img_type, img_size, algo_value, digest_len,
+ sig_len] = struct.unpack('<IIIIHH', self.shdr)
+ offs += SHDR_SIZE
if magic != SHDR_MAGIC:
raise Exception("Unexpected magic: 0x{:08x}".format(magic))
- if img_type != SHDR_BOOTSTRAP_TA:
- raise Exception("Unsupported image type: {}".format(img_type))
-
- if algo_value not in algo.values():
+ if algo_value not in sig_tee_alg.values():
raise Exception('Unrecognized algorithm: 0x{:08x}'
.format(algo_value))
+ self.sig_algo = value_to_key(sig_tee_alg, algo_value)
- # Verify signature against hash digest
- if algo_value == 0x70414930:
- key.verify(
- signature,
- digest,
- padding.PSS(
- mgf=padding.MGF1(chosen_hash),
- salt_length=digest_len
- ),
- utils.Prehashed(chosen_hash)
- )
+ if digest_len != self.digest_len:
+ raise Exception("Unexpected digest len: {}".format(digest_len))
+
+ self.img_digest = self.inf[offs:offs + digest_len]
+ offs += digest_len
+ self.sig = self.inf[offs:offs + sig_len]
+ offs += sig_len
+
+ if img_type == SHDR_BOOTSTRAP_TA or img_type == SHDR_ENCRYPTED_TA:
+ self.ta_uuid = self.inf[offs:offs + UUID_SIZE]
+ offs += UUID_SIZE
+ self.ta_version = self.inf[offs:offs + 4]
+ offs += 4
+ if img_type == SHDR_ENCRYPTED_TA:
+ self.ehdr = self.inf[offs: offs + EHDR_SIZE]
+ offs += EHDR_SIZE
+ [enc_algo, flags, nonce_len,
+ tag_len] = struct.unpack('<IIHH', self.ehdr)
+ if enc_value not in enc_tee_alg.values():
+ raise Exception('Unrecognized encrypt algorithm: 0x{:08x}'
+ .format(enc_value))
+ if nonce_len != 12:
+ raise Exception("Unexpected nonce len: {}"
+ .format(nonce_len))
+ self.nonce = self.inf[offs:offs + nonce_len]
+ offs += nonce_len
+
+ if tag_len != 16:
+ raise Exception("Unexpected tag len: {}".format(tag_len))
+ self.tag = self.inf[-tag_len:]
+ self.ciphertext = self.inf[offs:-tag_len]
+ if len(self.ciphertext) != img_size:
+ raise Exception("Unexpected ciphertext size: ",
+ "got {}, expected {}"
+ .format(len(self.ciphertext), img_size))
+ else:
+ self.img = self.inf[offs:]
+ if len(self.img) != img_size:
+ raise Exception("Unexpected img size: got {}, expected {}"
+ .format(len(self.img), img_size))
else:
- key.verify(
- signature,
- digest,
- padding.PKCS1v15(),
- utils.Prehashed(chosen_hash)
- )
+ raise Exception("Unsupported image type: {}".format(img_type))
- h = hashes.Hash(chosen_hash, default_backend())
+ def decrypt_ta(enc_key):
+ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
- # sizeof(struct shdr)
- h.update(img[:SHDR_SIZE])
+ cipher = AESGCM(bytes.fromhex(enc_key))
+ self.img = cipher.decrypt(self.nonce, self.ciphertext, None)
- # sizeof(struct shdr_bootstrap_ta)
- h.update(img[start:end])
+ def __get_padding(self):
+ from cryptography.hazmat.primitives.asymmetric import padding
- # raw image
- start = end
- end += img_size
- h.update(img[start:end])
+ if self.sig_algo == 'TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256':
+ pad = padding.PSS(mgf=padding.MGF1(self.chosen_hash),
+ salt_length=self.digest_len)
+ elif self.sig_algo == 'TEE_ALG_RSASSA_PKCS1_V1_5_SHA256':
+ pad = padding.PKCS1v15()
- if digest != h.finalize():
+ return pad
+
+ def sign(self):
+ from cryptography.hazmat.primitives.asymmetric import utils
+ from cryptography.hazmat.primitives.asymmetric import rsa
+
+ if not isinstance(self.key, rsa.RSAPrivateKey):
+ logger.error('Provided key cannot be used for signing, ' +
+ 'please use offline-signing mode.')
+ sys.exit(1)
+ else:
+ self.sig = self.key.sign(self.img_digest, self.__get_padding(),
+ utils.Prehashed(self.chosen_hash))
+
+ if len(self.sig) != self.sig_len:
+ raise Exception(("Actual signature length is not equal to ",
+ "the computed one: {} != {}").
+ format(len(self.sig), self.sig_len))
+
+ def add_signature(self, sigf):
+ import base64
+
+ with open(sigf, 'r') as f:
+ self.sig = base64.b64decode(f.read())
+
+ if len(self.sig) != self.sig_len:
+ raise Exception(("Actual signature length is not equal to ",
+ "the expected one: {} != {}").
+ format(len(self.sig), self.sig_len))
+
+ def verify_signature(self):
+ from cryptography.hazmat.primitives.asymmetric import utils
+ from cryptography.hazmat.primitives.asymmetric import rsa
+ from cryptography import exceptions
+
+ if isinstance(self.key, rsa.RSAPrivateKey):
+ pkey = self.key.public_key()
+ else:
+ pkey = self.key
+
+ try:
+ pkey.verify(self.sig, self.img_digest, self.__get_padding(),
+ utils.Prehashed(self.chosen_hash))
+ except exceptions.InvalidSignature:
+ logger.error('Verification failed, ignoring given signature.')
+ sys.exit(1)
+
+ def verify_digest(self):
+ if self.img_digest != self.__calc_digest():
raise Exception('Hash digest does not match')
+ def verify_uuid(self, uuid):
+ if self.ta_uuid != uuid.bytes:
+ raise Exception('UUID does not match')
+
+ def write(self, outf):
+ with open(outf, 'wb') as f:
+ f.write(self.shdr)
+ f.write(self.img_digest)
+ f.write(self.sig)
+ f.write(self.ta_uuid)
+ f.write(self.ta_version)
+ if hasattr(self, 'ehdr'):
+ f.write(self.ehdr)
+ f.write(self.nonce)
+ f.write(self.tag)
+ f.write(self.ciphertext)
+ else:
+ f.write(self.img)
+
+
+def load_ta_image(args):
+ ta_image = BinaryImage(args.inf, args.key)
+
+ if args.enc_key:
+ ta_image.encrypt_ta(args.enc_key, args.enc_key_type,
+ args.algo, args.uuid, args.ta_version)
+ else:
+ ta_image.set_bootstrap_ta(args.algo, args.uuid, args.ta_version)
+
+ return ta_image
+
+
+def main():
+ import logging
+ import os
+
+ global logger
+ logging.basicConfig()
+ logger = logging.getLogger(os.path.basename(__file__))
+
+ args = get_args()
+
+ def sign_encrypt_ta():
+ ta_image = load_ta_image(args)
+ ta_image.sign()
+ ta_image.write(args.outf)
+ logger.info('Successfully signed application.')
+
+ def generate_digest():
+ ta_image = load_ta_image(args)
+ with open(args.digf, 'wb+') as digfile:
+ digfile.write(base64.b64encode(binary_image.img_digest))
+
+ def stitch_ta():
+ ta_image = load_ta_image(args)
+ ta_image.add_signature(args.sigf)
+ ta_image.verify_signature()
+ ta_image.write(args.outf)
+ logger.info('Successfully applied signature.')
+
+ def verify_ta():
+ ta_image = BinaryImage(args.inf, args.key)
+ ta_image.parse()
+ if hasattr(ta_image, 'ciphertext'):
+ if args.enc_key is None:
+ logger.error('--enc_key needed to decrypt TA')
+ sys.exit(1)
+ ta_image.decrypt_ta(args.enc_key)
+ ta_image.verify_signature()
+ ta_image.verify_digest()
+ ta_image.verify_uuid(args.uuid)
logger.info('Trusted application is correctly verified.')
# dispatch command