import logging
from typing import Any
from typing import Dict
from typing import NamedTuple
from typing import Optional
from typing import Set
import jwt
from thrift import TSerialization
from thrift.protocol.TBinaryProtocol import TBinaryProtocolAcceleratedFactory
from baseplate import RequestContext
from baseplate.lib import cached_property
from baseplate.lib.secrets import SecretsStore
from baseplate.thrift.ttypes import Device as TDevice
from baseplate.thrift.ttypes import Loid as TLoid
from baseplate.thrift.ttypes import Request as TRequest
from baseplate.thrift.ttypes import Session as TSession
logger = logging.getLogger(__name__)
[docs]class NoAuthenticationError(Exception):
"""Raised when trying to use an invalid or missing authentication token."""
class AuthenticationTokenValidator:
"""Factory that knows how to validate raw authentication tokens."""
def __init__(self, secrets: SecretsStore):
self.secrets = secrets
def validate(self, token: bytes) -> "AuthenticationToken":
"""Validate a raw authentication token and return an object.
:param token: token value originating from the Authentication service
either directly or from an upstream service
"""
if not token:
return InvalidAuthenticationToken()
secret = self.secrets.get_versioned("secret/authentication/public-key")
for public_key in secret.all_versions:
try:
decoded = jwt.decode(token, public_key, algorithms="RS256")
return ValidatedAuthenticationToken(decoded)
except jwt.ExpiredSignatureError:
return InvalidAuthenticationToken()
except jwt.DecodeError:
pass
return InvalidAuthenticationToken()
[docs]class AuthenticationToken:
"""Information about the authenticated user.
:py:class:`EdgeRequestContext` provides high-level helpers for extracting
data from authentication tokens. Use those instead of direct access through
this class.
"""
@property
def subject(self) -> Optional[str]:
"""Return the raw `subject` that is authenticated."""
raise NotImplementedError
@property
def user_roles(self) -> Set[str]:
raise NotImplementedError
@property
def oauth_client_id(self) -> Optional[str]:
raise NotImplementedError
@property
def oauth_client_type(self) -> Optional[str]:
raise NotImplementedError
@property
def scopes(self) -> Set[str]:
raise NotImplementedError
@property
def loid(self) -> Optional[str]:
raise NotImplementedError
@property
def loid_created_ms(self) -> Optional[int]:
raise NotImplementedError
class ValidatedAuthenticationToken(AuthenticationToken):
def __init__(self, payload: Dict[str, Any]):
self.payload = payload
@property
def subject(self) -> Optional[str]:
return self.payload.get("sub")
@cached_property
def user_roles(self) -> Set[str]: # type: ignore # pylint: disable=W0236
return set(self.payload.get("roles", []))
@property
def oauth_client_id(self) -> Optional[str]:
return self.payload.get("client_id")
@property
def oauth_client_type(self) -> Optional[str]:
return self.payload.get("client_type")
@property
def scopes(self) -> Set[str]:
return set(self.payload.get("scopes") or [])
@property
def loid(self) -> Optional[str]:
return (self.payload.get("loid") or {}).get("id")
@property
def loid_created_ms(self) -> Optional[int]:
return (self.payload.get("loid") or {}).get("created_ms")
class InvalidAuthenticationToken(AuthenticationToken):
@property
def subject(self) -> Optional[str]:
raise NoAuthenticationError
@property
def user_roles(self) -> Set[str]:
raise NoAuthenticationError
@property
def oauth_client_id(self) -> Optional[str]:
raise NoAuthenticationError
@property
def oauth_client_type(self) -> Optional[str]:
raise NoAuthenticationError
@property
def scopes(self) -> Set[str]:
raise NoAuthenticationError
@property
def loid(self) -> Optional[str]:
raise NoAuthenticationError
@property
def loid_created_ms(self) -> Optional[int]:
raise NoAuthenticationError
[docs]class Session(NamedTuple):
"""Wrapper for the session values in the EdgeRequestContext."""
id: str
class Device(NamedTuple):
"""Wrapper for the device values in the EdgeRequestContext."""
id: str
[docs]class User(NamedTuple):
"""Wrapper for the user values in AuthenticationToken and the LoId cookie."""
authentication_token: AuthenticationToken
loid: str
cookie_created_ms: int
@property
def id(self) -> Optional[str]:
"""Return the authenticated account_id for the current User.
:raises: :py:class:`NoAuthenticationError` if there was no
authentication token, it was invalid, or the subject is not an
account.
"""
subject = self.authentication_token.subject
if not (subject and subject.startswith("t2_")):
raise NoAuthenticationError
return subject
@property
def is_logged_in(self) -> bool:
"""Return if the User has a valid, authenticated id."""
try:
return self.id is not None
except NoAuthenticationError:
return False
@property
def roles(self) -> Set[str]:
"""Return the authenticated roles for the current User.
:raises: :py:class:`NoAuthenticationError` if there was no
authentication token or it was invalid
"""
return self.authentication_token.user_roles
[docs] def has_role(self, role: str) -> bool:
"""Return if the authenticated user has the specified role.
:param client_types: Case-insensitive sequence role name to check.
:raises: :py:class:`NoAuthenticationError` if there was no
authentication token defined for the current context
"""
return role.lower() in self.roles
[docs] def event_fields(self) -> Dict[str, Any]:
"""Return fields to be added to events."""
if self.is_logged_in:
user_id = self.id
else:
user_id = self.loid
return {
"user_id": user_id,
"logged_in": self.is_logged_in,
"cookie_created_timestamp": self.cookie_created_ms,
}
[docs]class OAuthClient(NamedTuple):
"""Wrapper for the OAuth2 client values in AuthenticationToken."""
authentication_token: AuthenticationToken
@property
def id(self) -> Optional[str]:
"""Return the authenticated id for the current client.
:raises: :py:class:`NoAuthenticationError` if there was no
authentication token defined for the current context
"""
return self.authentication_token.oauth_client_id
[docs] def is_type(self, *client_types: str) -> bool:
"""Return if the authenticated client type is one of the given types.
When checking the type of the current OauthClient, you should check
that the type "is" one of the allowed types rather than checking that
it "is not" a disallowed type.
For example::
if oauth_client.is_type("third_party"):
...
not::
if not oauth_client.is_type("first_party"):
...
:param client_types: Case-insensitive sequence of client type
names that you want to check.
:raises: :py:class:`NoAuthenticationError` if there was no
authentication token defined for the current context
"""
lower_types = (client_type.lower() for client_type in client_types)
if not self.authentication_token.oauth_client_type:
return False
return self.authentication_token.oauth_client_type in lower_types
[docs] def event_fields(self) -> Dict[str, Any]:
"""Return fields to be added to events."""
try:
oauth_client_id = self.id
except NoAuthenticationError:
oauth_client_id = None
return {"oauth_client_id": oauth_client_id}
[docs]class Service(NamedTuple):
"""Wrapper for the Service values in AuthenticationToken."""
authentication_token: AuthenticationToken
@property
def name(self) -> str:
"""Return the authenticated service name.
:type: name string or None if context authentication is invalid
:raises: :py:class:`NoAuthenticationError` if there was no
authentication token, it was invalid, or the subject is not a
service.
"""
subject = self.authentication_token.subject
if not (subject and subject.startswith("service/")):
raise NoAuthenticationError
name = subject[len("service/") :]
return name
[docs]class EdgeRequestContextFactory:
"""Factory for creating :py:class:`EdgeRequestContext` objects.
Every application should set one of these up. Edge services that talk
directly with clients should use :py:meth:`new` directly. For internal
services, pass the object off to Baseplate's framework integration
(Thrift/Pyramid) for automatic use.
:param baseplate.lib.secrets.SecretsStore secrets: A configured secrets
store.
"""
def __init__(self, secrets: SecretsStore):
self.authn_token_validator = AuthenticationTokenValidator(secrets)
[docs] def new(
self,
authentication_token: Optional[bytes] = None,
loid_id: Optional[str] = None,
loid_created_ms: Optional[int] = None,
session_id: Optional[str] = None,
device_id: Optional[str] = None,
) -> "EdgeRequestContext":
"""Return a new EdgeRequestContext object made from scratch.
Services at the edge that communicate directly with clients should use
this to pass on the information they get to downstream services. They
can then use this information to check authentication, run experiments,
etc.
To use this, create and attach the context early in your request flow:
.. code-block:: python
auth_cookie = request.cookies["authentication"]
token = request.authentication_service.authenticate_cookie(cookie)
loid = parse_loid(request.cookies["loid"])
session = parse_session(request.cookies["session"])
device_id = request.headers["x-device-id"]
edge_context = self.edgecontext_factory.new(
authentication_token=token,
loid_id=loid.id,
loid_created_ms=loid.created,
session_id=session.id,
device_id=device_id,
)
edge_context.attach_context(request)
:param authentication_token: A raw authentication token as returned by
the authentication service.
:param loid_id: ID for the current LoID in fullname format.
:param loid_created_ms: Epoch milliseconds when the current LoID cookie
was created.
:param session_id: ID for the current session cookie.
:param session_id: ID for the device where the request originated from.
"""
if loid_id is not None and not loid_id.startswith("t2_"):
raise ValueError(
"loid_id <%s> is not in a valid format, it should be in the "
"fullname format with the '0' padding removed: 't2_loid_id'" % loid_id
)
t_request = TRequest(
loid=TLoid(id=loid_id, created_ms=loid_created_ms),
session=TSession(id=session_id),
authentication_token=authentication_token,
device=TDevice(id=device_id),
)
header = TSerialization.serialize(t_request, EdgeRequestContext._HEADER_PROTOCOL_FACTORY)
context = EdgeRequestContext(self.authn_token_validator, header)
# Set the _t_request property so we can skip the deserialization step
# since we already have the thrift object.
context._t_request = t_request
return context
[docs] def from_upstream(self, edge_header: Optional[bytes]) -> "EdgeRequestContext":
"""Create and return an EdgeRequestContext from an upstream header.
This is generally used internally to Baseplate by framework
integrations that automatically pick up context from inbound requests.
:param edge_header: Raw payload of Edge-Request header from upstream
service.
"""
return EdgeRequestContext(self.authn_token_validator, edge_header)
[docs]class EdgeRequestContext:
"""Contextual information about the initial request to an edge service.
Construct this using an
:py:class:`~baseplate.lib.edge_context.EdgeRequestContextFactory`.
"""
_HEADER_PROTOCOL_FACTORY = TBinaryProtocolAcceleratedFactory()
def __init__(
self, authn_token_validator: AuthenticationTokenValidator, header: Optional[bytes]
):
self._authn_token_validator = authn_token_validator
self._header = header
[docs] def attach_context(self, context: RequestContext) -> None:
"""Attach this to the provided :py:class:`~baseplate.RequestContext`.
:param context: request context to attach this to
"""
context.request_context = self
context.raw_request_context = self._header
[docs] def event_fields(self) -> Dict[str, Any]:
"""Return fields to be added to events."""
fields = {"session_id": self.session.id}
if self.device.id:
fields["device_id"] = self.device.id
fields.update(self.user.event_fields())
fields.update(self.oauth_client.event_fields())
return fields
@cached_property
def authentication_token(self) -> AuthenticationToken:
return self._authn_token_validator.validate(self._t_request.authentication_token)
[docs] @cached_property
def user(self) -> User:
""":py:class:`~baseplate.lib.edge_context.User` object for the current context."""
return User(
authentication_token=self.authentication_token,
loid=self._t_request.loid.id,
cookie_created_ms=self._t_request.loid.created_ms,
)
[docs] @cached_property
def oauth_client(self) -> OAuthClient:
""":py:class:`~baseplate.lib.edge_context.OAuthClient` object for the current context."""
return OAuthClient(self.authentication_token)
[docs] @cached_property
def device(self) -> Device:
""":py:class:`~baseplate.lib.edge_context.Device` object for the current context."""
return Device(id=self._t_request.device.id)
[docs] @cached_property
def session(self) -> Session:
""":py:class:`~baseplate.lib.edge_context.Session` object for the current context."""
return Session(id=self._t_request.session.id)
[docs] @cached_property
def service(self) -> Service:
""":py:class:`~baseplate.lib.edge_context.Service` object for the current context."""
return Service(self.authentication_token)
@cached_property
def _t_request(self) -> TRequest: # pylint: disable=method-hidden
_t_request = TRequest()
_t_request.loid = TLoid()
_t_request.session = TSession()
_t_request.device = TDevice()
if self._header:
try:
TSerialization.deserialize(_t_request, self._header, self._HEADER_PROTOCOL_FACTORY)
except Exception:
logger.debug("Invalid Edge-Request header. %s", self._header)
return _t_request