|
@@ -0,0 +1,179 @@
|
|
|
+import pytest
|
|
|
+from sqlalchemy import Column
|
|
|
+from sqlalchemy import DateTime
|
|
|
+from sqlalchemy import Integer
|
|
|
+from sqlalchemy import MetaData
|
|
|
+from sqlalchemy import Table
|
|
|
+from sqlalchemy import Text
|
|
|
+
|
|
|
+from clean_python import ctx
|
|
|
+from clean_python import Filter
|
|
|
+from clean_python import Tenant
|
|
|
+from clean_python.sql import SQLGateway
|
|
|
+from clean_python.sql.testing import assert_query_equal
|
|
|
+from clean_python.sql.testing import FakeSQLDatabase
|
|
|
+
|
|
|
+writer = Table(
|
|
|
+ "writer",
|
|
|
+ MetaData(),
|
|
|
+ Column("id", Integer, primary_key=True, autoincrement=True),
|
|
|
+ Column("value", Text, nullable=False),
|
|
|
+ Column("updated_at", DateTime(timezone=True), nullable=False),
|
|
|
+ Column("tenant", Integer, nullable=False),
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+ALL_FIELDS = "writer.id, writer.value, writer.updated_at, writer.tenant"
|
|
|
+
|
|
|
+
|
|
|
+class TstSQLGateway(SQLGateway, table=writer, multitenant=True):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def sql_gateway():
|
|
|
+ return TstSQLGateway(FakeSQLDatabase())
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def tenant():
|
|
|
+ ctx.tenant = Tenant(id=2, name="foo")
|
|
|
+ return ctx.tenant
|
|
|
+
|
|
|
+
|
|
|
+async def test_no_tenant(sql_gateway):
|
|
|
+ with pytest.raises(RuntimeError):
|
|
|
+ await sql_gateway.filter([])
|
|
|
+ assert len(sql_gateway.provider.queries) == 0
|
|
|
+
|
|
|
+
|
|
|
+async def test_missing_tenant_column():
|
|
|
+ table = Table(
|
|
|
+ "notenant",
|
|
|
+ MetaData(),
|
|
|
+ Column("id", Integer, primary_key=True, autoincrement=True),
|
|
|
+ )
|
|
|
+
|
|
|
+ with pytest.raises(ValueError):
|
|
|
+
|
|
|
+ class Foo(SQLGateway, table=table, multitenant=True):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+async def test_filter(sql_gateway, tenant):
|
|
|
+ sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
|
|
|
+ assert await sql_gateway.filter([Filter(field="id", values=[1])]) == [
|
|
|
+ {"id": 2, "value": "foo"}
|
|
|
+ ]
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 1 AND writer.tenant = {tenant.id}",
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def test_count(sql_gateway, tenant):
|
|
|
+ sql_gateway.provider.result.return_value = [{"count": 4}]
|
|
|
+ assert await sql_gateway.count([Filter(field="id", values=[1])]) == 4
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ f"SELECT count(*) AS count FROM writer WHERE writer.id = 1 AND writer.tenant = {tenant.id}",
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def test_add(sql_gateway, tenant):
|
|
|
+ records = [{"id": 2, "value": "foo", "tenant": tenant.id}]
|
|
|
+ sql_gateway.provider.result.return_value = records
|
|
|
+ assert await sql_gateway.add({"value": "foo"}) == records[0]
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ (
|
|
|
+ f"INSERT INTO writer (value, tenant) VALUES ('foo', {tenant.id}) RETURNING {ALL_FIELDS}"
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def test_update(sql_gateway, tenant):
|
|
|
+ records = [{"id": 2, "value": "foo", "tenant": tenant.id}]
|
|
|
+ sql_gateway.provider.result.return_value = records
|
|
|
+ assert await sql_gateway.update({"id": 2, "value": "foo"}) == records[0]
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ (
|
|
|
+ f"UPDATE writer SET id=2, value='foo', tenant={tenant.id} "
|
|
|
+ f"WHERE writer.id = 2 AND writer.tenant = {tenant.id} "
|
|
|
+ f"RETURNING {ALL_FIELDS}"
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def test_remove(sql_gateway, tenant):
|
|
|
+ sql_gateway.provider.result.return_value = [{"id": 2}]
|
|
|
+ assert (await sql_gateway.remove(2)) is True
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ (
|
|
|
+ f"DELETE FROM writer WHERE writer.id = 2 AND writer.tenant = {tenant.id} "
|
|
|
+ f"RETURNING writer.id"
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def test_upsert(sql_gateway, tenant):
|
|
|
+ record = {"id": 2, "value": "foo", "tenant": tenant.id}
|
|
|
+ sql_gateway.provider.result.return_value = [record]
|
|
|
+ assert await sql_gateway.upsert({"id": 2, "value": "foo"}) == record
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ (
|
|
|
+ f"INSERT INTO writer (id, value, tenant) VALUES (2, 'foo', {tenant.id}) "
|
|
|
+ f"ON CONFLICT (id, tenant) DO UPDATE SET "
|
|
|
+ f"id = %(param_1)s, value = %(param_2)s, tenant = %(param_3)s "
|
|
|
+ f"RETURNING {ALL_FIELDS}"
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def test_update_transactional(sql_gateway, tenant):
|
|
|
+ existing = {"id": 2, "value": "foo", "tenant": tenant.id}
|
|
|
+ expected = {"id": 2, "value": "bar", "tenant": tenant.id}
|
|
|
+ sql_gateway.provider.result.side_effect = ([existing], [expected])
|
|
|
+ actual = await sql_gateway.update_transactional(
|
|
|
+ 2, lambda x: {"id": x["id"], "value": "bar"}
|
|
|
+ )
|
|
|
+ assert actual == expected
|
|
|
+
|
|
|
+ (queries,) = sql_gateway.provider.queries
|
|
|
+ assert len(queries) == 2
|
|
|
+ assert_query_equal(
|
|
|
+ queries[0],
|
|
|
+ (
|
|
|
+ f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 2 "
|
|
|
+ f"AND writer.tenant = {tenant.id} FOR UPDATE"
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ assert_query_equal(
|
|
|
+ queries[1],
|
|
|
+ (
|
|
|
+ f"UPDATE writer SET id=2, value='bar', tenant={tenant.id} "
|
|
|
+ f"WHERE writer.id = 2 AND writer.tenant = {tenant.id} RETURNING {ALL_FIELDS}"
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def test_exists(sql_gateway, tenant):
|
|
|
+ sql_gateway.provider.result.return_value = [{"exists": True}]
|
|
|
+ assert await sql_gateway.exists([Filter(field="id", values=[1])]) is True
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ (
|
|
|
+ f"SELECT true AS exists FROM writer "
|
|
|
+ f"WHERE writer.id = 1 AND writer.tenant = {tenant.id} LIMIT 1"
|
|
|
+ ),
|
|
|
+ )
|