Source code for ox.utils.jwt

from datetime import datetime
from typing import Any, Type

import jwt
from pydantic import BaseModel, field_validator

from django.core.exceptions import PermissionDenied
from django.conf import settings
from django.utils.translation import gettext_lazy as _


__all__ = ("JWToken", "JWTokenViewMixin", "encode", "decode")


[docs] class JWToken(BaseModel): """ A JSON Web Token to encode or decode. It is a pydantic model whose values are decoded from raw string / is used to serialize the token. When :py:attr:`expiration` is provided, model validation checks that it is not expired. """ expiration: datetime | None = None def encode(self, **kwargs): return encode(self.model_dump(mode="json"), **kwargs) @classmethod def decode(cls, jwt: str, **kwargs): data = decode(jwt, **kwargs) return cls.model_validate(data) @field_validator("expiration", mode="after") @classmethod def validate_expiration(cls, value): # We assume that datetime is local to server if value is not None and value < datetime.now(): raise ValueError("This token is expired") return value
[docs] class JWTokenViewMixin: """View mixin that validate input jwt. The JWT is provided as url argument (string) or GET parameter, and is read at dispatch. PermissionDenied is raised if an error occurs. Otherwise, sets :py:attr:`jwt` to the decoded token. """ jwt: JWToken = None """ The JWT read from url if any. """ jwt_str: str = None """ The provided JWT string. """ jwtoken_class: Type[JWToken] = None """ JWToken class to invoke """ def dispatch(self, *args, jwt: str = None, **kwargs): if jwt := (jwt or self.request.GET.get("jwt")): if not self.jwtoken_class: raise ValueError("Missing `jwtoken_class` attribute") self.jwt = self.get_jwt(jwt) self.jwt_str = jwt return super().dispatch(*args, **kwargs)
[docs] def get_jwt(self, value: str): """Get JWT from the provided string. :yield PermissionDenied: an error occurred decoding the token. """ try: return self.jwtoken_class.decode(value) except Exception: raise PermissionDenied(_("This token is not valid"))
[docs] def encode(payload: dict[str, Any], headers=None, json_encoder=None, algorithm="HS256"): """ Simple wrapper around pyjwt's encode, using configured key and HS256 algorithm. """ return jwt.encode(payload, settings.JWT_KEY, headers=headers, json_encoder=json_encoder, algorithm=algorithm)
[docs] def decode(value: str, algorithms=["HS256"], **kwargs) -> dict[str, Any] | JWToken: """ Simple wrapper around pyjwt's decode, using configured key (and fallback keys). It will try to run over all keys provided by ``JWT_KEY`` and ``JWT_KEY_FALLBACKS`` settings. If none succeed, will re-raise error that happened on the decoding using ``JWT_KEY``. :param jwt: the token to be decoded. :param **kwargs: extra parameters to pass to ``jwt.decode``. """ keys = [settings.JWT_KEY, *settings.JWT_KEY_FALLBACKS] error = None for key in keys: try: return jwt.decode(value, key, algorithms=algorithms, **kwargs) except Exception as err: if error is None: error = err raise error