Separate general and token specific code

Change-Id: I36f19eb2583398c2badcf7b078684412ed99922b
Signed-off-by: Mate Toth-Pal <mate.toth-pal@arm.com>
diff --git a/iat-verifier/iatverifier/attest_token_verifier.py b/iat-verifier/iatverifier/attest_token_verifier.py
new file mode 100644
index 0000000..e89dc6b
--- /dev/null
+++ b/iat-verifier/iatverifier/attest_token_verifier.py
@@ -0,0 +1,206 @@
+# -----------------------------------------------------------------------------
+# Copyright (c) 2019-2022, Arm Limited. All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import logging
+
+import cbor2
+
+logger = logging.getLogger('iat-verifiers')
+
+class AttestationClaim:
+    MANDATORY = 0
+    RECOMMENDED = 1
+    OPTIONAL = 2
+
+    def __init__(self, verifier, necessity=MANDATORY):
+        self.config = verifier.config
+        self.verifier = verifier
+        self.necessity = necessity
+        self.verify_count = 0
+
+    def verify(self, value):
+        raise NotImplementedError
+
+    def get_claim_key(self=None):
+        raise NotImplementedError
+
+    def get_claim_name(self=None):
+        raise NotImplementedError
+
+    def get_contained_claim_key_list(self):
+        return {}
+
+    def decode(self, value):
+        if self.is_utf_8():
+            try:
+                return value.decode()
+            except UnicodeDecodeError as e:
+                msg = 'Error decodeing value for "{}": {}'
+                self.verifier.error(msg.format(self.get_claim_name(), e))
+                return str(value)[2:-1]
+        else:  # not a UTF-8 value, i.e. a bytestring
+            return value
+
+    def add_tokens_to_dict(self, token, value):
+        entry_name = self.get_claim_name()
+        if isinstance(value, bytes):
+            value = self.decode(value)
+        token[entry_name] = value
+
+    def claim_found(self):
+        return self.verify_count>0
+
+    def _check_type(self, name, value, expected_type):
+        if not isinstance(value, expected_type):
+            msg = 'Invalid {}: must be a(n) {}: found {}'
+            self.verifier.error(msg.format(name, expected_type, type(value)))
+            return False
+        return True
+
+    def _validate_bytestring_length_equals(self, value, name, expected_len):
+        self._check_type(name, value, bytes)
+
+        value_len = len(value)
+        if value_len != expected_len:
+            msg = 'Invalid {} length: must be exactly {} bytes, found {} bytes'
+            self.verifier.error(msg.format(name, expected_len, value_len))
+
+    def _validate_bytestring_length_is_at_least(self, value, name, minimal_length):
+        self._check_type(name, value, bytes)
+
+        value_len = len(value)
+        if value_len < minimal_length:
+            msg = 'Invalid {} length: must be at least {} bytes, found {} bytes'
+            self.verifier.error(msg.format(name, minimal_length, value_len))
+
+    @staticmethod
+    def parse_raw(raw_value):
+        return raw_value
+
+    @staticmethod
+    def get_formatted_value(value):
+        return value
+
+    def is_utf_8(self):
+        return False
+
+    def check_cross_claim_requirements(self):
+        pass
+
+
+class NonVerifiedClaim(AttestationClaim):
+    def verify(self, value):
+        self.verify_count += 1
+
+    def get_claim_key(self=None):
+        raise NotImplementedError
+
+    def get_claim_name(self=None):
+        raise NotImplementedError
+
+
+class VerifierConfiguration:
+    def __init__(self, keep_going=False, strict=False):
+        self.keep_going=keep_going
+        self.strict=strict
+
+class AttestationTokenVerifier:
+
+    all_known_claims = {}
+
+    SIGN_METHOD_SIGN1 = "sign"
+    SIGN_METHOD_MAC0 = "mac"
+    SIGN_METHOD_RAW = "raw"
+
+    COSE_ALG_ES256="ES256"
+    COSE_ALG_ES384="ES384"
+    COSE_ALG_ES512="ES512"
+    COSE_ALG_HS256_64="HS256/64"
+    COSE_ALG_HS256="HS256"
+    COSE_ALG_HS384="HS384"
+    COSE_ALG_HS512="HS512"
+
+    def __init__(self, method, cose_alg, configuration=None):
+        self.method = method
+        self.cose_alg = cose_alg
+        self.config = configuration if configuration is not None else VerifierConfiguration()
+        self.claims = []
+
+        self.seen_errors = False
+
+    def add_claims(self, claims):
+        for claim in claims:
+            key = claim.get_claim_key()
+            if key not in AttestationTokenVerifier.all_known_claims:
+                AttestationTokenVerifier.all_known_claims[key] = claim.__class__
+
+            AttestationTokenVerifier.all_known_claims.update(claim.get_contained_claim_key_list())
+        self.claims.extend(claims)
+
+    def check_cross_claim_requirements(self):
+        pass
+
+    def decode_and_validate_iat(self, encoded_iat):
+        try:
+            raw_token = cbor2.loads(encoded_iat)
+        except Exception as e:
+            msg = 'Invalid CBOR: {}'
+            raise ValueError(msg.format(e))
+
+        claims = {v.get_claim_key(): v for v in self.claims}
+
+        token = {}
+        while not hasattr(raw_token, 'items'):
+            # TODO: token map is not a map. We are assuming that it is a tag
+            raw_token = raw_token.value
+        for entry in raw_token.keys():
+            value = raw_token[entry]
+
+            try:
+                claim = claims[entry]
+            except KeyError:
+                if self.config.strict:
+                    self.error('Invalid IAT claim: {}'.format(entry))
+                token[entry] = value
+                continue
+
+            claim.verify(value)
+            claim.add_tokens_to_dict(token, value)
+
+        # Check claims' necessity
+        for claim in claims.values():
+            if not claim.claim_found():
+                if claim.necessity==AttestationClaim.MANDATORY:
+                    msg = 'Invalid IAT: missing MANDATORY claim "{}"'
+                    self.error(msg.format(claim.get_claim_name()))
+                elif claim.necessity==AttestationClaim.RECOMMENDED:
+                    msg = 'Missing RECOMMENDED claim "{}"'
+                    self.warning(msg.format(claim.get_claim_name()))
+
+            claim.check_cross_claim_requirements()
+
+        self.check_cross_claim_requirements()
+
+        return token
+
+
+    def get_wrapping_tag(self=None):
+        """The value of the tag that the token is wrapped in.
+
+        The function should return None if the token is not wrapped.
+        """
+        return None
+
+    def error(self, message):
+        self.seen_errors = True
+        if self.config.keep_going:
+            logger.error(message)
+        else:
+            raise ValueError(message)
+
+    def warning(self, message):
+        logger.warning(message)