Browse Source

Client credentials (#22)

Casper van der Wel 1 year ago
parent
commit
614eff9f39

+ 3 - 1
CHANGES.md

@@ -4,7 +4,9 @@
 0.6.7 (unreleased)
 ------------------
 
-- Nothing changed yet.
+- Adapt call signature of the `fetch_token` callable in `ApiProvicer`.
+
+- Add `clean_python.oauth.client_credentials`.
 
 
 0.6.6 (2023-10-04)

+ 7 - 12
clean_python/api_client/api_provider.py

@@ -3,6 +3,7 @@ import re
 from http import HTTPStatus
 from typing import Awaitable
 from typing import Callable
+from typing import Dict
 from typing import Optional
 from urllib.parse import quote
 from urllib.parse import urlencode
@@ -13,7 +14,6 @@ from aiohttp import ClientResponse
 from aiohttp import ClientSession
 from pydantic import AnyHttpUrl
 
-from clean_python import ctx
 from clean_python import Json
 
 from .exceptions import ApiException
@@ -62,7 +62,7 @@ class ApiProvider:
 
     Args:
         url: The url of the API (with trailing slash)
-        fetch_token: Callable that returns a token for a tenant id
+        fetch_token: Coroutine that returns headers for authorization
         retries: Total number of retries per request
         backoff_factor: Multiplier for retry delay times (1, 2, 4, ...)
     """
@@ -70,12 +70,13 @@ class ApiProvider:
     def __init__(
         self,
         url: AnyHttpUrl,
-        fetch_token: Callable[[ClientSession, int], Awaitable[Optional[str]]],
+        fetch_token: Callable[[], Awaitable[Dict[str, str]]],
         retries: int = 3,
         backoff_factor: float = 1.0,
     ):
         self._url = str(url)
-        assert self._url.endswith("/")
+        if not self._url.endswith("/"):
+            self._url += "/"
         self._fetch_token = fetch_token
         assert retries > 0
         self._retries = retries
@@ -91,27 +92,21 @@ class ApiProvider:
         fields: Optional[Json],
         timeout: float,
     ) -> ClientResponse:
-        assert ctx.tenant is not None
-        headers = {}
         request_kwargs = {
             "method": method,
             "url": add_query_params(join(self._url, quote(path)), params),
             "timeout": timeout,
             "json": json,
             "data": fields,
+            "headers": await self._fetch_token(),
         }
-        token = await self._fetch_token(self._session, ctx.tenant.id)
-        if token is not None:
-            headers["Authorization"] = f"Bearer {token}"
         for attempt in range(self._retries):
             if attempt > 0:
                 backoff = self._backoff_factor * 2 ** (attempt - 1)
                 await asyncio.sleep(backoff)
 
             try:
-                response = await self._session.request(
-                    headers=headers, **request_kwargs
-                )
+                response = await self._session.request(**request_kwargs)
                 await response.read()
             except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):
                 if attempt == self._retries - 1:

+ 5 - 7
clean_python/api_client/sync_api_provider.py

@@ -1,6 +1,7 @@
 import json as json_lib
 from http import HTTPStatus
 from typing import Callable
+from typing import Dict
 from typing import Optional
 from urllib.parse import quote
 
@@ -8,7 +9,6 @@ from pydantic import AnyHttpUrl
 from urllib3 import PoolManager
 from urllib3 import Retry
 
-from clean_python import ctx
 from clean_python import Json
 
 from .api_provider import add_query_params
@@ -36,12 +36,13 @@ class SyncApiProvider:
     def __init__(
         self,
         url: AnyHttpUrl,
-        fetch_token: Callable[[PoolManager, int], Optional[str]],
+        fetch_token: Callable[[], Dict[str, str]],
         retries: int = 3,
         backoff_factor: float = 1.0,
     ):
         self._url = str(url)
-        assert self._url.endswith("/")
+        if not self._url.endswith("/"):
+            self._url += "/"
         self._fetch_token = fetch_token
         self._pool = PoolManager(retries=Retry(retries, backoff_factor=backoff_factor))
 
@@ -54,7 +55,6 @@ class SyncApiProvider:
         fields: Optional[Json],
         timeout: float,
     ):
-        assert ctx.tenant is not None
         headers = {}
         request_kwargs = {
             "method": method,
@@ -69,9 +69,7 @@ class SyncApiProvider:
             headers["Content-Type"] = "application/json"
         elif fields is not None:
             request_kwargs["fields"] = fields
-        token = self._fetch_token(self._pool, ctx.tenant.id)
-        if token is not None:
-            headers["Authorization"] = f"Bearer {token}"
+        headers.update(self._fetch_token())
         return self._pool.request(headers=headers, **request_kwargs)
 
     def request(

+ 114 - 0
clean_python/oauth2/client_credentials.py

@@ -0,0 +1,114 @@
+import base64
+import json
+import time
+from functools import lru_cache
+
+from aiohttp import BasicAuth
+from async_lru import alru_cache
+from pydantic import AnyHttpUrl
+from pydantic import BaseModel
+
+from clean_python.api_client import ApiProvider
+from clean_python.api_client import SyncApiProvider
+
+__all__ = ["CCTokenGateway", "SyncCCTokenGateway", "OAuth2CCSettings"]
+
+
+REFRESH_TIME_DELTA = 5 * 60  # in seconds
+
+
+def decode_jwt(token):
+    """Decode a JWT without checking its signature"""
+    # JWT consists of {header}.{payload}.{signature}
+    _, payload, _ = token.split(".")
+    # JWT should be padded with = (base64.b64decode expects this)
+    payload += "=" * (-len(payload) % 4)
+    return json.loads(base64.b64decode(payload))
+
+
+def is_token_usable(token: str, leeway: int) -> bool:
+    """Determine whether the token has expired"""
+    try:
+        claims = decode_jwt(token)
+    except Exception:
+        return False
+
+    exp = claims["exp"]
+    refresh_on = exp - leeway
+    return refresh_on >= int(time.time())
+
+
+class OAuth2CCSettings(BaseModel):
+    token_url: AnyHttpUrl
+    client_id: str
+    client_secret: str
+    scope: str
+    timeout: float = 1.0  # in seconds
+    leeway: int = 5 * 60  # in seconds
+
+
+class CCTokenGateway:
+    def __init__(self, settings: OAuth2CCSettings):
+        self.scope = settings.scope
+        self.timeout = settings.timeout
+        self.leeway = settings.leeway
+
+        async def fetch_token():
+            auth = BasicAuth(settings.client_id, settings.client_secret)
+            return {"Authorization": auth.encode()}
+
+        self.provider = ApiProvider(url=settings.token_url, fetch_token=fetch_token)
+        # This binds the cache to the CCTokenGateway instance (and not the class)
+        self.cached_fetch_token = alru_cache(self._fetch_token)
+
+    async def _fetch_token(self) -> str:
+        response = await self.provider.request(
+            method="POST",
+            path="",
+            fields={"grant_type": "client_credentials", "scope": self.scope},
+            timeout=self.timeout,
+        )
+        assert response is not None
+        return response["access_token"]
+
+    async def fetch_token(self) -> str:
+        token_str = await self.cached_fetch_token()
+        if not is_token_usable(token_str, self.leeway):
+            self.cached_fetch_token.cache_clear()
+            token_str = await self.cached_fetch_token()
+        return token_str
+
+
+# Copy-paste of async version:
+
+
+class SyncCCTokenGateway:
+    def __init__(self, settings: OAuth2CCSettings):
+        self.scope = settings.scope
+        self.timeout = settings.timeout
+        self.leeway = settings.leeway
+
+        def fetch_token():
+            auth = BasicAuth(settings.client_id, settings.client_secret)
+            return {"Authorization": auth.encode()}
+
+        self.provider = SyncApiProvider(url=settings.token_url, fetch_token=fetch_token)
+        # This binds the cache to the SyncCCTokenGateway instance (and not the class)
+        self.cached_fetch_token = lru_cache(self._fetch_token)
+
+    def _fetch_token(self) -> str:
+        response = self.provider.request(
+            method="POST",
+            path="",
+            fields={"grant_type": "client_credentials", "scope": self.scope},
+            timeout=self.timeout,
+        )
+        assert response is not None
+        return response["access_token"]
+
+    def fetch_token(self) -> str:
+        token_str = self.cached_fetch_token()
+        if not is_token_usable(token_str, self.leeway):
+            self.cached_fetch_token.cache_clear()
+            token_str = self.cached_fetch_token()
+        return token_str

+ 33 - 0
integration_tests/fastapi_example/presentation.py

@@ -1,3 +1,6 @@
+import base64
+import json
+import time
 from http import HTTPStatus
 from typing import Optional
 
@@ -5,6 +8,9 @@ from fastapi import Depends
 from fastapi import Form
 from fastapi import Response
 from fastapi import UploadFile
+from fastapi.responses import JSONResponse
+from fastapi.security import HTTPBasic
+from fastapi.security import HTTPBasicCredentials
 
 from clean_python import DoesNotExist
 from clean_python import Page
@@ -33,6 +39,9 @@ class BookUpdate(ValueObject):
     title: Optional[str] = None
 
 
+basic = HTTPBasic()
+
+
 class V1Books(Resource, version=v(1), name="books"):
     def __init__(self):
         self.manager = ManageBook()
@@ -73,3 +82,27 @@ class V1Books(Resource, version=v(1), name="books"):
     @put("/urlencode/{name}", response_model=Author)
     async def urlencode(self, name: str):
         return {"name": name}
+
+    @post("/token")
+    def token(
+        self,
+        grant_type: str = Form(),
+        scope: str = Form(),
+        credentials: HTTPBasicCredentials = Depends(basic),
+    ):
+        """For testing client credentials grant"""
+        if grant_type != "client_credentials":
+            return JSONResponse({"error": "invalid_grant"})
+        if credentials.username != "testclient":
+            return JSONResponse({"error": "invalid_client"})
+        if credentials.password != "supersecret":
+            return JSONResponse({"error": "invalid_client"})
+        if scope != "all":
+            return JSONResponse({"error": "invalid_grant"})
+        claims = {"user": "foo", "exp": int(time.time()) + 3600}
+        payload = base64.b64encode(json.dumps(claims).encode()).decode()
+        return {
+            "access_token": f"header.{payload}.signature",
+            "token_type": "Bearer",
+            "expires_in": 3600,
+        }

+ 2 - 2
integration_tests/test_int_api_gateway.py

@@ -12,8 +12,8 @@ class BooksGateway(ApiGateway, path="v1/books/{id}"):
     pass
 
 
-async def fake_token(a, b):
-    return "token"
+async def fake_token():
+    return {"Authorization": "Bearer token"}
 
 
 @pytest.fixture

+ 2 - 2
integration_tests/test_int_api_provider.py

@@ -8,8 +8,8 @@ from clean_python.api_client import ApiException
 from clean_python.api_client import ApiProvider
 
 
-async def fake_token(a, b):
-    return "token"
+async def fake_token():
+    return {"Authorization": "Bearer token"}
 
 
 @pytest.fixture

+ 37 - 0
integration_tests/test_int_client_credentials.py

@@ -0,0 +1,37 @@
+import pytest
+
+from clean_python.oauth2.client_credentials import CCTokenGateway
+from clean_python.oauth2.client_credentials import is_token_usable
+from clean_python.oauth2.client_credentials import OAuth2CCSettings
+from clean_python.oauth2.client_credentials import SyncCCTokenGateway
+
+
+@pytest.fixture
+def settings(fastapi_example_app) -> OAuth2CCSettings:
+    # these settings match those hardcoded in the example app
+    return OAuth2CCSettings(
+        token_url=fastapi_example_app + "/v1/token",
+        client_id="testclient",
+        client_secret="supersecret",
+        scope="all",
+    )
+
+
+@pytest.fixture
+def gateway(settings) -> CCTokenGateway:
+    return CCTokenGateway(settings)
+
+
+async def test_fetch_token(gateway: CCTokenGateway):
+    response = await gateway._fetch_token()
+    assert is_token_usable(response, 0)
+
+
+@pytest.fixture
+def sync_gateway(settings) -> SyncCCTokenGateway:
+    return SyncCCTokenGateway(settings)
+
+
+def test_fetch_token_sync(sync_gateway: SyncCCTokenGateway):
+    response = sync_gateway._fetch_token()
+    assert is_token_usable(response, 0)

+ 3 - 1
integration_tests/test_int_sync_api_gateway.py

@@ -17,7 +17,9 @@ class BooksGateway(SyncApiGateway, path="v1/books/{id}"):
 @pytest.fixture
 def provider(fastapi_example_app) -> SyncApiProvider:
     ctx.tenant = Tenant(id=2, name="")
-    yield SyncApiProvider(fastapi_example_app + "/", lambda a, b: "token")
+    yield SyncApiProvider(
+        fastapi_example_app + "/", lambda: {"Authorization": "Bearer token"}
+    )
     ctx.tenant = None
 
 

+ 3 - 1
integration_tests/test_int_sync_api_provider.py

@@ -13,7 +13,9 @@ from clean_python.api_client import SyncApiProvider
 @pytest.fixture
 def provider(fastapi_example_app) -> SyncApiProvider:
     ctx.tenant = Tenant(id=2, name="")
-    yield SyncApiProvider(fastapi_example_app + "/", lambda a, b: "token")
+    yield SyncApiProvider(
+        fastapi_example_app + "/", lambda: {"Authorization": "Bearer token"}
+    )
     ctx.tenant = None
 
 

+ 1 - 1
pyproject.toml

@@ -10,7 +10,7 @@ license = {text = "MIT"}
 classifiers = ["Programming Language :: Python"]
 keywords = []
 requires-python = ">=3.7"
-dependencies = ["pydantic==2.*", "inject==4.*", "asgiref", "blinker"]
+dependencies = ["pydantic==2.*", "inject==4.*", "asgiref", "blinker", "async-lru"]
 dynamic = ["version"]
 
 [project.optional-dependencies]

+ 4 - 4
tests/api_client/test_api_provider.py

@@ -12,12 +12,12 @@ from clean_python.api_client import ApiProvider
 MODULE = "clean_python.api_client.api_provider"
 
 
-async def fake_token(a, b):
-    return f"tenant-{b}"
+async def fake_token():
+    return {"Authorization": f"Bearer tenant-{ctx.tenant.id}"}
 
 
-async def no_token(a, b):
-    return None
+async def no_token():
+    return {}
 
 
 @pytest.fixture

+ 2 - 4
tests/api_client/test_sync_api_provider.py

@@ -34,7 +34,7 @@ def api_provider(tenant, response) -> SyncApiProvider:
     with mock.patch(MODULE + ".PoolManager"):
         api_provider = SyncApiProvider(
             url="http://testserver/foo/",
-            fetch_token=lambda a, b: f"tenant-{b}",
+            fetch_token=lambda: {"Authorization": f"Bearer tenant-{ctx.tenant.id}"},
         )
         api_provider._pool.request.return_value = response
         yield api_provider
@@ -136,9 +136,7 @@ def test_error_response(api_provider: SyncApiProvider, response, status):
 
 @mock.patch(MODULE + ".PoolManager", new=mock.Mock())
 def test_no_token(response, tenant):
-    api_provider = SyncApiProvider(
-        url="http://testserver/foo/", fetch_token=lambda a, b: None
-    )
+    api_provider = SyncApiProvider(url="http://testserver/foo/", fetch_token=lambda: {})
     api_provider._pool.request.return_value = response
     api_provider.request("GET", "")
     assert api_provider._pool.request.call_args[1]["headers"] == {}

+ 133 - 0
tests/oauth2/test_client_credentials.py

@@ -0,0 +1,133 @@
+import base64
+import json
+import time
+from unittest import mock
+
+import pytest
+
+from clean_python.oauth2.client_credentials import CCTokenGateway
+from clean_python.oauth2.client_credentials import is_token_usable
+from clean_python.oauth2.client_credentials import OAuth2CCSettings
+from clean_python.oauth2.client_credentials import SyncCCTokenGateway
+
+SECRET_KEY = "abcd1234"
+MODULE = "clean_python.oauth2.client_credentials"
+
+
+def get_token(claims: dict, expires_in: int = 3600) -> str:
+    claims["exp"] = int(time.time()) + expires_in
+    payload = base64.b64encode(json.dumps(claims).encode()).decode()
+    return f"header.{payload}.signature"
+
+
+@pytest.fixture
+def settings() -> OAuth2CCSettings:
+    return OAuth2CCSettings(
+        client_id="cid",
+        client_secret="secret",
+        token_url="https://authserver/token",
+        scope="all",
+    )
+
+
+@pytest.fixture
+def gateway(settings) -> CCTokenGateway:
+    with mock.patch(MODULE + ".ApiProvider", autospec=True):
+        yield CCTokenGateway(settings)
+
+
+@pytest.fixture
+def sync_gateway(settings) -> SyncCCTokenGateway:
+    with mock.patch(MODULE + ".SyncApiProvider", autospec=True):
+        yield SyncCCTokenGateway(settings)
+
+
+@pytest.mark.parametrize(
+    "expires_in,leeway,expected",
+    [
+        (3600, 0, True),
+        (-10, 0, False),
+        (60, 300, False),
+    ],
+)
+def test_is_token_usable(expires_in, leeway, expected):
+    token = get_token({"user": "foo"}, expires_in=expires_in)
+    assert is_token_usable(token, leeway) is expected
+
+
+async def test_fetch_token(gateway: CCTokenGateway):
+    gateway.provider.request.return_value = {"access_token": "foo"}
+
+    token = await gateway._fetch_token()
+
+    assert token == "foo"
+
+    gateway.provider.request.assert_awaited_once_with(
+        method="POST",
+        path="",
+        fields={"grant_type": "client_credentials", "scope": "all"},
+        timeout=1.0,
+    )
+
+
+async def test_fetch_token_cache(gateway: CCTokenGateway):
+    # empty cache: provider gets called
+    token = get_token({})
+    gateway.provider.request.return_value = {"access_token": token}
+    actual = await gateway.fetch_token()
+    assert actual == token
+    assert gateway.provider.request.called
+
+    gateway.provider.request.reset_mock()
+
+    # cache is filled: provider is not called
+    actual = await gateway.fetch_token()
+    assert actual == token
+    assert not gateway.provider.request.called
+
+    gateway.provider.request.reset_mock()
+
+    # token is not usable so it is refreshed:
+    with mock.patch(MODULE + ".is_token_usable", side_effect=(False, True)):
+        actual = await gateway.fetch_token()
+        assert actual == token
+        assert gateway.provider.request.called
+
+
+def test_fetch_token_sync(sync_gateway: SyncCCTokenGateway):
+    sync_gateway.provider.request.return_value = {"access_token": "foo"}
+
+    token = sync_gateway._fetch_token()
+
+    assert token == "foo"
+
+    sync_gateway.provider.request.assert_called_once_with(
+        method="POST",
+        path="",
+        fields={"grant_type": "client_credentials", "scope": "all"},
+        timeout=1.0,
+    )
+
+
+def test_fetch_token_sync_cache(sync_gateway: SyncCCTokenGateway):
+    # empty cache: provider gets called
+    token = get_token({})
+    sync_gateway.provider.request.return_value = {"access_token": token}
+    actual = sync_gateway.fetch_token()
+    assert actual == token
+    assert sync_gateway.provider.request.called
+
+    sync_gateway.provider.request.reset_mock()
+
+    # cache is filled: provider is not called
+    actual = sync_gateway.fetch_token()
+    assert actual == token
+    assert not sync_gateway.provider.request.called
+
+    sync_gateway.provider.request.reset_mock()
+
+    # token is not usable so it is refreshed:
+    with mock.patch(MODULE + ".is_token_usable", side_effect=(False, True)):
+        actual = sync_gateway.fetch_token()
+        assert actual == token
+        assert sync_gateway.provider.request.called