| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 | import pytestfrom sqlalchemy import Columnfrom sqlalchemy import DateTimefrom sqlalchemy import Integerfrom sqlalchemy import MetaDatafrom sqlalchemy import Tablefrom sqlalchemy import Textfrom clean_python import ctxfrom clean_python import Filterfrom clean_python import Tenantfrom clean_python.sql import SQLGatewayfrom clean_python.sql.testing import assert_query_equalfrom clean_python.sql.testing import FakeSQLDatabasewriter = Table(    "writer",    MetaData(),    Column("id", Integer, primary_key=True, autoincrement=True),    Column("value", Text, nullable=False),    Column("updated_at", DateTime(timezone=True), nullable=False),    Column("tenant", Integer, nullable=False),)ALL_FIELDS = "writer.id, writer.value, writer.updated_at, writer.tenant"class TstSQLGateway(SQLGateway, table=writer, multitenant=True):    pass@pytest.fixturedef sql_gateway():    return TstSQLGateway(FakeSQLDatabase())@pytest.fixturedef tenant():    ctx.tenant = Tenant(id=2, name="foo")    return ctx.tenantasync def test_no_tenant(sql_gateway):    with pytest.raises(RuntimeError):        await sql_gateway.filter([])    assert len(sql_gateway.provider.queries) == 0async def test_missing_tenant_column():    table = Table(        "notenant",        MetaData(),        Column("id", Integer, primary_key=True, autoincrement=True),    )    with pytest.raises(ValueError):        class Foo(SQLGateway, table=table, multitenant=True):            passasync def test_filter(sql_gateway, tenant):    sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]    assert await sql_gateway.filter([Filter(field="id", values=[1])]) == [        {"id": 2, "value": "foo"}    ]    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 1 AND writer.tenant = {tenant.id}",    )async def test_count(sql_gateway, tenant):    sql_gateway.provider.result.return_value = [{"count": 4}]    assert await sql_gateway.count([Filter(field="id", values=[1])]) == 4    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        f"SELECT count(*) AS count FROM writer WHERE writer.id = 1 AND writer.tenant = {tenant.id}",    )async def test_add(sql_gateway, tenant):    records = [{"id": 2, "value": "foo", "tenant": tenant.id}]    sql_gateway.provider.result.return_value = records    assert await sql_gateway.add({"value": "foo"}) == records[0]    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        (            f"INSERT INTO writer (value, tenant) VALUES ('foo', {tenant.id}) RETURNING {ALL_FIELDS}"        ),    )async def test_update(sql_gateway, tenant):    records = [{"id": 2, "value": "foo", "tenant": tenant.id}]    sql_gateway.provider.result.return_value = records    assert await sql_gateway.update({"id": 2, "value": "foo"}) == records[0]    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        (            f"UPDATE writer SET id=2, value='foo', tenant={tenant.id} "            f"WHERE writer.id = 2 AND writer.tenant = {tenant.id} "            f"RETURNING {ALL_FIELDS}"        ),    )async def test_remove(sql_gateway, tenant):    sql_gateway.provider.result.return_value = [{"id": 2}]    assert (await sql_gateway.remove(2)) is True    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        (            f"DELETE FROM writer WHERE writer.id = 2 AND writer.tenant = {tenant.id} "            f"RETURNING writer.id"        ),    )async def test_upsert(sql_gateway, tenant):    record = {"id": 2, "value": "foo", "tenant": tenant.id}    sql_gateway.provider.result.return_value = [record]    assert await sql_gateway.upsert({"id": 2, "value": "foo"}) == record    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        (            f"INSERT INTO writer (id, value, tenant) VALUES (2, 'foo', {tenant.id}) "            f"ON CONFLICT (id, tenant) DO UPDATE SET "            f"id = %(param_1)s, value = %(param_2)s, tenant = %(param_3)s "            f"RETURNING {ALL_FIELDS}"        ),    )async def test_update_transactional(sql_gateway, tenant):    existing = {"id": 2, "value": "foo", "tenant": tenant.id}    expected = {"id": 2, "value": "bar", "tenant": tenant.id}    sql_gateway.provider.result.side_effect = ([existing], [expected])    actual = await sql_gateway.update_transactional(        2, lambda x: {"id": x["id"], "value": "bar"}    )    assert actual == expected    (queries,) = sql_gateway.provider.queries    assert len(queries) == 2    assert_query_equal(        queries[0],        (            f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 2 "            f"AND writer.tenant = {tenant.id} FOR UPDATE"        ),    )    assert_query_equal(        queries[1],        (            f"UPDATE writer SET id=2, value='bar', tenant={tenant.id} "            f"WHERE writer.id = 2 AND writer.tenant = {tenant.id} RETURNING {ALL_FIELDS}"        ),    )async def test_exists(sql_gateway, tenant):    sql_gateway.provider.result.return_value = [{"exists": True}]    assert await sql_gateway.exists([Filter(field="id", values=[1])]) is True    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        (            f"SELECT true AS exists FROM writer "            f"WHERE writer.id = 1 AND writer.tenant = {tenant.id} LIMIT 1"        ),    )
 |