blob: 2ee2bbac9b31ec1a4f66f1ad8c5ae4f28ecbc818 [file] [log] [blame]
Mate Toth-Palbb187d02022-04-26 16:01:51 +02001# -----------------------------------------------------------------------------
2# Copyright (c) 2019-2022, Arm Limited. All rights reserved.
3#
4# SPDX-License-Identifier: BSD-3-Clause
5#
6# -----------------------------------------------------------------------------
7
8import logging
Mate Toth-Pald10a9142022-04-28 15:34:13 +02009from abc import ABC, abstractmethod
Mate Toth-Palbb187d02022-04-26 16:01:51 +020010
11import cbor2
12
13logger = logging.getLogger('iat-verifiers')
14
Mate Toth-Pald10a9142022-04-28 15:34:13 +020015class AttestationClaim(ABC):
16 """
17 This class represents a claim.
18
19 This class is abstract. A concrete claim have to be derived from this class,
20 and it have to implement all the abstract methods.
21
22 This class contains methods that are not abstract. These are here as a
23 default behavior, that a derived class might either keep, or override.
24
25 A token is built up as a hierarchy of claim classes. Although it is
26 important, that claim objects don't have a 'value' field. The actual parsed
27 token is stored in a map structure. It is possible to execute operations
28 on a token map, and the operations are defined by claim classes/objects.
29 Such operations are for example verifying a token.
30 """
31
Mate Toth-Palbb187d02022-04-26 16:01:51 +020032 MANDATORY = 0
33 RECOMMENDED = 1
34 OPTIONAL = 2
35
Mate Toth-Pald10a9142022-04-28 15:34:13 +020036 def __init__(self, verifier, *, necessity=MANDATORY):
Mate Toth-Palbb187d02022-04-26 16:01:51 +020037 self.config = verifier.config
38 self.verifier = verifier
39 self.necessity = necessity
40 self.verify_count = 0
41
Mate Toth-Pald10a9142022-04-28 15:34:13 +020042 # Abstract methods
43
44 @abstractmethod
Mate Toth-Palbb187d02022-04-26 16:01:51 +020045 def verify(self, value):
Mate Toth-Pald10a9142022-04-28 15:34:13 +020046 """Verify this claim
47
48 Throw an exception if the claim is not valid"""
Mate Toth-Palbb187d02022-04-26 16:01:51 +020049 raise NotImplementedError
50
Mate Toth-Pald10a9142022-04-28 15:34:13 +020051 @abstractmethod
Mate Toth-Palbb187d02022-04-26 16:01:51 +020052 def get_claim_key(self=None):
Mate Toth-Pald10a9142022-04-28 15:34:13 +020053 """Get the key of this claim
54
55 Returns the key of this claim. The implementation have to support
56 calling this method with or without an instance as well."""
Mate Toth-Palbb187d02022-04-26 16:01:51 +020057 raise NotImplementedError
58
Mate Toth-Pald10a9142022-04-28 15:34:13 +020059 @abstractmethod
Mate Toth-Palbb187d02022-04-26 16:01:51 +020060 def get_claim_name(self=None):
Mate Toth-Pald10a9142022-04-28 15:34:13 +020061 """Get the name of this claim
62
63 Returns the name of this claim. The implementation have to support
64 calling this method with or without an instance as well."""
Mate Toth-Palbb187d02022-04-26 16:01:51 +020065 raise NotImplementedError
66
Mate Toth-Pald10a9142022-04-28 15:34:13 +020067 # Default methods that a derived class might override
68
Mate Toth-Palbb187d02022-04-26 16:01:51 +020069 def get_contained_claim_key_list(self):
Mate Toth-Pald10a9142022-04-28 15:34:13 +020070 """Return a dictionary of the claims that can be present in this claim
71
72 Return a dictionary where keys are the claim keys (the same that is
73 returned by get_claim_key), and the values are the claim classes for
74 that key.
75 """
Mate Toth-Palbb187d02022-04-26 16:01:51 +020076 return {}
77
78 def decode(self, value):
Mate Toth-Pald10a9142022-04-28 15:34:13 +020079 """
80 Decode the value of the claim if the value is an UTF-8 string
81 """
Mate Toth-Palbb187d02022-04-26 16:01:51 +020082 if self.is_utf_8():
83 try:
84 return value.decode()
85 except UnicodeDecodeError as e:
86 msg = 'Error decodeing value for "{}": {}'
87 self.verifier.error(msg.format(self.get_claim_name(), e))
88 return str(value)[2:-1]
89 else: # not a UTF-8 value, i.e. a bytestring
90 return value
91
Mate Toth-Pald10a9142022-04-28 15:34:13 +020092 def add_value_to_dict(self, token, value):
93 """Add 'value' to the dict 'token'"""
Mate Toth-Palbb187d02022-04-26 16:01:51 +020094 entry_name = self.get_claim_name()
95 if isinstance(value, bytes):
96 value = self.decode(value)
97 token[entry_name] = value
98
99 def claim_found(self):
Mate Toth-Pald10a9142022-04-28 15:34:13 +0200100 """Return true if verify was called on tis claim instance"""
Mate Toth-Palbb187d02022-04-26 16:01:51 +0200101 return self.verify_count>0
102
103 def _check_type(self, name, value, expected_type):
Mate Toth-Pald10a9142022-04-28 15:34:13 +0200104 """Check that a value's type is as expected"""
Mate Toth-Palbb187d02022-04-26 16:01:51 +0200105 if not isinstance(value, expected_type):
106 msg = 'Invalid {}: must be a(n) {}: found {}'
107 self.verifier.error(msg.format(name, expected_type, type(value)))
108 return False
109 return True
110
111 def _validate_bytestring_length_equals(self, value, name, expected_len):
Mate Toth-Pald10a9142022-04-28 15:34:13 +0200112 """Check that a bytestreams length is as expected"""
Mate Toth-Palbb187d02022-04-26 16:01:51 +0200113 self._check_type(name, value, bytes)
114
115 value_len = len(value)
116 if value_len != expected_len:
117 msg = 'Invalid {} length: must be exactly {} bytes, found {} bytes'
118 self.verifier.error(msg.format(name, expected_len, value_len))
119
120 def _validate_bytestring_length_is_at_least(self, value, name, minimal_length):
Mate Toth-Pald10a9142022-04-28 15:34:13 +0200121 """Check that a bytestream has a minimum length"""
Mate Toth-Palbb187d02022-04-26 16:01:51 +0200122 self._check_type(name, value, bytes)
123
124 value_len = len(value)
125 if value_len < minimal_length:
126 msg = 'Invalid {} length: must be at least {} bytes, found {} bytes'
127 self.verifier.error(msg.format(name, minimal_length, value_len))
128
129 @staticmethod
130 def parse_raw(raw_value):
Mate Toth-Pald10a9142022-04-28 15:34:13 +0200131 """Parse a raw value
132
133 As it appears in a yaml file
134 """
Mate Toth-Palbb187d02022-04-26 16:01:51 +0200135 return raw_value
136
137 @staticmethod
138 def get_formatted_value(value):
Mate Toth-Pald10a9142022-04-28 15:34:13 +0200139 """Format the value according to this claim"""
Mate Toth-Palbb187d02022-04-26 16:01:51 +0200140 return value
141
142 def is_utf_8(self):
Mate Toth-Pald10a9142022-04-28 15:34:13 +0200143 """Returns whether the value of this claim should be UTF-8"""
Mate Toth-Palbb187d02022-04-26 16:01:51 +0200144 return False
145
146 def check_cross_claim_requirements(self):
Mate Toth-Pald10a9142022-04-28 15:34:13 +0200147 """Check whether the claims inside this claim satisfy requirements"""
Mate Toth-Palbb187d02022-04-26 16:01:51 +0200148
149
150class NonVerifiedClaim(AttestationClaim):
151 def verify(self, value):
152 self.verify_count += 1
153
Mate Toth-Pald10a9142022-04-28 15:34:13 +0200154class CompositeAttestClaim(AttestationClaim):
155 """
156 This class represents composite claim.
Mate Toth-Palbb187d02022-04-26 16:01:51 +0200157
Mate Toth-Pald10a9142022-04-28 15:34:13 +0200158 This class is still abstract, but can contain other claims. This means that
159 a value representing this claim is a dictionary. This claim contains further
160 claims which represent the possible key-value pairs in the value for this
161 claim.
162 """
163
164 def __init__(self, verifier, *, claims, is_list, necessity=AttestationClaim.MANDATORY):
165 """ Initialise a composite claim.
166
167 In case 'is_list' is False, the expected type of value is a dictionary,
168 containing the necessary claims determined by the 'claims' list.
169 In case 'is_list' is True, the expected type of value is a list,
170 containing a number of dictionaries, each one containing the necessary
171 claims determined by the 'claims' list.
172 """
173 super().__init__(verifier, necessity=necessity)
174 self.is_list = is_list
175 self.claims = claims
176
177 def _get_contained_claims(self):
178 return [claim(self.verifier, **args) for claim, args in self.claims]
179
180 def get_contained_claim_key_list(self):
181 ret = {}
182 for claim in self._get_contained_claims():
183 ret[claim.get_claim_key()] = claim.__class__
184 return ret
185
186 def _verify_dict(self, entry_number, value):
187 if not self._check_type(self.get_claim_name(), value, dict):
188 return
189
190 claims = {v.get_claim_key(): v for v in self._get_contained_claims()}
191 for k, v in value.items():
192 if k not in claims.keys():
193 if self.config.strict:
194 msg = 'Unexpected {} claim: {}'
195 self.verifier.error(msg.format(self.get_claim_name(), k))
196 else:
197 continue
198 try:
199 claims[k].verify(v)
200 except Exception:
201 if not self.config.keep_going:
202 raise
203
204 # Check claims' necessity
205 for claim in claims.values():
206 if not claim.claim_found():
207 if claim.necessity==AttestationClaim.MANDATORY:
208 msg = ('Invalid IAT: missing MANDATORY claim "{}" '
209 'from {}').format(claim.get_claim_name(),
210 self.get_claim_name())
211 if entry_number is not None:
212 msg += ' at index {}'.format(entry_number)
213 self.verifier.error(msg)
214 elif claim.necessity==AttestationClaim.RECOMMENDED:
215 msg = ('Missing RECOMMENDED claim "{}" '
216 'from {}').format(claim.get_claim_name(),
217 self.get_claim_name())
218 if entry_number is not None:
219 msg += ' at index {}'.format(entry_number)
220 self.verifier.warning(msg)
221
222 def verify(self, value):
223 """
224 Verify a composite claim.
225 """
226 if self.is_list:
227 if not self._check_type(self.get_claim_name(), value, list):
228 return
229
230 for entry_number, entry in enumerate(value):
231 self._verify_dict(entry_number, entry)
232 else:
233 self._verify_dict(None, value)
234
235 self.verify_count += 1
236
237 def _decode_dict(self, raw_dict):
238 decoded_dict = {}
239 names = {claim.get_claim_key(): claim.get_claim_name() for claim in self._get_contained_claims()}
240 for k, v in raw_dict.items():
241 if isinstance(v, bytes):
242 v = self.decode(v)
243 try:
244 decoded_dict[names[k]] = v
245 except KeyError:
246 if self.config.strict:
247 if not self.config.keep_going:
248 raise
249 else:
250 decoded_dict[k] = v
251 return decoded_dict
252
253 def add_value_to_dict(self, token, value):
254 entry_name = self.get_claim_name()
255 try:
256 token[entry_name] = []
257 for raw_dict in value:
258 decoded_dict = self._decode_dict(raw_dict)
259 token[entry_name].append(decoded_dict)
260 except TypeError:
261 self.verifier.error('Invalid {} value: {}'.format(self.get_claim_name(), value))
262
Mate Toth-Palbb187d02022-04-26 16:01:51 +0200263
264
265class VerifierConfiguration:
266 def __init__(self, keep_going=False, strict=False):
267 self.keep_going=keep_going
268 self.strict=strict
269
270class AttestationTokenVerifier:
271
272 all_known_claims = {}
273
274 SIGN_METHOD_SIGN1 = "sign"
275 SIGN_METHOD_MAC0 = "mac"
276 SIGN_METHOD_RAW = "raw"
277
278 COSE_ALG_ES256="ES256"
279 COSE_ALG_ES384="ES384"
280 COSE_ALG_ES512="ES512"
281 COSE_ALG_HS256_64="HS256/64"
282 COSE_ALG_HS256="HS256"
283 COSE_ALG_HS384="HS384"
284 COSE_ALG_HS512="HS512"
285
286 def __init__(self, method, cose_alg, configuration=None):
287 self.method = method
288 self.cose_alg = cose_alg
289 self.config = configuration if configuration is not None else VerifierConfiguration()
290 self.claims = []
291
292 self.seen_errors = False
293
294 def add_claims(self, claims):
295 for claim in claims:
296 key = claim.get_claim_key()
297 if key not in AttestationTokenVerifier.all_known_claims:
298 AttestationTokenVerifier.all_known_claims[key] = claim.__class__
299
300 AttestationTokenVerifier.all_known_claims.update(claim.get_contained_claim_key_list())
301 self.claims.extend(claims)
302
303 def check_cross_claim_requirements(self):
304 pass
305
306 def decode_and_validate_iat(self, encoded_iat):
307 try:
308 raw_token = cbor2.loads(encoded_iat)
309 except Exception as e:
310 msg = 'Invalid CBOR: {}'
311 raise ValueError(msg.format(e))
312
313 claims = {v.get_claim_key(): v for v in self.claims}
314
315 token = {}
316 while not hasattr(raw_token, 'items'):
317 # TODO: token map is not a map. We are assuming that it is a tag
318 raw_token = raw_token.value
319 for entry in raw_token.keys():
320 value = raw_token[entry]
321
322 try:
323 claim = claims[entry]
324 except KeyError:
325 if self.config.strict:
326 self.error('Invalid IAT claim: {}'.format(entry))
327 token[entry] = value
328 continue
329
330 claim.verify(value)
Mate Toth-Pald10a9142022-04-28 15:34:13 +0200331 claim.add_value_to_dict(token, value)
Mate Toth-Palbb187d02022-04-26 16:01:51 +0200332
333 # Check claims' necessity
334 for claim in claims.values():
335 if not claim.claim_found():
336 if claim.necessity==AttestationClaim.MANDATORY:
337 msg = 'Invalid IAT: missing MANDATORY claim "{}"'
338 self.error(msg.format(claim.get_claim_name()))
339 elif claim.necessity==AttestationClaim.RECOMMENDED:
340 msg = 'Missing RECOMMENDED claim "{}"'
341 self.warning(msg.format(claim.get_claim_name()))
342
343 claim.check_cross_claim_requirements()
344
345 self.check_cross_claim_requirements()
346
347 return token
348
349
350 def get_wrapping_tag(self=None):
351 """The value of the tag that the token is wrapped in.
352
353 The function should return None if the token is not wrapped.
354 """
355 return None
356
357 def error(self, message):
358 self.seen_errors = True
359 if self.config.keep_going:
360 logger.error(message)
361 else:
362 raise ValueError(message)
363
364 def warning(self, message):
365 logger.warning(message)