Refactor token and map parsing
The aim of this change is to make it possible to verify nested EATs.
This requires finer grade control over how the token structure is
parsed, as CBOR envelopes can now be present inside the tree.
So this change makes the parsing the token and the map a recursive
operation, calling the necessary methods of the objects at each level.
Change-Id: I4c1e29deae7b238f2d82a73bd95c533f89492d40
Signed-off-by: Mate Toth-Pal <mate.toth-pal@arm.com>
diff --git a/iat-verifier/iatverifier/util.py b/iat-verifier/iatverifier/util.py
index 39af2e3..12f16b8 100644
--- a/iat-verifier/iatverifier/util.py
+++ b/iat-verifier/iatverifier/util.py
@@ -5,155 +5,67 @@
#
# -----------------------------------------------------------------------------
+"""Helper utilities for CLI tools and tests"""
+
from collections.abc import Iterable
from copy import deepcopy
import logging
import base64
-import cbor2
import yaml
from ecdsa import SigningKey, VerifyingKey
-from pycose.attributes import CoseAttrs
-from pycose.sign1message import Sign1Message
-from pycose.mac0message import Mac0Message
from iatverifier.attest_token_verifier import AttestationTokenVerifier
-from cbor2 import CBORTag
+from cbor2 import CBOREncoder
_logger = logging.getLogger("util")
-def sign_eat(token, verifier, *, add_p_header, key=None):
- protected_header = CoseAttrs()
- if add_p_header and key:
- protected_header['alg'] = verifier.cose_alg
- signed_msg = Sign1Message(p_header=protected_header)
- signed_msg.payload = token
- if key:
- signed_msg.key = key
- signed_msg.signature = signed_msg.compute_signature(alg=verifier.cose_alg)
- return signed_msg.encode()
+_known_curves = {
+ "NIST256p": AttestationTokenVerifier.COSE_ALG_ES256,
+ "NIST384p": AttestationTokenVerifier.COSE_ALG_ES384,
+ "NIST521p": AttestationTokenVerifier.COSE_ALG_ES512,
+}
+
+def convert_map_to_token(token_map, verifier, wfh, *, add_p_header, name_as_key, parse_raw_value):
+ """
+ Convert a map to token and write the result to a file.
+ """
+ encoder = CBOREncoder(wfh)
+ verifier.convert_map_to_token(
+ encoder,
+ token_map,
+ add_p_header=add_p_header,
+ name_as_key=name_as_key,
+ parse_raw_value=parse_raw_value,
+ root=True)
-def hmac_eat(token, verifier, *, add_p_header, key=None):
- protected_header = CoseAttrs()
- if add_p_header and key:
- protected_header['alg'] = verifier.cose_alg
- hmac_msg = Mac0Message(payload=token, key=key, p_header=protected_header)
- hmac_msg.compute_auth_tag(alg=verifier.cose_alg)
- return hmac_msg.encode()
-
-
-def convert_map_to_token_files(mapfile, keyfile, verifier, outfile, add_p_header):
- token_map = read_token_map(mapfile)
-
- if verifier.method == 'sign':
- with open(keyfile) as fh:
- signing_key = SigningKey.from_pem(fh.read())
+def read_token_map(file):
+ """
+ Read a yaml file and return a map
+ """
+ if hasattr(file, 'read'):
+ raw = yaml.safe_load(file)
else:
- with open(keyfile, 'rb') as fh:
- signing_key = fh.read()
+ with open(file, encoding="utf8") as file_obj:
+ raw = yaml.safe_load(file_obj)
- with open(outfile, 'wb') as wfh:
- convert_map_to_token(token_map, signing_key, verifier, wfh, add_p_header)
+ return raw
-def convert_map_to_token(token_map, signing_key, verifier, wfh, add_p_header):
- wrapping_tag = verifier.get_wrapping_tag()
- if wrapping_tag is not None:
- token = cbor2.dumps(CBORTag(wrapping_tag, token_map))
- else:
- token = cbor2.dumps(token_map)
+def recursive_bytes_to_strings(token, in_place=False):
+ """
+ Transform the map in 'token' by changing changing bytes to base64 encoded form.
- if verifier.method == AttestationTokenVerifier.SIGN_METHOD_RAW:
- signed_token = token
- elif verifier.method == AttestationTokenVerifier.SIGN_METHOD_SIGN1:
- signed_token = sign_eat(token, verifier, add_p_header=add_p_header, key=signing_key)
- elif verifier.method == AttestationTokenVerifier.SIGN_METHOD_MAC0:
- signed_token = hmac_eat(token, verifier, add_p_header=add_p_header, key=signing_key)
- else:
- err_msg = 'Unexpected method "{}"; must be one of: raw, sign, mac'
- raise ValueError(err_msg.format(method))
-
- wfh.write(signed_token)
-
-
-def convert_token_to_map(raw_data, verifier):
- payload = get_cose_payload(raw_data, verifier, check_p_header=False)
- token_map = cbor2.loads(payload)
- return _relabel_keys(token_map)
-
-
-def read_token_map(f):
- if hasattr(f, 'read'):
- raw = yaml.safe_load(f)
- else:
- with open(f) as fh:
- raw = yaml.safe_load(fh)
-
- return _parse_raw_token(raw)
-
-
-def extract_iat_from_cose(keyfile, tokenfile, verifier, check_p_header):
- key = read_keyfile(keyfile, verifier.method)
-
- try:
- with open(tokenfile, 'rb') as wfh:
- return get_cose_payload(wfh.read(), verifier, check_p_header=check_p_header, key=key)
- except Exception as e:
- msg = 'Bad COSE file "{}": {}'
- raise ValueError(msg.format(tokenfile, e))
-
-
-def get_cose_payload(cose, verifier, *, check_p_header, key=None):
- if verifier.method == AttestationTokenVerifier.SIGN_METHOD_SIGN1:
- return get_cose_sign1_payload(cose, verifier, check_p_header=check_p_header, key=key)
- if verifier.method == AttestationTokenVerifier.SIGN_METHOD_MAC0:
- return get_cose_mac0_payload(cose, verifier, check_p_header=check_p_header, key=key)
- err_msg = 'Unexpected method "{}"; must be one of: sign, mac'
- raise ValueError(err_msg.format(verifier.method))
-
-def parse_protected_header(msg, alg):
- try:
- msg_alg = msg.protected_header['alg']
- except KeyError:
- raise ValueError('Missing alg from protected header (expected {})'.format(alg))
- if alg != msg_alg:
- raise ValueError('Unexpected alg in protected header (expected {} instead of {})'.format(alg, msg_alg))
-
-def get_cose_sign1_payload(cose, verifier, *, check_p_header, key=None):
- msg = Sign1Message.decode(cose)
- if key:
- if check_p_header:
- parse_protected_header(msg, verifier.cose_alg)
- msg.key = key
- msg.signature = msg.signers
- try:
- msg.verify_signature(alg=verifier.cose_alg)
- except Exception as e:
- raise ValueError('Bad signature ({})'.format(e))
- return msg.payload
-
-
-def get_cose_mac0_payload(cose, verifier, *, check_p_header, key=None):
- msg = Mac0Message.decode(cose)
- if key:
- if check_p_header:
- parse_protected_header(msg, verifier.cose_alg)
- msg.key = key
- try:
- msg.verify_auth_tag(alg=verifier.cose_alg)
- except Exception as e:
- raise ValueError('Bad signature ({})'.format(e))
- return msg.payload
-
-def recursive_bytes_to_strings(d, in_place=False):
+ if 'in_place' is True, 'token' is modified, a new map is returned otherwise.
+ """
if in_place:
- result = d
+ result = token
else:
- result = deepcopy(d)
+ result = deepcopy(token)
if hasattr(result, 'items'):
- for k, v in result.items():
- result[k] = recursive_bytes_to_strings(v, in_place=True)
+ for key, value in result.items():
+ result[key] = recursive_bytes_to_strings(value, in_place=True)
elif (isinstance(result, Iterable) and
not isinstance(result, (str, bytes))):
result = [recursive_bytes_to_strings(r, in_place=True)
@@ -165,105 +77,45 @@
def read_keyfile(keyfile, method=AttestationTokenVerifier.SIGN_METHOD_SIGN1):
+ """
+ Read a keyfile and return the key
+ """
if keyfile:
if method == AttestationTokenVerifier.SIGN_METHOD_SIGN1:
- return read_sign1_key(keyfile)
+ return _read_sign1_key(keyfile)
if method == AttestationTokenVerifier.SIGN_METHOD_MAC0:
- return read_hmac_key(keyfile)
+ return _read_hmac_key(keyfile)
err_msg = 'Unexpected method "{}"; must be one of: sign, mac'
raise ValueError(err_msg.format(method))
return None
+def get_cose_alg_from_key(key):
+ """Extract the algorithm from the key if possible
-def read_sign1_key(keyfile):
+ Returns the signature algorithm ID defined by COSE
+ """
+ if not hasattr(key, "curve"):
+ raise ValueError("Key has no curve specified in it.")
+ return _known_curves[key.curve.name]
+
+def _read_sign1_key(keyfile):
+ with open(keyfile, 'rb') as file_obj:
+ raw_key = file_obj.read()
try:
- key = SigningKey.from_pem(open(keyfile, 'rb').read())
- except Exception as e:
- signing_key_error = str(e)
+ key = SigningKey.from_pem(raw_key)
+ except Exception as exc:
+ signing_key_error = str(exc)
try:
- key = VerifyingKey.from_pem(open(keyfile, 'rb').read())
- except Exception as e:
- verifying_key_error = str(e)
+ key = VerifyingKey.from_pem(raw_key)
+ except Exception as vexc:
+ verifying_key_error = str(vexc)
msg = 'Bad key file "{}":\n\tpubkey error: {}\n\tprikey error: {}'
- raise ValueError(msg.format(keyfile, verifying_key_error, signing_key_error))
+ raise ValueError(msg.format(keyfile, verifying_key_error, signing_key_error)) from vexc
return key
-def read_hmac_key(keyfile):
+def _read_hmac_key(keyfile):
return open(keyfile, 'rb').read()
-
-def _get_known_claims():
- if logging.DEBUG >= logging.root.level:
- _logger.debug("Known claims are:")
- for _, claim_class in AttestationTokenVerifier.all_known_claims.items():
- _logger.debug(f" {claim_class.get_claim_key():8} '{claim_class.get_claim_name()}'")
- for _, claim_class in AttestationTokenVerifier.all_known_claims.items():
- yield claim_class
-
-def _parse_raw_token(raw):
- result = {}
- field_names = {cc.get_claim_name(): cc for cc in _get_known_claims()}
- for raw_key, raw_value in raw.items():
- if isinstance(raw_key, int):
- key = raw_key
- else:
- field_name = raw_key.upper()
- try:
- claim_class = field_names[field_name]
- key = claim_class.get_claim_key()
- except KeyError:
- msg = 'Unknown field "{}" in token.'.format(field_name)
- raise ValueError(msg)
-
- if hasattr(raw_value, 'items'):
- value = _parse_raw_token(raw_value)
- elif (isinstance(raw_value, Iterable) and
- not isinstance(raw_value, (str, bytes))):
- value = []
- for v in raw_value:
- if hasattr(v, 'items'):
- value.append(_parse_raw_token(v))
- else:
- value.append(claim_class.parse_raw(v))
- else:
- value = claim_class.parse_raw(raw_value)
-
- result[key] = value
-
- return result
-
-def _format_value(names, key, value):
- if key in names:
- value = names[key].get_formatted_value(value)
- return value
-
-def _relabel_keys(token_map):
- result = {}
- while not hasattr(token_map, 'items'):
- # TODO: token map is not a map. We are assuming that it is a tag
- token_map = token_map.value
- names = {v.get_claim_key(): v for v in _get_known_claims()}
- for key, value in token_map.items():
- if hasattr(value, 'items'):
- value = _relabel_keys(value)
- elif (isinstance(value, Iterable) and
- not isinstance(value, (str, bytes))):
- new_value = []
- for item in value:
- if hasattr(item, 'items'):
- new_value.append(_relabel_keys(item))
- else:
- new_value.append(_format_value(names, key, item))
- value = new_value
- else:
- value = _format_value(names, key, value)
-
- if key in names:
- new_key = names[key].get_claim_name().lower()
- else:
- new_key = key
- result[new_key] = value
- return result