blob: e89dc6b19d52793b36b0ac8a885e143369654ffc [file] [log] [blame]
Mate Toth-Palbb187d02022-04-26 16:01:51 +02001# -----------------------------------------------------------------------------
2# Copyright (c) 2019-2022, Arm Limited. All rights reserved.
3#
4# SPDX-License-Identifier: BSD-3-Clause
5#
6# -----------------------------------------------------------------------------
7
8import logging
9
10import cbor2
11
12logger = logging.getLogger('iat-verifiers')
13
14class AttestationClaim:
15 MANDATORY = 0
16 RECOMMENDED = 1
17 OPTIONAL = 2
18
19 def __init__(self, verifier, necessity=MANDATORY):
20 self.config = verifier.config
21 self.verifier = verifier
22 self.necessity = necessity
23 self.verify_count = 0
24
25 def verify(self, value):
26 raise NotImplementedError
27
28 def get_claim_key(self=None):
29 raise NotImplementedError
30
31 def get_claim_name(self=None):
32 raise NotImplementedError
33
34 def get_contained_claim_key_list(self):
35 return {}
36
37 def decode(self, value):
38 if self.is_utf_8():
39 try:
40 return value.decode()
41 except UnicodeDecodeError as e:
42 msg = 'Error decodeing value for "{}": {}'
43 self.verifier.error(msg.format(self.get_claim_name(), e))
44 return str(value)[2:-1]
45 else: # not a UTF-8 value, i.e. a bytestring
46 return value
47
48 def add_tokens_to_dict(self, token, value):
49 entry_name = self.get_claim_name()
50 if isinstance(value, bytes):
51 value = self.decode(value)
52 token[entry_name] = value
53
54 def claim_found(self):
55 return self.verify_count>0
56
57 def _check_type(self, name, value, expected_type):
58 if not isinstance(value, expected_type):
59 msg = 'Invalid {}: must be a(n) {}: found {}'
60 self.verifier.error(msg.format(name, expected_type, type(value)))
61 return False
62 return True
63
64 def _validate_bytestring_length_equals(self, value, name, expected_len):
65 self._check_type(name, value, bytes)
66
67 value_len = len(value)
68 if value_len != expected_len:
69 msg = 'Invalid {} length: must be exactly {} bytes, found {} bytes'
70 self.verifier.error(msg.format(name, expected_len, value_len))
71
72 def _validate_bytestring_length_is_at_least(self, value, name, minimal_length):
73 self._check_type(name, value, bytes)
74
75 value_len = len(value)
76 if value_len < minimal_length:
77 msg = 'Invalid {} length: must be at least {} bytes, found {} bytes'
78 self.verifier.error(msg.format(name, minimal_length, value_len))
79
80 @staticmethod
81 def parse_raw(raw_value):
82 return raw_value
83
84 @staticmethod
85 def get_formatted_value(value):
86 return value
87
88 def is_utf_8(self):
89 return False
90
91 def check_cross_claim_requirements(self):
92 pass
93
94
95class NonVerifiedClaim(AttestationClaim):
96 def verify(self, value):
97 self.verify_count += 1
98
99 def get_claim_key(self=None):
100 raise NotImplementedError
101
102 def get_claim_name(self=None):
103 raise NotImplementedError
104
105
106class VerifierConfiguration:
107 def __init__(self, keep_going=False, strict=False):
108 self.keep_going=keep_going
109 self.strict=strict
110
111class AttestationTokenVerifier:
112
113 all_known_claims = {}
114
115 SIGN_METHOD_SIGN1 = "sign"
116 SIGN_METHOD_MAC0 = "mac"
117 SIGN_METHOD_RAW = "raw"
118
119 COSE_ALG_ES256="ES256"
120 COSE_ALG_ES384="ES384"
121 COSE_ALG_ES512="ES512"
122 COSE_ALG_HS256_64="HS256/64"
123 COSE_ALG_HS256="HS256"
124 COSE_ALG_HS384="HS384"
125 COSE_ALG_HS512="HS512"
126
127 def __init__(self, method, cose_alg, configuration=None):
128 self.method = method
129 self.cose_alg = cose_alg
130 self.config = configuration if configuration is not None else VerifierConfiguration()
131 self.claims = []
132
133 self.seen_errors = False
134
135 def add_claims(self, claims):
136 for claim in claims:
137 key = claim.get_claim_key()
138 if key not in AttestationTokenVerifier.all_known_claims:
139 AttestationTokenVerifier.all_known_claims[key] = claim.__class__
140
141 AttestationTokenVerifier.all_known_claims.update(claim.get_contained_claim_key_list())
142 self.claims.extend(claims)
143
144 def check_cross_claim_requirements(self):
145 pass
146
147 def decode_and_validate_iat(self, encoded_iat):
148 try:
149 raw_token = cbor2.loads(encoded_iat)
150 except Exception as e:
151 msg = 'Invalid CBOR: {}'
152 raise ValueError(msg.format(e))
153
154 claims = {v.get_claim_key(): v for v in self.claims}
155
156 token = {}
157 while not hasattr(raw_token, 'items'):
158 # TODO: token map is not a map. We are assuming that it is a tag
159 raw_token = raw_token.value
160 for entry in raw_token.keys():
161 value = raw_token[entry]
162
163 try:
164 claim = claims[entry]
165 except KeyError:
166 if self.config.strict:
167 self.error('Invalid IAT claim: {}'.format(entry))
168 token[entry] = value
169 continue
170
171 claim.verify(value)
172 claim.add_tokens_to_dict(token, value)
173
174 # Check claims' necessity
175 for claim in claims.values():
176 if not claim.claim_found():
177 if claim.necessity==AttestationClaim.MANDATORY:
178 msg = 'Invalid IAT: missing MANDATORY claim "{}"'
179 self.error(msg.format(claim.get_claim_name()))
180 elif claim.necessity==AttestationClaim.RECOMMENDED:
181 msg = 'Missing RECOMMENDED claim "{}"'
182 self.warning(msg.format(claim.get_claim_name()))
183
184 claim.check_cross_claim_requirements()
185
186 self.check_cross_claim_requirements()
187
188 return token
189
190
191 def get_wrapping_tag(self=None):
192 """The value of the tag that the token is wrapped in.
193
194 The function should return None if the token is not wrapped.
195 """
196 return None
197
198 def error(self, message):
199 self.seen_errors = True
200 if self.config.keep_going:
201 logger.error(message)
202 else:
203 raise ValueError(message)
204
205 def warning(self, message):
206 logger.warning(message)