test_sql_gateway_multitenant.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import pytest
  2. from sqlalchemy import Column
  3. from sqlalchemy import DateTime
  4. from sqlalchemy import Integer
  5. from sqlalchemy import MetaData
  6. from sqlalchemy import Table
  7. from sqlalchemy import Text
  8. from clean_python import ctx
  9. from clean_python import Filter
  10. from clean_python import Tenant
  11. from clean_python.sql import SQLGateway
  12. from clean_python.sql.testing import assert_query_equal
  13. from clean_python.sql.testing import FakeSQLDatabase
  14. writer = Table(
  15. "writer",
  16. MetaData(),
  17. Column("id", Integer, primary_key=True, autoincrement=True),
  18. Column("value", Text, nullable=False),
  19. Column("updated_at", DateTime(timezone=True), nullable=False),
  20. Column("tenant", Integer, nullable=False),
  21. )
  22. ALL_FIELDS = "writer.id, writer.value, writer.updated_at, writer.tenant"
  23. class TstSQLGateway(SQLGateway, table=writer, multitenant=True):
  24. pass
  25. @pytest.fixture
  26. def sql_gateway():
  27. return TstSQLGateway(FakeSQLDatabase())
  28. @pytest.fixture
  29. def tenant():
  30. ctx.tenant = Tenant(id=2, name="foo")
  31. return ctx.tenant
  32. async def test_no_tenant(sql_gateway):
  33. with pytest.raises(RuntimeError):
  34. await sql_gateway.filter([])
  35. assert len(sql_gateway.provider.queries) == 0
  36. async def test_missing_tenant_column():
  37. table = Table(
  38. "notenant",
  39. MetaData(),
  40. Column("id", Integer, primary_key=True, autoincrement=True),
  41. )
  42. with pytest.raises(ValueError):
  43. class Foo(SQLGateway, table=table, multitenant=True):
  44. pass
  45. async def test_filter(sql_gateway, tenant):
  46. sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
  47. assert await sql_gateway.filter([Filter(field="id", values=[1])]) == [
  48. {"id": 2, "value": "foo"}
  49. ]
  50. assert len(sql_gateway.provider.queries) == 1
  51. assert_query_equal(
  52. sql_gateway.provider.queries[0][0],
  53. f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 1 AND writer.tenant = {tenant.id}",
  54. )
  55. async def test_count(sql_gateway, tenant):
  56. sql_gateway.provider.result.return_value = [{"count": 4}]
  57. assert await sql_gateway.count([Filter(field="id", values=[1])]) == 4
  58. assert len(sql_gateway.provider.queries) == 1
  59. assert_query_equal(
  60. sql_gateway.provider.queries[0][0],
  61. f"SELECT count(*) AS count FROM writer WHERE writer.id = 1 AND writer.tenant = {tenant.id}",
  62. )
  63. async def test_add(sql_gateway, tenant):
  64. records = [{"id": 2, "value": "foo", "tenant": tenant.id}]
  65. sql_gateway.provider.result.return_value = records
  66. assert await sql_gateway.add({"value": "foo"}) == records[0]
  67. assert len(sql_gateway.provider.queries) == 1
  68. assert_query_equal(
  69. sql_gateway.provider.queries[0][0],
  70. (
  71. f"INSERT INTO writer (value, tenant) VALUES ('foo', {tenant.id}) RETURNING {ALL_FIELDS}"
  72. ),
  73. )
  74. async def test_update(sql_gateway, tenant):
  75. records = [{"id": 2, "value": "foo", "tenant": tenant.id}]
  76. sql_gateway.provider.result.return_value = records
  77. assert await sql_gateway.update({"id": 2, "value": "foo"}) == records[0]
  78. assert len(sql_gateway.provider.queries) == 1
  79. assert_query_equal(
  80. sql_gateway.provider.queries[0][0],
  81. (
  82. f"UPDATE writer SET id=2, value='foo', tenant={tenant.id} "
  83. f"WHERE writer.id = 2 AND writer.tenant = {tenant.id} "
  84. f"RETURNING {ALL_FIELDS}"
  85. ),
  86. )
  87. async def test_remove(sql_gateway, tenant):
  88. sql_gateway.provider.result.return_value = [{"id": 2}]
  89. assert (await sql_gateway.remove(2)) is True
  90. assert len(sql_gateway.provider.queries) == 1
  91. assert_query_equal(
  92. sql_gateway.provider.queries[0][0],
  93. (
  94. f"DELETE FROM writer WHERE writer.id = 2 AND writer.tenant = {tenant.id} "
  95. f"RETURNING writer.id"
  96. ),
  97. )
  98. async def test_upsert(sql_gateway, tenant):
  99. record = {"id": 2, "value": "foo", "tenant": tenant.id}
  100. sql_gateway.provider.result.return_value = [record]
  101. assert await sql_gateway.upsert({"id": 2, "value": "foo"}) == record
  102. assert len(sql_gateway.provider.queries) == 1
  103. assert_query_equal(
  104. sql_gateway.provider.queries[0][0],
  105. (
  106. f"INSERT INTO writer (id, value, tenant) VALUES (2, 'foo', {tenant.id}) "
  107. f"ON CONFLICT (id, tenant) DO UPDATE SET "
  108. f"id = %(param_1)s, value = %(param_2)s, tenant = %(param_3)s "
  109. f"RETURNING {ALL_FIELDS}"
  110. ),
  111. )
  112. async def test_update_transactional(sql_gateway, tenant):
  113. existing = {"id": 2, "value": "foo", "tenant": tenant.id}
  114. expected = {"id": 2, "value": "bar", "tenant": tenant.id}
  115. sql_gateway.provider.result.side_effect = ([existing], [expected])
  116. actual = await sql_gateway.update_transactional(
  117. 2, lambda x: {"id": x["id"], "value": "bar"}
  118. )
  119. assert actual == expected
  120. (queries,) = sql_gateway.provider.queries
  121. assert len(queries) == 2
  122. assert_query_equal(
  123. queries[0],
  124. (
  125. f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 2 "
  126. f"AND writer.tenant = {tenant.id} FOR UPDATE"
  127. ),
  128. )
  129. assert_query_equal(
  130. queries[1],
  131. (
  132. f"UPDATE writer SET id=2, value='bar', tenant={tenant.id} "
  133. f"WHERE writer.id = 2 AND writer.tenant = {tenant.id} RETURNING {ALL_FIELDS}"
  134. ),
  135. )
  136. async def test_exists(sql_gateway, tenant):
  137. sql_gateway.provider.result.return_value = [{"exists": True}]
  138. assert await sql_gateway.exists([Filter(field="id", values=[1])]) is True
  139. assert len(sql_gateway.provider.queries) == 1
  140. assert_query_equal(
  141. sql_gateway.provider.queries[0][0],
  142. (
  143. f"SELECT true AS exists FROM writer "
  144. f"WHERE writer.id = 1 AND writer.tenant = {tenant.id} LIMIT 1"
  145. ),
  146. )