Add option for COSE protected header

Add option to generate and check the protected header of the COSE
envelope.

Change-Id: I5d298c5a5bb90ba32443c731d75400169c06de1c
Signed-off-by: Mate Toth-Pal <mate.toth-pal@arm.com>
diff --git a/iat-verifier/iatverifier/util.py b/iat-verifier/iatverifier/util.py
index 1025355..9418721 100644
--- a/iat-verifier/iatverifier/util.py
+++ b/iat-verifier/iatverifier/util.py
@@ -13,6 +13,7 @@
 import cbor2
 import yaml
 from ecdsa import SigningKey, VerifyingKey
+from pycose.attributes import CoseAttrs
 from pycose.sign1message import Sign1Message
 from pycose.mac0message import Mac0Message
 from iatverifier.verifiers import AttestationTokenVerifier
@@ -20,8 +21,11 @@
 
 _logger = logging.getLogger("util")
 
-def sign_eat(token, verifier, key=None):
-    signed_msg = Sign1Message()
+def sign_eat(token, verifier, *, add_p_header, key=None):
+    protected_header = CoseAttrs()
+    if add_p_header and key:
+        protected_header['alg'] = verifier.cose_alg
+    signed_msg = Sign1Message(p_header=protected_header)
     signed_msg.payload = token
     if key:
         signed_msg.key = key
@@ -29,13 +33,16 @@
     return signed_msg.encode()
 
 
-def hmac_eat(token, verifier, key=None):
-    hmac_msg = Mac0Message(payload=token, key=key)
+def hmac_eat(token, verifier, *, add_p_header, key=None):
+    protected_header = CoseAttrs()
+    if add_p_header and key:
+        protected_header['alg'] = verifier.cose_alg
+    hmac_msg = Mac0Message(payload=token, key=key, p_header=protected_header)
     hmac_msg.compute_auth_tag(alg=verifier.cose_alg)
     return hmac_msg.encode()
 
 
-def convert_map_to_token_files(mapfile, keyfile, verifier, outfile):
+def convert_map_to_token_files(mapfile, keyfile, verifier, outfile, add_p_header):
     token_map = read_token_map(mapfile)
 
     if verifier.method == 'sign':
@@ -46,10 +53,10 @@
             signing_key = fh.read()
 
     with open(outfile, 'wb') as wfh:
-        convert_map_to_token(token_map, signing_key, verifier, wfh)
+        convert_map_to_token(token_map, signing_key, verifier, wfh, add_p_header)
 
 
-def convert_map_to_token(token_map, signing_key, verifier, wfh):
+def convert_map_to_token(token_map, signing_key, verifier, wfh, add_p_header):
     wrapping_tag = verifier.get_wrapping_tag()
     if wrapping_tag is not None:
         token = cbor2.dumps(CBORTag(wrapping_tag, token_map))
@@ -59,9 +66,9 @@
     if verifier.method == AttestationTokenVerifier.SIGN_METHOD_RAW:
         signed_token = token
     elif verifier.method == AttestationTokenVerifier.SIGN_METHOD_SIGN1:
-        signed_token = sign_eat(token, verifier, signing_key)
+        signed_token = sign_eat(token, verifier, add_p_header=add_p_header, key=signing_key)
     elif verifier.method == AttestationTokenVerifier.SIGN_METHOD_MAC0:
-        signed_token = hmac_eat(token, verifier, signing_key)
+        signed_token = hmac_eat(token, verifier, add_p_header=add_p_header, key=signing_key)
     else:
         err_msg = 'Unexpected method "{}"; must be one of: raw, sign, mac'
         raise ValueError(err_msg.format(method))
@@ -70,7 +77,7 @@
 
 
 def convert_token_to_map(raw_data, verifier):
-    payload = get_cose_payload(raw_data, verifier)
+    payload = get_cose_payload(raw_data, verifier, check_p_header=False)
     token_map = cbor2.loads(payload)
     return _relabel_keys(token_map)
 
@@ -85,29 +92,38 @@
     return _parse_raw_token(raw)
 
 
-def extract_iat_from_cose(keyfile, tokenfile, verifier):
+def extract_iat_from_cose(keyfile, tokenfile, verifier, check_p_header):
     key = read_keyfile(keyfile, verifier.method)
 
     try:
         with open(tokenfile, 'rb') as wfh:
-            return get_cose_payload(wfh.read(), verifier, key)
+            return get_cose_payload(wfh.read(), verifier, check_p_header=check_p_header, key=key)
     except Exception as e:
         msg = 'Bad COSE file "{}": {}'
         raise ValueError(msg.format(tokenfile, e))
 
 
-def get_cose_payload(cose, verifier, key=None):
+def get_cose_payload(cose, verifier, *, check_p_header, key=None):
     if verifier.method == AttestationTokenVerifier.SIGN_METHOD_SIGN1:
-        return get_cose_sign1_payload(cose, verifier, key)
+        return get_cose_sign1_payload(cose, verifier, check_p_header=check_p_header, key=key)
     if verifier.method == AttestationTokenVerifier.SIGN_METHOD_MAC0:
-        return get_cose_mac0_pyload(cose, verifier, key)
+        return get_cose_mac0_payload(cose, verifier, check_p_header=check_p_header, key=key)
     err_msg = 'Unexpected method "{}"; must be one of: sign, mac'
-    raise ValueError(err_msg.format(method))
+    raise ValueError(err_msg.format(verifier.method))
 
+def parse_protected_header(msg, alg):
+    try:
+        msg_alg = msg.protected_header['alg']
+    except KeyError:
+        raise ValueError('Missing alg from protected header (expected {})'.format(alg))
+    if alg != msg_alg:
+        raise ValueError('Unexpected alg in protected header (expected {} instead of {})'.format(alg, msg_alg))
 
-def get_cose_sign1_payload(cose, verifier, key=None):
+def get_cose_sign1_payload(cose, verifier, *, check_p_header, key=None):
     msg = Sign1Message.decode(cose)
     if key:
+        if check_p_header:
+            parse_protected_header(msg, verifier.cose_alg)
         msg.key = key
         msg.signature = msg.signers
         try:
@@ -117,9 +133,11 @@
     return msg.payload
 
 
-def get_cose_mac0_pyload(cose, verifier, key=None):
+def get_cose_mac0_payload(cose, verifier, *, check_p_header, key=None):
     msg = Mac0Message.decode(cose)
     if key:
+        if check_p_header:
+            parse_protected_header(msg, verifier.cose_alg)
         msg.key = key
         try:
             msg.verify_auth_tag(alg=verifier.cose_alg)