Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 1 | # ----------------------------------------------------------------------------- |
| 2 | # Copyright (c) 2019-2022, Arm Limited. All rights reserved. |
| 3 | # |
| 4 | # SPDX-License-Identifier: BSD-3-Clause |
| 5 | # |
| 6 | # ----------------------------------------------------------------------------- |
| 7 | |
| 8 | import logging |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 9 | from abc import ABC, abstractmethod |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 10 | |
| 11 | import cbor2 |
| 12 | |
| 13 | logger = logging.getLogger('iat-verifiers') |
| 14 | |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 15 | class AttestationClaim(ABC): |
| 16 | """ |
| 17 | This class represents a claim. |
| 18 | |
| 19 | This class is abstract. A concrete claim have to be derived from this class, |
| 20 | and it have to implement all the abstract methods. |
| 21 | |
| 22 | This class contains methods that are not abstract. These are here as a |
| 23 | default behavior, that a derived class might either keep, or override. |
| 24 | |
| 25 | A token is built up as a hierarchy of claim classes. Although it is |
| 26 | important, that claim objects don't have a 'value' field. The actual parsed |
| 27 | token is stored in a map structure. It is possible to execute operations |
| 28 | on a token map, and the operations are defined by claim classes/objects. |
| 29 | Such operations are for example verifying a token. |
| 30 | """ |
| 31 | |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 32 | MANDATORY = 0 |
| 33 | RECOMMENDED = 1 |
| 34 | OPTIONAL = 2 |
| 35 | |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 36 | def __init__(self, verifier, *, necessity=MANDATORY): |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 37 | self.config = verifier.config |
| 38 | self.verifier = verifier |
| 39 | self.necessity = necessity |
| 40 | self.verify_count = 0 |
| 41 | |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 42 | # Abstract methods |
| 43 | |
| 44 | @abstractmethod |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 45 | def verify(self, value): |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 46 | """Verify this claim |
| 47 | |
| 48 | Throw an exception if the claim is not valid""" |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 49 | raise NotImplementedError |
| 50 | |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 51 | @abstractmethod |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 52 | def get_claim_key(self=None): |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 53 | """Get the key of this claim |
| 54 | |
| 55 | Returns the key of this claim. The implementation have to support |
| 56 | calling this method with or without an instance as well.""" |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 57 | raise NotImplementedError |
| 58 | |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 59 | @abstractmethod |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 60 | def get_claim_name(self=None): |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 61 | """Get the name of this claim |
| 62 | |
| 63 | Returns the name of this claim. The implementation have to support |
| 64 | calling this method with or without an instance as well.""" |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 65 | raise NotImplementedError |
| 66 | |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 67 | # Default methods that a derived class might override |
| 68 | |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 69 | def get_contained_claim_key_list(self): |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 70 | """Return a dictionary of the claims that can be present in this claim |
| 71 | |
| 72 | Return a dictionary where keys are the claim keys (the same that is |
| 73 | returned by get_claim_key), and the values are the claim classes for |
| 74 | that key. |
| 75 | """ |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 76 | return {} |
| 77 | |
| 78 | def decode(self, value): |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 79 | """ |
| 80 | Decode the value of the claim if the value is an UTF-8 string |
| 81 | """ |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 82 | if self.is_utf_8(): |
| 83 | try: |
| 84 | return value.decode() |
| 85 | except UnicodeDecodeError as e: |
| 86 | msg = 'Error decodeing value for "{}": {}' |
| 87 | self.verifier.error(msg.format(self.get_claim_name(), e)) |
| 88 | return str(value)[2:-1] |
| 89 | else: # not a UTF-8 value, i.e. a bytestring |
| 90 | return value |
| 91 | |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 92 | def add_value_to_dict(self, token, value): |
| 93 | """Add 'value' to the dict 'token'""" |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 94 | entry_name = self.get_claim_name() |
| 95 | if isinstance(value, bytes): |
| 96 | value = self.decode(value) |
| 97 | token[entry_name] = value |
| 98 | |
| 99 | def claim_found(self): |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 100 | """Return true if verify was called on tis claim instance""" |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 101 | return self.verify_count>0 |
| 102 | |
| 103 | def _check_type(self, name, value, expected_type): |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 104 | """Check that a value's type is as expected""" |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 105 | if not isinstance(value, expected_type): |
| 106 | msg = 'Invalid {}: must be a(n) {}: found {}' |
| 107 | self.verifier.error(msg.format(name, expected_type, type(value))) |
| 108 | return False |
| 109 | return True |
| 110 | |
| 111 | def _validate_bytestring_length_equals(self, value, name, expected_len): |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 112 | """Check that a bytestreams length is as expected""" |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 113 | self._check_type(name, value, bytes) |
| 114 | |
| 115 | value_len = len(value) |
| 116 | if value_len != expected_len: |
| 117 | msg = 'Invalid {} length: must be exactly {} bytes, found {} bytes' |
| 118 | self.verifier.error(msg.format(name, expected_len, value_len)) |
| 119 | |
| 120 | def _validate_bytestring_length_is_at_least(self, value, name, minimal_length): |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 121 | """Check that a bytestream has a minimum length""" |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 122 | self._check_type(name, value, bytes) |
| 123 | |
| 124 | value_len = len(value) |
| 125 | if value_len < minimal_length: |
| 126 | msg = 'Invalid {} length: must be at least {} bytes, found {} bytes' |
| 127 | self.verifier.error(msg.format(name, minimal_length, value_len)) |
| 128 | |
| 129 | @staticmethod |
| 130 | def parse_raw(raw_value): |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 131 | """Parse a raw value |
| 132 | |
| 133 | As it appears in a yaml file |
| 134 | """ |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 135 | return raw_value |
| 136 | |
| 137 | @staticmethod |
| 138 | def get_formatted_value(value): |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 139 | """Format the value according to this claim""" |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 140 | return value |
| 141 | |
| 142 | def is_utf_8(self): |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 143 | """Returns whether the value of this claim should be UTF-8""" |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 144 | return False |
| 145 | |
| 146 | def check_cross_claim_requirements(self): |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 147 | """Check whether the claims inside this claim satisfy requirements""" |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 148 | |
| 149 | |
| 150 | class NonVerifiedClaim(AttestationClaim): |
| 151 | def verify(self, value): |
| 152 | self.verify_count += 1 |
| 153 | |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 154 | class CompositeAttestClaim(AttestationClaim): |
| 155 | """ |
| 156 | This class represents composite claim. |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 157 | |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 158 | This class is still abstract, but can contain other claims. This means that |
| 159 | a value representing this claim is a dictionary. This claim contains further |
| 160 | claims which represent the possible key-value pairs in the value for this |
| 161 | claim. |
| 162 | """ |
| 163 | |
| 164 | def __init__(self, verifier, *, claims, is_list, necessity=AttestationClaim.MANDATORY): |
| 165 | """ Initialise a composite claim. |
| 166 | |
| 167 | In case 'is_list' is False, the expected type of value is a dictionary, |
| 168 | containing the necessary claims determined by the 'claims' list. |
| 169 | In case 'is_list' is True, the expected type of value is a list, |
| 170 | containing a number of dictionaries, each one containing the necessary |
| 171 | claims determined by the 'claims' list. |
| 172 | """ |
| 173 | super().__init__(verifier, necessity=necessity) |
| 174 | self.is_list = is_list |
| 175 | self.claims = claims |
| 176 | |
| 177 | def _get_contained_claims(self): |
| 178 | return [claim(self.verifier, **args) for claim, args in self.claims] |
| 179 | |
| 180 | def get_contained_claim_key_list(self): |
| 181 | ret = {} |
| 182 | for claim in self._get_contained_claims(): |
| 183 | ret[claim.get_claim_key()] = claim.__class__ |
| 184 | return ret |
| 185 | |
| 186 | def _verify_dict(self, entry_number, value): |
| 187 | if not self._check_type(self.get_claim_name(), value, dict): |
| 188 | return |
| 189 | |
| 190 | claims = {v.get_claim_key(): v for v in self._get_contained_claims()} |
| 191 | for k, v in value.items(): |
| 192 | if k not in claims.keys(): |
| 193 | if self.config.strict: |
| 194 | msg = 'Unexpected {} claim: {}' |
| 195 | self.verifier.error(msg.format(self.get_claim_name(), k)) |
| 196 | else: |
| 197 | continue |
| 198 | try: |
| 199 | claims[k].verify(v) |
| 200 | except Exception: |
| 201 | if not self.config.keep_going: |
| 202 | raise |
| 203 | |
| 204 | # Check claims' necessity |
| 205 | for claim in claims.values(): |
| 206 | if not claim.claim_found(): |
| 207 | if claim.necessity==AttestationClaim.MANDATORY: |
| 208 | msg = ('Invalid IAT: missing MANDATORY claim "{}" ' |
| 209 | 'from {}').format(claim.get_claim_name(), |
| 210 | self.get_claim_name()) |
| 211 | if entry_number is not None: |
| 212 | msg += ' at index {}'.format(entry_number) |
| 213 | self.verifier.error(msg) |
| 214 | elif claim.necessity==AttestationClaim.RECOMMENDED: |
| 215 | msg = ('Missing RECOMMENDED claim "{}" ' |
| 216 | 'from {}').format(claim.get_claim_name(), |
| 217 | self.get_claim_name()) |
| 218 | if entry_number is not None: |
| 219 | msg += ' at index {}'.format(entry_number) |
| 220 | self.verifier.warning(msg) |
| 221 | |
| 222 | def verify(self, value): |
| 223 | """ |
| 224 | Verify a composite claim. |
| 225 | """ |
| 226 | if self.is_list: |
| 227 | if not self._check_type(self.get_claim_name(), value, list): |
| 228 | return |
| 229 | |
| 230 | for entry_number, entry in enumerate(value): |
| 231 | self._verify_dict(entry_number, entry) |
| 232 | else: |
| 233 | self._verify_dict(None, value) |
| 234 | |
| 235 | self.verify_count += 1 |
| 236 | |
| 237 | def _decode_dict(self, raw_dict): |
| 238 | decoded_dict = {} |
| 239 | names = {claim.get_claim_key(): claim.get_claim_name() for claim in self._get_contained_claims()} |
| 240 | for k, v in raw_dict.items(): |
| 241 | if isinstance(v, bytes): |
| 242 | v = self.decode(v) |
| 243 | try: |
| 244 | decoded_dict[names[k]] = v |
| 245 | except KeyError: |
| 246 | if self.config.strict: |
| 247 | if not self.config.keep_going: |
| 248 | raise |
| 249 | else: |
| 250 | decoded_dict[k] = v |
| 251 | return decoded_dict |
| 252 | |
| 253 | def add_value_to_dict(self, token, value): |
| 254 | entry_name = self.get_claim_name() |
| 255 | try: |
| 256 | token[entry_name] = [] |
| 257 | for raw_dict in value: |
| 258 | decoded_dict = self._decode_dict(raw_dict) |
| 259 | token[entry_name].append(decoded_dict) |
| 260 | except TypeError: |
| 261 | self.verifier.error('Invalid {} value: {}'.format(self.get_claim_name(), value)) |
| 262 | |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 263 | |
| 264 | |
| 265 | class VerifierConfiguration: |
| 266 | def __init__(self, keep_going=False, strict=False): |
| 267 | self.keep_going=keep_going |
| 268 | self.strict=strict |
| 269 | |
| 270 | class AttestationTokenVerifier: |
| 271 | |
| 272 | all_known_claims = {} |
| 273 | |
| 274 | SIGN_METHOD_SIGN1 = "sign" |
| 275 | SIGN_METHOD_MAC0 = "mac" |
| 276 | SIGN_METHOD_RAW = "raw" |
| 277 | |
| 278 | COSE_ALG_ES256="ES256" |
| 279 | COSE_ALG_ES384="ES384" |
| 280 | COSE_ALG_ES512="ES512" |
| 281 | COSE_ALG_HS256_64="HS256/64" |
| 282 | COSE_ALG_HS256="HS256" |
| 283 | COSE_ALG_HS384="HS384" |
| 284 | COSE_ALG_HS512="HS512" |
| 285 | |
| 286 | def __init__(self, method, cose_alg, configuration=None): |
| 287 | self.method = method |
| 288 | self.cose_alg = cose_alg |
| 289 | self.config = configuration if configuration is not None else VerifierConfiguration() |
| 290 | self.claims = [] |
| 291 | |
| 292 | self.seen_errors = False |
| 293 | |
| 294 | def add_claims(self, claims): |
| 295 | for claim in claims: |
| 296 | key = claim.get_claim_key() |
| 297 | if key not in AttestationTokenVerifier.all_known_claims: |
| 298 | AttestationTokenVerifier.all_known_claims[key] = claim.__class__ |
| 299 | |
| 300 | AttestationTokenVerifier.all_known_claims.update(claim.get_contained_claim_key_list()) |
| 301 | self.claims.extend(claims) |
| 302 | |
| 303 | def check_cross_claim_requirements(self): |
| 304 | pass |
| 305 | |
| 306 | def decode_and_validate_iat(self, encoded_iat): |
| 307 | try: |
| 308 | raw_token = cbor2.loads(encoded_iat) |
| 309 | except Exception as e: |
| 310 | msg = 'Invalid CBOR: {}' |
| 311 | raise ValueError(msg.format(e)) |
| 312 | |
| 313 | claims = {v.get_claim_key(): v for v in self.claims} |
| 314 | |
| 315 | token = {} |
| 316 | while not hasattr(raw_token, 'items'): |
| 317 | # TODO: token map is not a map. We are assuming that it is a tag |
| 318 | raw_token = raw_token.value |
| 319 | for entry in raw_token.keys(): |
| 320 | value = raw_token[entry] |
| 321 | |
| 322 | try: |
| 323 | claim = claims[entry] |
| 324 | except KeyError: |
| 325 | if self.config.strict: |
| 326 | self.error('Invalid IAT claim: {}'.format(entry)) |
| 327 | token[entry] = value |
| 328 | continue |
| 329 | |
| 330 | claim.verify(value) |
Mate Toth-Pal | d10a914 | 2022-04-28 15:34:13 +0200 | [diff] [blame^] | 331 | claim.add_value_to_dict(token, value) |
Mate Toth-Pal | bb187d0 | 2022-04-26 16:01:51 +0200 | [diff] [blame] | 332 | |
| 333 | # Check claims' necessity |
| 334 | for claim in claims.values(): |
| 335 | if not claim.claim_found(): |
| 336 | if claim.necessity==AttestationClaim.MANDATORY: |
| 337 | msg = 'Invalid IAT: missing MANDATORY claim "{}"' |
| 338 | self.error(msg.format(claim.get_claim_name())) |
| 339 | elif claim.necessity==AttestationClaim.RECOMMENDED: |
| 340 | msg = 'Missing RECOMMENDED claim "{}"' |
| 341 | self.warning(msg.format(claim.get_claim_name())) |
| 342 | |
| 343 | claim.check_cross_claim_requirements() |
| 344 | |
| 345 | self.check_cross_claim_requirements() |
| 346 | |
| 347 | return token |
| 348 | |
| 349 | |
| 350 | def get_wrapping_tag(self=None): |
| 351 | """The value of the tag that the token is wrapped in. |
| 352 | |
| 353 | The function should return None if the token is not wrapped. |
| 354 | """ |
| 355 | return None |
| 356 | |
| 357 | def error(self, message): |
| 358 | self.seen_errors = True |
| 359 | if self.config.keep_going: |
| 360 | logger.error(message) |
| 361 | else: |
| 362 | raise ValueError(message) |
| 363 | |
| 364 | def warning(self, message): |
| 365 | logger.warning(message) |