blob: b102fb224d2db847c82bf75f1f5cb0efd6e634b1 [file] [log] [blame]
Raef Coles59cf5d82024-12-09 15:41:13 +00001#!/usr/bin/env python3
2#-------------------------------------------------------------------------------
3# SPDX-FileCopyrightText: Copyright The TrustedFirmware-M Contributors
4#
5# SPDX-License-Identifier: BSD-3-Clause
6#
7#-------------------------------------------------------------------------------
8
9from cryptography.hazmat.primitives.asymmetric import utils, ec
10from cryptography.hazmat.primitives import hashes
11from cryptography.hazmat.primitives.serialization import load_pem_private_key, Encoding, PublicFormat
12
13import pyhsslms
14
15from cryptography.hazmat.primitives.ciphers.aead import AESCCM
16import secrets
17
18import argparse
19import logging
20logger = logging.getLogger("TF-M")
21from arg_utils import *
22
23def _asn1_sig_to_raw(sig : bytes , curve : ec.EllipticCurve) -> bytes:
24 point_size = curve.key_size // 8
25 r, s = utils.decode_dss_signature(sig)
26 return bytes(0).join([x.to_bytes(point_size, byteorder="big") for x in [r, s]])
27
28def _pubkey_ecdsa(key : str) -> bytes:
29 with open(key, "rb") as f:
30 key_data = f.read()
31
32 priv_key = load_pem_private_key(key_data, password=None)
33 return priv_key.public_key().public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo)
34
35def _pubkey_lms(key : str) -> bytes:
36 return pyhsslms.HssLmsPublicKey(key[:-4]).hss_pub.pub.serialize()
37
38pubkey_algs = {
39 "ECDSA": _pubkey_ecdsa,
40 "LMS": _pubkey_lms,
41}
42
43def _sign_ecdsa(data : bytes,
44 key : str,
45 curve : ec.EllipticCurve = None,
46 hash_alg : hashes.HashAlgorithm = None,
47 **kwargs : {},
48 ) -> bytes:
49 with open(key, "rb") as f:
50 key_data = f.read()
51
52 priv_key = load_pem_private_key(key_data, password=None)
53
54 hash_alg_defaults = {
55 ec.SECP256R1: hashes.SHA256,
56 ec.SECP384R1: hashes.SHA384,
57 }
58
59 if curve:
60 assert curve.name == priv_key.curve.name, "Key curve {} does not match required curve {}".format(priv_key.curve.name, curve.name)
61
62 if not hash_alg:
63 hash_alg = hash_alg_defaults[type(priv_key.curve)]
64
65 logger.info("Signing with ECDSA key {} (curve {}) and hash {}".format(key, priv_key.curve.name, hash_alg.name))
66
67 digest = hashes.Hash(hash_alg())
68 digest.update(data)
69 logger.info("Signing hash {}".format(digest.finalize().hex()))
70
71 asn1_sig = priv_key.sign(data, ec.ECDSA(hash_alg()))
72
73 return _asn1_sig_to_raw(asn1_sig, priv_key.curve)
74
75def _sign_lms(data : bytes,
76 key : str,
77 **kwargs : {},
78 ) -> bytes:
79 priv_key = pyhsslms.HssLmsPrivateKey(key[:-4])
80 logger.info("Signing with LMS key {}".format(key))
81 return priv_key.sign(data)[4:]
82
83def _sign_aes_ccm(data : bytes,
84 key : str,
85 iv : bytes,
86 **kwargs : {},
87 ) -> bytes:
88 with open(key, "rb") as f:
89 key_data = f.read()
90
91 if not iv:
92 iv = secrets.token_bytes(12)
93
94 return AESCCM(key_data).encrypt(iv, plaintext, data)
95
96def sign_data(data : bytes,
97 sign_key : str,
98 sign_alg : str,
99 sign_hash_alg : str = None,
100 **kwargs,
101 ) -> bytes:
102 assert(sign_key)
103 assert(sign_alg)
104
105 sig = sign_algs[sign_alg](data = data,
106 key = sign_key,
107 hash_alg = sign_hash_alg,
108 **kwargs)
109
110 return sig
111
112def get_pubkey(sign_key : str,
113 sign_alg : str,
114 **kwargs
115 ) -> bytes:
116 return pubkey_algs[sign_alg](sign_key)
117
118sign_algs = {
119 "ECDSA": _sign_ecdsa,
120 "LMS": _sign_lms,
121}
122
123def add_arguments(parser : argparse.ArgumentParser,
124 prefix : str = "",
125 required : bool = True,
126 ) -> None:
127 add_prefixed_argument(parser, "sign_key", prefix, help="signing key input file",
128 type=str, required=required)
129 add_prefixed_argument(parser, "sign_alg", prefix, help="signing algorithm",
130 choices=sign_algs.keys(), required=required)
131 add_prefixed_argument(parser, "sign_hash_alg", prefix, help="signing hash algorithm",
132 type=arg_type_hash, required=False)
133
134def parse_args(args : argparse.Namespace,
135 prefix : str = "",
136 ) -> dict:
137 out = parse_args_automatically(args, ["sign_key", "sign_alg", "sign_hash_alg"], prefix)
138 return out
139
140script_description = """
141Sign some data.
142"""
143if __name__ == "__main__":
144 import argparse
145
146 parser = argparse.ArgumentParser(allow_abbrev=False,
147 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
148 description=script_description)
149 parser.add_argument("--log_level", help="log level", required=False, default="ERROR", choices=logging._levelToName.values())
150 parser.add_argument("--data", help="data to sign", type=arg_type_bytes, required=True)
151
152 add_arguments(parser, required=True)
153
154 args = parser.parse_args()
155 logger.setLevel(args.log_level)
156
157 config = parse_args(args)
158 config |= parse_args_automatically(args, ["data"])
159
160 print(sign_data(**config).hex())