test_client_credentials.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import base64
  2. import json
  3. import time
  4. from unittest import mock
  5. import pytest
  6. from clean_python.oauth2.client_credentials import CCTokenGateway
  7. from clean_python.oauth2.client_credentials import is_token_usable
  8. from clean_python.oauth2.client_credentials import OAuth2CCSettings
  9. from clean_python.oauth2.client_credentials import SyncCCTokenGateway
  10. SECRET_KEY = "abcd1234"
  11. MODULE = "clean_python.oauth2.client_credentials"
  12. def get_token(claims: dict, expires_in: int = 3600) -> str:
  13. claims["exp"] = int(time.time()) + expires_in
  14. payload = base64.b64encode(json.dumps(claims).encode()).decode()
  15. return f"header.{payload}.signature"
  16. @pytest.fixture
  17. def settings() -> OAuth2CCSettings:
  18. return OAuth2CCSettings(
  19. client_id="cid",
  20. client_secret="secret",
  21. token_url="https://authserver/token",
  22. scope="all",
  23. )
  24. @pytest.fixture
  25. def gateway(settings) -> CCTokenGateway:
  26. with mock.patch(MODULE + ".ApiProvider", autospec=True):
  27. yield CCTokenGateway(settings)
  28. @pytest.fixture
  29. def sync_gateway(settings) -> SyncCCTokenGateway:
  30. with mock.patch(MODULE + ".SyncApiProvider", autospec=True):
  31. yield SyncCCTokenGateway(settings)
  32. @pytest.mark.parametrize(
  33. "expires_in,leeway,expected",
  34. [
  35. (3600, 0, True),
  36. (-10, 0, False),
  37. (60, 300, False),
  38. ],
  39. )
  40. def test_is_token_usable(expires_in, leeway, expected):
  41. token = get_token({"user": "foo"}, expires_in=expires_in)
  42. assert is_token_usable(token, leeway) is expected
  43. async def test_fetch_token(gateway: CCTokenGateway):
  44. gateway.provider.request.return_value = {"access_token": "foo"}
  45. token = await gateway._fetch_token()
  46. assert token == "foo"
  47. gateway.provider.request.assert_awaited_once_with(
  48. method="POST",
  49. path="",
  50. fields={"grant_type": "client_credentials", "scope": "all"},
  51. timeout=1.0,
  52. )
  53. async def test_fetch_token_cache(gateway: CCTokenGateway):
  54. # empty cache: provider gets called
  55. token = get_token({})
  56. gateway.provider.request.return_value = {"access_token": token}
  57. actual = await gateway.fetch_token()
  58. assert actual == token
  59. assert gateway.provider.request.called
  60. gateway.provider.request.reset_mock()
  61. # cache is filled: provider is not called
  62. actual = await gateway.fetch_token()
  63. assert actual == token
  64. assert not gateway.provider.request.called
  65. gateway.provider.request.reset_mock()
  66. # token is not usable so it is refreshed:
  67. with mock.patch(MODULE + ".is_token_usable", side_effect=(False, True)):
  68. actual = await gateway.fetch_token()
  69. assert actual == token
  70. assert gateway.provider.request.called
  71. def test_fetch_token_sync(sync_gateway: SyncCCTokenGateway):
  72. sync_gateway.provider.request.return_value = {"access_token": "foo"}
  73. token = sync_gateway._fetch_token()
  74. assert token == "foo"
  75. sync_gateway.provider.request.assert_called_once_with(
  76. method="POST",
  77. path="",
  78. fields={"grant_type": "client_credentials", "scope": "all"},
  79. timeout=1.0,
  80. )
  81. def test_fetch_token_sync_cache(sync_gateway: SyncCCTokenGateway):
  82. # empty cache: provider gets called
  83. token = get_token({})
  84. sync_gateway.provider.request.return_value = {"access_token": token}
  85. actual = sync_gateway.fetch_token()
  86. assert actual == token
  87. assert sync_gateway.provider.request.called
  88. sync_gateway.provider.request.reset_mock()
  89. # cache is filled: provider is not called
  90. actual = sync_gateway.fetch_token()
  91. assert actual == token
  92. assert not sync_gateway.provider.request.called
  93. sync_gateway.provider.request.reset_mock()
  94. # token is not usable so it is refreshed:
  95. with mock.patch(MODULE + ".is_token_usable", side_effect=(False, True)):
  96. actual = sync_gateway.fetch_token()
  97. assert actual == token
  98. assert sync_gateway.provider.request.called
  99. async def test_fetch_headers(gateway: CCTokenGateway):
  100. gateway.provider.request.return_value = {"access_token": "foo"}
  101. await gateway.fetch_headers() == {"Authorization": "Bearer foo"}
  102. def test_fetch_headers_sync(sync_gateway: SyncCCTokenGateway):
  103. sync_gateway.provider.request.return_value = {"access_token": "foo"}
  104. assert sync_gateway.fetch_headers() == {"Authorization": "Bearer foo"}