import json import base64 from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import rsa, padding # This method decodes the JWT and verifies the signature. If a key is provided, # that will be used for signature verification. Otherwise, the key sent within # the JWT payload will be used instead. # This returns a tuple of (decoded_header, decoded_payload, verify_succeeded). def decode_jwt(token, key=None): try: # Decode the header and payload. header, payload, signature = token.split('.') decoded_header = decode_base64_json(header) decoded_payload = decode_base64_json(payload) # If decoding failed, return nothing. if not decoded_header or not decoded_payload: return None, None, False # If there is a key passed in (for refresh), use that for checking the signature below. # Otherwise (for registration), use the key sent within the JWT to check the signature. if key == None: key = decoded_payload.get('key') public_key = serialization.load_pem_public_key(jwk_to_pem(key)) # Verifying the signature will throw an exception if it fails. verify_rs256_signature(header, payload, signature, public_key) return decoded_header, decoded_payload, True except Exception: return None, None, False def jwk_to_pem(jwk_data): jwk = json.loads(jwk_data) if isinstance(jwk_data, str) else jwk_data key_type = jwk.get("kty") if key_type != "RSA": raise ValueError(f"Unsupported key type: {key_type}") n = int.from_bytes(decode_base64url(jwk["n"]), 'big') e = int.from_bytes(decode_base64url(jwk["e"]), 'big') public_key = rsa.RSAPublicNumbers(e, n).public_key() pem_public_key = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo ) return pem_public_key def verify_rs256_signature(encoded_header, encoded_payload, signature, public_key): message = (f'{encoded_header}.{encoded_payload}').encode('utf-8') signature_bytes = decode_base64(signature) # This will throw an exception if verification fails. public_key.verify( signature_bytes, message, padding.PKCS1v15(), hashes.SHA256() ) def add_base64_padding(encoded_data): remainder = len(encoded_data) % 4 if remainder > 0: encoded_data += '=' * (4 - remainder) return encoded_data def decode_base64url(encoded_data): encoded_data = add_base64_padding(encoded_data) encoded_data = encoded_data.replace("-", "+").replace("_", "/") return base64.b64decode(encoded_data) def decode_base64(encoded_data): encoded_data = add_base64_padding(encoded_data) return base64.urlsafe_b64decode(encoded_data) def decode_base64_json(encoded_data): return json.loads(decode_base64(encoded_data))