imgtool: Update RSA code
Replace RSA code with one using the python 'cryptography' library. This
library is much more complete, and will make adding support for password
protected keys, and separate public keys easier.
There is, however, a significant change brought about by this change:
the private keys are stored in PKCS#8 format, instead of the raw format
that was used previously. This is a more modern format that has a few
advantages, including: supporting stronger password protection, and
allowing the key type to be determined upon read.
This tool will still support reading the old style public keys, but
other tools that use these keys will need to be updated in order to work
with the new format.
This new code has some unit tests to go along with it for some basic
sanity testing of the code.
Signed-off-by: David Brown <david.brown@linaro.org>
diff --git a/scripts/imgtool/keys/__init__.py b/scripts/imgtool/keys/__init__.py
index ee54a0f..8a2c50f 100644
--- a/scripts/imgtool/keys/__init__.py
+++ b/scripts/imgtool/keys/__init__.py
@@ -16,91 +16,11 @@
Cryptographic key management for imgtool.
"""
-from Crypto.Hash import SHA256
-from Crypto.PublicKey import RSA
-from Crypto.Signature import PKCS1_v1_5, PKCS1_PSS
-from ecdsa import SigningKey, NIST256p, util
-import hashlib
-from pyasn1.type import namedtype, univ
-from pyasn1.codec.der.encoder import encode
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
-# By default, we use RSA-PSS (PKCS 2.1). That can be overridden on
-# the command line to support the older (less secure) PKCS1.5
-sign_rsa_pss = True
-
-AUTOGEN_MESSAGE = "/* Autogenerated by imgtool.py, do not edit. */"
-
-class RSAPublicKey(univ.Sequence):
- componentType = namedtype.NamedTypes(
- namedtype.NamedType('modulus', univ.Integer()),
- namedtype.NamedType('publicExponent', univ.Integer()))
-
-class RSA2048():
- def __init__(self, key):
- """Construct an RSA2048 key with the given key data"""
- self.key = key
-
- @staticmethod
- def generate():
- return RSA2048(RSA.generate(2048))
-
- def export_private(self, path):
- with open(path, 'wb') as f:
- f.write(self.key.exportKey('PEM'))
-
- def get_public_bytes(self):
- node = RSAPublicKey()
- node['modulus'] = self.key.n
- node['publicExponent'] = self.key.e
- return bytearray(encode(node))
-
- def emit_c(self):
- print(AUTOGEN_MESSAGE)
- print("const unsigned char rsa_pub_key[] = {", end='')
- encoded = self.get_public_bytes()
- for count, b in enumerate(encoded):
- if count % 8 == 0:
- print("\n\t", end='')
- else:
- print(" ", end='')
- print("0x{:02x},".format(b), end='')
- print("\n};")
- print("const unsigned int rsa_pub_key_len = {};".format(len(encoded)))
-
- def emit_rust(self):
- print(AUTOGEN_MESSAGE)
- print("static RSA_PUB_KEY: &'static [u8] = &[", end='')
- encoded = self.get_public_bytes()
- for count, b in enumerate(encoded):
- if count % 8 == 0:
- print("\n ", end='')
- else:
- print(" ", end='')
- print("0x{:02x},".format(b), end='')
- print("\n];")
-
- def sig_type(self):
- """Return the type of this signature (as a string)"""
- if sign_rsa_pss:
- return "PKCS1_PSS_RSA2048_SHA256"
- else:
- return "PKCS15_RSA2048_SHA256"
-
- def sig_len(self):
- return 256
-
- def sig_tlv(self):
- return "RSA2048"
-
- def sign(self, payload):
- sha = SHA256.new(payload)
- if sign_rsa_pss:
- signer = PKCS1_PSS.new(self.key)
- else:
- signer = PKCS1_v1_5.new(self.key)
- signature = signer.sign(sha)
- assert len(signature) == self.sig_len()
- return signature
+from .rsa import RSA2048, RSA2048Public, RSAUsageError
class ECDSA256P1():
def __init__(self, key):
@@ -168,16 +88,38 @@
def sig_tlv(self):
return "ECDSA256"
-def load(path):
+class PasswordRequired(Exception):
+ """Raised to indicate that the key is password protected, but a
+ password was not specified."""
+ pass
+
+def load(path, passwd=None):
+ """Try loading a key from the given path. Returns None if the password wasn't specified."""
with open(path, 'rb') as f:
- pem = f.read()
+ raw_pem = f.read()
try:
- key = RSA.importKey(pem)
- if key.n.bit_length() != 2048:
- raise Exception("Unsupported RSA bit length, only 2048 supported")
- return RSA2048(key)
+ pk = serialization.load_pem_private_key(
+ raw_pem,
+ password=passwd,
+ backend=default_backend())
+ # This is a bit nonsensical of an exception, but it is what
+ # cryptography seems to currently raise if the password is needed.
+ except TypeError:
+ return None
except ValueError:
- key = SigningKey.from_pem(pem)
- if key.curve.name != 'NIST256p':
- raise Exception("Unsupported ECDSA curve")
- return ECDSA256P1(key)
+ # This seems to happen if the key is a public key, let's try
+ # loading it as a public key.
+ pk = serialization.load_pem_public_key(
+ raw_pem,
+ backend=default_backend())
+
+ if isinstance(pk, RSAPrivateKey):
+ if pk.key_size != 2048:
+ raise Exception("Unsupported RSA key size: " + pk.key_size)
+ return RSA2048(pk)
+ elif isinstance(pk, RSAPublicKey):
+ if pk.key_size != 2048:
+ raise Exception("Unsupported RSA key size: " + pk.key_size)
+ return RSA2048Public(pk)
+ else:
+ raise Exception("Unknown key type: " + str(type(pk)))
diff --git a/scripts/imgtool/keys/general.py b/scripts/imgtool/keys/general.py
new file mode 100644
index 0000000..3ba34cb
--- /dev/null
+++ b/scripts/imgtool/keys/general.py
@@ -0,0 +1,35 @@
+"""General key class."""
+
+import sys
+
+AUTOGEN_MESSAGE = "/* Autogenerated by imgtool.py, do not edit. */"
+
+class KeyClass(object):
+ def _public_emit(self, header, trailer, indent, file=sys.stdout, len_format=None):
+ print(AUTOGEN_MESSAGE, file=file)
+ print(header, end='', file=file)
+ encoded = self.get_public_bytes()
+ for count, b in enumerate(encoded):
+ if count % 8 == 0:
+ print("\n" + indent, end='', file=file)
+ else:
+ print(" ", end='', file=file)
+ print("0x{:02x},".format(b), end='', file=file)
+ print("\n" + trailer, file=file)
+ if len_format is not None:
+ print(len_format.format(len(encoded)), file=file)
+
+ def emit_c(self, file=sys.stdout):
+ self._public_emit(
+ header="const unsigned char {}_pub_key[] = {{".format(self.shortname()),
+ trailer="};",
+ indent=" ",
+ len_format="const unsigned int {}_pub_key_len = {{}};".format(self.shortname()),
+ file=file)
+
+ def emit_rust(self, file=sys.stdout):
+ self._public_emit(
+ header="static {}_PUB_KEY: &'static [u8] = &[".format(self.shortname().upper()),
+ trailer="];",
+ indent=" ",
+ file=file)
diff --git a/scripts/imgtool/keys/rsa.py b/scripts/imgtool/keys/rsa.py
new file mode 100644
index 0000000..8d5d048
--- /dev/null
+++ b/scripts/imgtool/keys/rsa.py
@@ -0,0 +1,90 @@
+"""
+RSA Key management
+"""
+
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric import rsa
+from cryptography.hazmat.primitives.asymmetric.padding import PSS, MGF1
+from cryptography.hazmat.primitives.hashes import SHA256
+
+from .general import KeyClass
+
+class RSAUsageError(Exception):
+ pass
+
+class RSA2048Public(KeyClass):
+ """The public key can only do a few operations"""
+ def __init__(self, key):
+ self.key = key
+
+ def shortname(self):
+ return "rsa"
+
+ def _unsupported(self, name):
+ raise RSAUsageError("Operation {} requires private key".format(name))
+
+ def _get_public(self):
+ return self.key
+
+ def get_public_bytes(self):
+ # The key embedded into MCUboot is in PKCS1 format.
+ return self._get_public().public_bytes(
+ encoding=serialization.Encoding.DER,
+ format=serialization.PublicFormat.PKCS1)
+
+ def export_private(self, path, passwd=None):
+ self._unsupported('export_private')
+
+ def export_public(self, path):
+ """Write the public key to the given file."""
+ pem = self._get_public().public_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PublicFormat.SubjectPublicKeyInfo)
+ with open(path, 'wb') as f:
+ f.write(pem)
+
+ def sig_type(self):
+ return "PKCS1_PSS_RSA2048_SHA256"
+
+ def sig_tlv(self):
+ return "RSA2048"
+
+class RSA2048(RSA2048Public):
+ """
+ Wrapper around an 2048-bit RSA key, with imgtool support.
+ """
+
+ def __init__(self, key):
+ """The key should be a private key from cryptography"""
+ self.key = key
+
+ @staticmethod
+ def generate():
+ pk = rsa.generate_private_key(
+ public_exponent=65537,
+ key_size=2048,
+ backend=default_backend())
+ return RSA2048(pk)
+
+ def _get_public(self):
+ return self.key.public_key()
+
+ def export_private(self, path, passwd=None):
+ """Write the private key to the given file, protecting it with the optional password."""
+ if passwd is None:
+ enc = serialization.NoEncryption()
+ else:
+ enc = serialization.BestAvailableEncryption(passwd)
+ pem = self.key.private_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PrivateFormat.PKCS8,
+ encryption_algorithm=enc)
+ with open(path, 'wb') as f:
+ f.write(pem)
+
+ def sign(self, payload):
+ return self.key.sign(
+ data=payload,
+ padding=PSS(mgf=MGF1(SHA256()), salt_length=PSS.MAX_LENGTH),
+ algorithm=SHA256())
diff --git a/scripts/imgtool/keys/rsa_test.py b/scripts/imgtool/keys/rsa_test.py
new file mode 100644
index 0000000..0d3dfd8
--- /dev/null
+++ b/scripts/imgtool/keys/rsa_test.py
@@ -0,0 +1,102 @@
+"""
+Tests for RSA keys
+"""
+
+import io
+import os
+import sys
+import tempfile
+import unittest
+
+from cryptography.exceptions import InvalidSignature
+from cryptography.hazmat.primitives.asymmetric.padding import PSS, MGF1
+from cryptography.hazmat.primitives.hashes import SHA256
+
+# Setup sys path so 'imgtool' is in it.
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
+
+from imgtool.keys import load, RSA2048, RSAUsageError
+
+class KeyGeneration(unittest.TestCase):
+
+ def setUp(self):
+ self.test_dir = tempfile.TemporaryDirectory()
+
+ def tname(self, base):
+ return os.path.join(self.test_dir.name, base)
+
+ def tearDown(self):
+ self.test_dir.cleanup()
+
+ def test_keygen(self):
+ name1 = self.tname("keygen.pem")
+ k = RSA2048.generate()
+ k.export_private(name1, b'secret')
+
+ # Try loading the key without a password.
+ self.assertIsNone(load(name1))
+
+ k2 = load(name1, b'secret')
+
+ pubname = self.tname('keygen-pub.pem')
+ k2.export_public(pubname)
+ pk2 = load(pubname)
+
+ # We should be able to export the public key from the loaded
+ # public key, but not the private key.
+ pk2.export_public(self.tname('keygen-pub2.pem'))
+ self.assertRaises(RSAUsageError, pk2.export_private, self.tname('keygen-priv2.pem'))
+
+ def test_emit(self):
+ """Basic sanity check on the code emitters."""
+ k = RSA2048.generate()
+
+ ccode = io.StringIO()
+ k.emit_c(ccode)
+ self.assertIn("rsa_pub_key", ccode.getvalue())
+ self.assertIn("rsa_pub_key_len", ccode.getvalue())
+
+ rustcode = io.StringIO()
+ k.emit_rust(rustcode)
+ self.assertIn("RSA_PUB_KEY", rustcode.getvalue())
+
+ def test_emit_pub(self):
+ """Basic sanity check on the code emitters, from public key."""
+ pubname = self.tname("public.pem")
+ k = RSA2048.generate()
+ k.export_public(pubname)
+
+ k2 = load(pubname)
+
+ ccode = io.StringIO()
+ k2.emit_c(ccode)
+ self.assertIn("rsa_pub_key", ccode.getvalue())
+ self.assertIn("rsa_pub_key_len", ccode.getvalue())
+
+ rustcode = io.StringIO()
+ k2.emit_rust(rustcode)
+ self.assertIn("RSA_PUB_KEY", rustcode.getvalue())
+
+ def test_sig(self):
+ k = RSA2048.generate()
+ buf = b'This is the message'
+ sig = k.sign(buf)
+
+ # The code doesn't have any verification, so verify this
+ # manually.
+ k.key.public_key().verify(
+ signature=sig,
+ data=buf,
+ padding=PSS(mgf=MGF1(SHA256()), salt_length=PSS.MAX_LENGTH),
+ algorithm=SHA256())
+
+ # Modify the message to make sure the signature fails.
+ self.assertRaises(InvalidSignature,
+ k.key.public_key().verify,
+ signature=sig,
+ data=b'This is thE message',
+ padding=PSS(mgf=MGF1(SHA256()), salt_length=PSS.MAX_LENGTH),
+ algorithm=SHA256())
+
+if __name__ == '__main__':
+ unittest.main()