76 lines
2.9 KiB
Python
76 lines
2.9 KiB
Python
import json
|
|
import base64
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
|
|
|
# This method decodes the JWT and verifies the signature. If a key is provided,
|
|
# that will be used for signature verification. Otherwise, the key sent within
|
|
# the JWT payload will be used instead.
|
|
# This returns a tuple of (decoded_header, decoded_payload, verify_succeeded).
|
|
def decode_jwt(token, key=None):
|
|
try:
|
|
# Decode the header and payload.
|
|
header, payload, signature = token.split('.')
|
|
decoded_header = decode_base64_json(header)
|
|
decoded_payload = decode_base64_json(payload)
|
|
|
|
# If decoding failed, return nothing.
|
|
if not decoded_header or not decoded_payload:
|
|
return None, None, False
|
|
|
|
# If there is a key passed in (for refresh), use that for checking the signature below.
|
|
# Otherwise (for registration), use the key sent within the JWT to check the signature.
|
|
if key == None:
|
|
key = decoded_payload.get('key')
|
|
public_key = serialization.load_pem_public_key(jwk_to_pem(key))
|
|
# Verifying the signature will throw an exception if it fails.
|
|
verify_rs256_signature(header, payload, signature, public_key)
|
|
return decoded_header, decoded_payload, True
|
|
except Exception:
|
|
return None, None, False
|
|
|
|
def jwk_to_pem(jwk_data):
|
|
jwk = json.loads(jwk_data) if isinstance(jwk_data, str) else jwk_data
|
|
key_type = jwk.get("kty")
|
|
|
|
if key_type != "RSA":
|
|
raise ValueError(f"Unsupported key type: {key_type}")
|
|
|
|
n = int.from_bytes(decode_base64url(jwk["n"]), 'big')
|
|
e = int.from_bytes(decode_base64url(jwk["e"]), 'big')
|
|
public_key = rsa.RSAPublicNumbers(e, n).public_key()
|
|
pem_public_key = public_key.public_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
|
)
|
|
return pem_public_key
|
|
|
|
def verify_rs256_signature(encoded_header, encoded_payload, signature, public_key):
|
|
message = (f'{encoded_header}.{encoded_payload}').encode('utf-8')
|
|
signature_bytes = decode_base64(signature)
|
|
# This will throw an exception if verification fails.
|
|
public_key.verify(
|
|
signature_bytes,
|
|
message,
|
|
padding.PKCS1v15(),
|
|
hashes.SHA256()
|
|
)
|
|
|
|
def add_base64_padding(encoded_data):
|
|
remainder = len(encoded_data) % 4
|
|
if remainder > 0:
|
|
encoded_data += '=' * (4 - remainder)
|
|
return encoded_data
|
|
|
|
def decode_base64url(encoded_data):
|
|
encoded_data = add_base64_padding(encoded_data)
|
|
encoded_data = encoded_data.replace("-", "+").replace("_", "/")
|
|
return base64.b64decode(encoded_data)
|
|
|
|
def decode_base64(encoded_data):
|
|
encoded_data = add_base64_padding(encoded_data)
|
|
return base64.urlsafe_b64decode(encoded_data)
|
|
|
|
def decode_base64_json(encoded_data):
|
|
return json.loads(decode_base64(encoded_data))
|