| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458 | from datetime import datetimefrom datetime import timezonefrom unittest import mockimport pytestfrom sqlalchemy import Columnfrom sqlalchemy import DateTimefrom sqlalchemy import ForeignKeyfrom sqlalchemy import Integerfrom sqlalchemy import MetaDatafrom sqlalchemy import Tablefrom sqlalchemy import Textfrom clean_python import Conflictfrom clean_python import DoesNotExistfrom clean_python import Filterfrom clean_python import PageOptionsfrom clean_python.sql import SQLGatewayfrom clean_python.sql.testing import assert_query_equalfrom clean_python.sql.testing import FakeSQLDatabasewriter = Table(    "writer",    MetaData(),    Column("id", Integer, primary_key=True, autoincrement=True),    Column("value", Text, nullable=False),    Column("updated_at", DateTime(timezone=True), nullable=False),)book = Table(    "book",    MetaData(),    Column("id", Integer, primary_key=True, autoincrement=True),    Column("title", Text, nullable=False),    Column(        "writer_id",        Integer,        ForeignKey("writer.id", ondelete="CASCADE", name="book_writer_id_fkey"),        nullable=False,    ),)ALL_FIELDS = "writer.id, writer.value, writer.updated_at"BOOK_FIELDS = "book.id, book.title, book.writer_id"class TstSQLGateway(SQLGateway, table=writer):    passclass TstRelatedSQLGateway(SQLGateway, table=book):    pass@pytest.fixturedef sql_gateway():    return TstSQLGateway(FakeSQLDatabase())@pytest.fixturedef related_sql_gateway():    return TstRelatedSQLGateway(FakeSQLDatabase())@pytest.mark.parametrize(    "filters,sql",    [        ([], ""),        ([Filter(field="value", values=[])], " WHERE false"),        ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),        (            [Filter(field="value", values=["foo", "bar"])],            " WHERE writer.value IN ('foo', 'bar')",        ),        ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),        (            [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],            " WHERE writer.id = 1 AND writer.value = 'foo'",        ),    ],)async def test_filter(sql_gateway, filters, sql):    sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]    assert await sql_gateway.filter(filters) == [{"id": 2, "value": "foo"}]    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        f"SELECT {ALL_FIELDS} FROM writer{sql}",    )@pytest.mark.parametrize(    "page_options,sql",    [        (None, ""),        (            PageOptions(limit=5, order_by="id"),            " ORDER BY writer.id ASC LIMIT 5 OFFSET 0",        ),        (            PageOptions(limit=5, offset=2, order_by="id", ascending=False),            " ORDER BY writer.id DESC LIMIT 5 OFFSET 2",        ),    ],)async def test_filter_with_pagination(sql_gateway, page_options, sql):    sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]    assert await sql_gateway.filter([], params=page_options) == [        {"id": 2, "value": "foo"}    ]    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        f"SELECT {ALL_FIELDS} FROM writer{sql}",    )async def test_filter_with_pagination_and_filter(sql_gateway):    sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]    assert await sql_gateway.filter(        [Filter(field="value", values=["foo"])],        params=PageOptions(limit=5, order_by="id"),    ) == [{"id": 2, "value": "foo"}]    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        (            f"SELECT {ALL_FIELDS} FROM writer "            f"WHERE writer.value = 'foo' "            f"ORDER BY writer.id ASC LIMIT 5 OFFSET 0"        ),    )@pytest.mark.parametrize(    "filters,sql",    [        ([], ""),        ([Filter(field="value", values=[])], " WHERE false"),        ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),        (            [Filter(field="value", values=["foo", "bar"])],            " WHERE writer.value IN ('foo', 'bar')",        ),        ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),        (            [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],            " WHERE writer.id = 1 AND writer.value = 'foo'",        ),    ],)async def test_count(sql_gateway, filters, sql):    sql_gateway.provider.result.return_value = [{"count": 4}]    assert await sql_gateway.count(filters) == 4    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        f"SELECT count(*) AS count FROM writer{sql}",    )@mock.patch.object(SQLGateway, "filter")async def test_get(filter_m, sql_gateway):    filter_m.return_value = [{"id": 2, "value": "foo"}]    assert await sql_gateway.get(2) == filter_m.return_value[0]    assert len(sql_gateway.provider.queries) == 0    filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)@mock.patch.object(SQLGateway, "filter")async def test_get_does_not_exist(filter_m, sql_gateway):    filter_m.return_value = []    assert await sql_gateway.get(2) is None    assert len(sql_gateway.provider.queries) == 0    filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)@pytest.mark.parametrize(    "record,sql",    [        ({}, "DEFAULT VALUES"),        ({"value": "foo"}, "(value) VALUES ('foo')"),        ({"id": None, "value": "foo"}, "(value) VALUES ('foo')"),        ({"id": 2, "value": "foo"}, "(id, value) VALUES (2, 'foo')"),        ({"value": "foo", "nonexisting": 2}, "(value) VALUES ('foo')"),    ],)async def test_add(sql_gateway, record, sql):    records = [{"id": 2, "value": "foo"}]    sql_gateway.provider.result.return_value = records    assert await sql_gateway.add(record) == records[0]    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        (f"INSERT INTO writer {sql} RETURNING {ALL_FIELDS}"),    )@pytest.mark.parametrize(    "record,if_unmodified_since,sql",    [        (            {"id": 2, "value": "foo"},            None,            "SET id=2, value='foo' WHERE writer.id = 2",        ),        ({"id": 2, "other": "foo"}, None, "SET id=2 WHERE writer.id = 2"),        (            {"id": 2, "value": "foo"},            datetime(2010, 1, 1, tzinfo=timezone.utc),            (                "SET id=2, value='foo' WHERE writer.id = 2 "                "AND writer.updated_at = '2010-01-01 00:00:00+00:00'"            ),        ),    ],)async def test_update(sql_gateway, record, if_unmodified_since, sql):    records = [{"id": 2, "value": "foo"}]    sql_gateway.provider.result.return_value = records    assert await sql_gateway.update(record, if_unmodified_since) == records[0]    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        (f"UPDATE writer {sql} RETURNING {ALL_FIELDS}"),    )async def test_update_does_not_exist(sql_gateway):    sql_gateway.provider.result.return_value = []    with pytest.raises(DoesNotExist):        await sql_gateway.update({"id": 2})    assert len(sql_gateway.provider.queries) == 1@mock.patch.object(SQLGateway, "get")async def test_update_if_unmodified_since_does_not_exist(get_m, sql_gateway):    get_m.return_value = None    sql_gateway.provider.result.return_value = []    with pytest.raises(DoesNotExist):        await sql_gateway.update(            {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)        )    assert len(sql_gateway.provider.queries) == 1    get_m.assert_awaited_once_with(2)@mock.patch.object(SQLGateway, "get")async def test_update_if_unmodified_since_conflict(get_m, sql_gateway):    get_m.return_value = {"id": 2, "value": "foo"}    sql_gateway.provider.result.return_value = []    with pytest.raises(Conflict):        await sql_gateway.update(            {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)        )    assert len(sql_gateway.provider.queries) == 1    get_m.assert_awaited_once_with(2)async def test_remove(sql_gateway):    sql_gateway.provider.result.return_value = [{"id": 2}]    assert (await sql_gateway.remove(2)) is True    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),    )async def test_remove_does_not_exist(sql_gateway):    sql_gateway.provider.result.return_value = []    assert (await sql_gateway.remove(2)) is False    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),    )async def test_upsert(sql_gateway):    record = {"id": 2, "value": "foo"}    sql_gateway.provider.result.return_value = [record]    assert await sql_gateway.upsert(record) == record    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        (            f"INSERT INTO writer (id, value) VALUES (2, 'foo') "            f"ON CONFLICT (id) DO UPDATE SET "            f"id = %(param_1)s, value = %(param_2)s "            f"RETURNING {ALL_FIELDS}"        ),    )@mock.patch.object(SQLGateway, "add")async def test_upsert_no_id(add_m, sql_gateway):    add_m.return_value = {"id": 5, "value": "foo"}    assert await sql_gateway.upsert({"value": "foo"}) == add_m.return_value    add_m.assert_awaited_once_with({"value": "foo"})    assert len(sql_gateway.provider.queries) == 0async def test_get_related_one_to_many(related_sql_gateway: SQLGateway):    writers = [{"id": 2}, {"id": 3}]    books = [        {"id": 3, "title": "x", "writer_id": 2},        {"id": 4, "title": "y", "writer_id": 2},    ]    related_sql_gateway.provider.result.return_value = books    await related_sql_gateway._get_related_one_to_many(        items=writers,        field_name="books",        fk_name="writer_id",    )    assert writers == [{"id": 2, "books": books}, {"id": 3, "books": []}]    assert len(related_sql_gateway.provider.queries) == 1    assert_query_equal(        related_sql_gateway.provider.queries[0][0],        (            "SELECT book.id, book.title, book.writer_id FROM book WHERE book.writer_id IN (2, 3)"        ),    )@pytest.mark.parametrize(    "books,current_books,expected_queries,query_results",    [        # no change        (            [{"id": 3, "title": "x", "writer_id": 2}],            [{"id": 3, "title": "x", "writer_id": 2}],            [],            [],        ),        # added a book (without an id)        (            [{"title": "x", "writer_id": 2}],            [],            [                f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}"            ],            [[{"id": 3, "title": "x", "writer_id": 2}]],        ),        # added a book (with an id)        (            [{"id": 3, "title": "x", "writer_id": 2}],            [],            [                f"INSERT INTO book (id, title, writer_id) VALUES (3, 'x', 2) RETURNING {BOOK_FIELDS}"            ],            [[{"id": 3, "title": "x", "writer_id": 2}]],        ),        # updated a book        (            [{"id": 3, "title": "x", "writer_id": 2}],            [{"id": 3, "title": "a", "writer_id": 2}],            [                f"UPDATE book SET id=3, title='x', writer_id=2 WHERE book.id = 3 RETURNING {BOOK_FIELDS}"            ],            [[{"id": 3, "title": "x", "writer_id": 2}]],        ),        # replaced a book with a new one        (            [{"title": "x", "writer_id": 2}],            [{"id": 15, "title": "a", "writer_id": 2}],            [                f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}",                "DELETE FROM book WHERE book.id = 15 RETURNING book.id",            ],            [[{"id": 3, "title": "x", "writer_id": 2}], [{"id": 15}]],        ),    ],)async def test_set_related_one_to_many(    related_sql_gateway: SQLGateway,    books,    current_books,    expected_queries,    query_results,):    writer = {"id": 2, "books": books}    related_sql_gateway.provider.result.side_effect = [current_books] + query_results    result = writer.copy()    await related_sql_gateway._set_related_one_to_many(        item=writer,        result=result,        field_name="books",        fk_name="writer_id",    )    assert result == {        "id": 2,        "books": [{"id": 3, "title": "x", "writer_id": 2}],    }    assert len(related_sql_gateway.provider.queries) == len(expected_queries) + 1    assert_query_equal(        related_sql_gateway.provider.queries[0][0],        f"SELECT {BOOK_FIELDS} FROM book WHERE book.writer_id = 2",    )    for (actual_query,), expected_query in zip(        related_sql_gateway.provider.queries[1:], expected_queries    ):        assert_query_equal(actual_query, expected_query)async def test_update_transactional(sql_gateway):    existing = {"id": 2, "value": "foo"}    expected = {"id": 2, "value": "bar"}    sql_gateway.provider.result.side_effect = ([existing], [expected])    actual = await sql_gateway.update_transactional(        2, lambda x: {"id": x["id"], "value": "bar"}    )    assert actual == expected    (queries,) = sql_gateway.provider.queries    assert len(queries) == 2    assert_query_equal(        queries[0],        f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 2 FOR UPDATE",    )    assert_query_equal(        queries[1],        (            f"UPDATE writer SET id=2, value='bar' WHERE writer.id = 2 RETURNING {ALL_FIELDS}"        ),    )@pytest.mark.parametrize(    "filters,sql",    [        ([], ""),        ([Filter(field="value", values=[])], " WHERE false"),        ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),        (            [Filter(field="value", values=["foo", "bar"])],            " WHERE writer.value IN ('foo', 'bar')",        ),        ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),        (            [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],            " WHERE writer.id = 1 AND writer.value = 'foo'",        ),    ],)async def test_exists(sql_gateway, filters, sql):    sql_gateway.provider.result.return_value = [{"exists": True}]    assert await sql_gateway.exists(filters) is True    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        f"SELECT true AS exists FROM writer{sql} LIMIT 1",    )
 |