| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453 | from datetime import datetime, timezonefrom unittest import mockimport pytestfrom sqlalchemy import Column, DateTime, ForeignKey, Integer, MetaData, Table, Textfrom base_lib import (    assert_query_equal,    Conflict,    DoesNotExist,    FakeSQLDatabase,    Filter,    PageOptions,    SQLGateway,)writer = Table(    "writer",    MetaData(),    Column("id", Integer, primary_key=True, autoincrement=True),    Column("value", Text, nullable=False),    Column("updated_at", DateTime(timezone=True), nullable=False),)book = Table(    "book",    MetaData(),    Column("id", Integer, primary_key=True, autoincrement=True),    Column("title", Text, nullable=False),    Column(        "writer_id",        Integer,        ForeignKey("writer.id", ondelete="CASCADE", name="book_writer_id_fkey"),        nullable=False,    ),)ALL_FIELDS = "writer.id, writer.value, writer.updated_at"BOOK_FIELDS = "book.id, book.title, book.writer_id"class TstSQLGateway(SQLGateway, table=writer):    passclass TstRelatedSQLGateway(SQLGateway, table=book):    pass@pytest.fixturedef sql_gateway():    return TstSQLGateway(FakeSQLDatabase())@pytest.fixturedef related_sql_gateway():    return TstRelatedSQLGateway(FakeSQLDatabase())@pytest.mark.parametrize(    "filters,sql",    [        ([], ""),        ([Filter(field="value", values=[])], " WHERE false"),        ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),        (            [Filter(field="value", values=["foo", "bar"])],            " WHERE writer.value IN ('foo', 'bar')",        ),        ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),        (            [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],            " WHERE writer.id = 1 AND writer.value = 'foo'",        ),    ],)async def test_filter(sql_gateway, filters, sql):    sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]    assert await sql_gateway.filter(filters) == [{"id": 2, "value": "foo"}]    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        f"SELECT {ALL_FIELDS} FROM writer{sql}",    )@pytest.mark.parametrize(    "page_options,sql",    [        (None, ""),        (            PageOptions(limit=5, order_by="id"),            " ORDER BY writer.id ASC LIMIT 5 OFFSET 0",        ),        (            PageOptions(limit=5, offset=2, order_by="id", ascending=False),            " ORDER BY writer.id DESC LIMIT 5 OFFSET 2",        ),    ],)async def test_filter_with_pagination(sql_gateway, page_options, sql):    sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]    assert await sql_gateway.filter([], params=page_options) == [        {"id": 2, "value": "foo"}    ]    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        f"SELECT {ALL_FIELDS} FROM writer{sql}",    )async def test_filter_with_pagination_and_filter(sql_gateway):    sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]    assert await sql_gateway.filter(        [Filter(field="value", values=["foo"])],        params=PageOptions(limit=5, order_by="id"),    ) == [{"id": 2, "value": "foo"}]    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        (            f"SELECT {ALL_FIELDS} FROM writer "            f"WHERE writer.value = 'foo' "            f"ORDER BY writer.id ASC LIMIT 5 OFFSET 0"        ),    )@pytest.mark.parametrize(    "filters,sql",    [        ([], ""),        ([Filter(field="value", values=[])], " WHERE false"),        ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),        (            [Filter(field="value", values=["foo", "bar"])],            " WHERE writer.value IN ('foo', 'bar')",        ),        ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),        (            [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],            " WHERE writer.id = 1 AND writer.value = 'foo'",        ),    ],)async def test_count(sql_gateway, filters, sql):    sql_gateway.provider.result.return_value = [{"count": 4}]    assert await sql_gateway.count(filters) == 4    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        f"SELECT count(*) AS count FROM writer{sql}",    )@mock.patch.object(SQLGateway, "filter")async def test_get(filter_m, sql_gateway):    filter_m.return_value = [{"id": 2, "value": "foo"}]    assert await sql_gateway.get(2) == filter_m.return_value[0]    assert len(sql_gateway.provider.queries) == 0    filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)@mock.patch.object(SQLGateway, "filter")async def test_get_does_not_exist(filter_m, sql_gateway):    filter_m.return_value = []    assert await sql_gateway.get(2) is None    assert len(sql_gateway.provider.queries) == 0    filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)@pytest.mark.parametrize(    "record,sql",    [        ({}, "DEFAULT VALUES"),        ({"value": "foo"}, "(value) VALUES ('foo')"),        ({"id": None, "value": "foo"}, "(value) VALUES ('foo')"),        ({"id": 2, "value": "foo"}, "(id, value) VALUES (2, 'foo')"),        ({"value": "foo", "nonexisting": 2}, "(value) VALUES ('foo')"),    ],)async def test_add(sql_gateway, record, sql):    records = [{"id": 2, "value": "foo"}]    sql_gateway.provider.result.return_value = records    assert await sql_gateway.add(record) == records[0]    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        (f"INSERT INTO writer {sql} RETURNING {ALL_FIELDS}"),    )@pytest.mark.parametrize(    "record,if_unmodified_since,sql",    [        (            {"id": 2, "value": "foo"},            None,            "SET id=2, value='foo' WHERE writer.id = 2",        ),        ({"id": 2, "other": "foo"}, None, "SET id=2 WHERE writer.id = 2"),        (            {"id": 2, "value": "foo"},            datetime(2010, 1, 1, tzinfo=timezone.utc),            (                "SET id=2, value='foo' WHERE writer.id = 2 "                "AND writer.updated_at = '2010-01-01 00:00:00+00:00'"            ),        ),    ],)async def test_update(sql_gateway, record, if_unmodified_since, sql):    records = [{"id": 2, "value": "foo"}]    sql_gateway.provider.result.return_value = records    assert await sql_gateway.update(record, if_unmodified_since) == records[0]    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        (f"UPDATE writer {sql} RETURNING {ALL_FIELDS}"),    )async def test_update_does_not_exist(sql_gateway):    sql_gateway.provider.result.return_value = []    with pytest.raises(DoesNotExist):        await sql_gateway.update({"id": 2})    assert len(sql_gateway.provider.queries) == 1@mock.patch.object(SQLGateway, "get")async def test_update_if_unmodified_since_does_not_exist(get_m, sql_gateway):    get_m.return_value = None    sql_gateway.provider.result.return_value = []    with pytest.raises(DoesNotExist):        await sql_gateway.update(            {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)        )    assert len(sql_gateway.provider.queries) == 1    get_m.assert_awaited_once_with(2)@mock.patch.object(SQLGateway, "get")async def test_update_if_unmodified_since_conflict(get_m, sql_gateway):    get_m.return_value = {"id": 2, "value": "foo"}    sql_gateway.provider.result.return_value = []    with pytest.raises(Conflict):        await sql_gateway.update(            {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)        )    assert len(sql_gateway.provider.queries) == 1    get_m.assert_awaited_once_with(2)async def test_remove(sql_gateway):    sql_gateway.provider.result.return_value = [{"id": 2}]    assert (await sql_gateway.remove(2)) is True    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),    )async def test_remove_does_not_exist(sql_gateway):    sql_gateway.provider.result.return_value = []    assert (await sql_gateway.remove(2)) is False    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),    )async def test_upsert(sql_gateway):    record = {"id": 2, "value": "foo"}    sql_gateway.provider.result.return_value = [record]    assert await sql_gateway.upsert(record) == record    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        (            f"INSERT INTO writer (id, value) VALUES (2, 'foo') "            f"ON CONFLICT (id) DO UPDATE SET "            f"id = %(param_1)s, value = %(param_2)s "            f"RETURNING {ALL_FIELDS}"        ),    )@mock.patch.object(SQLGateway, "add")async def test_upsert_no_id(add_m, sql_gateway):    add_m.return_value = {"id": 5, "value": "foo"}    assert await sql_gateway.upsert({"value": "foo"}) == add_m.return_value    add_m.assert_awaited_once_with({"value": "foo"})    assert len(sql_gateway.provider.queries) == 0async def test_get_related_one_to_many(related_sql_gateway: SQLGateway):    writers = [{"id": 2}, {"id": 3}]    books = [        {"id": 3, "title": "x", "writer_id": 2},        {"id": 4, "title": "y", "writer_id": 2},    ]    related_sql_gateway.provider.result.return_value = books    await related_sql_gateway._get_related_one_to_many(        items=writers,        field_name="books",        fk_name="writer_id",    )    assert writers == [{"id": 2, "books": books}, {"id": 3, "books": []}]    assert len(related_sql_gateway.provider.queries) == 1    assert_query_equal(        related_sql_gateway.provider.queries[0][0],        (            "SELECT book.id, book.title, book.writer_id FROM book WHERE book.writer_id IN (2, 3)"        ),    )@pytest.mark.parametrize(    "books,current_books,expected_queries,query_results",    [        # no change        (            [{"id": 3, "title": "x", "writer_id": 2}],            [{"id": 3, "title": "x", "writer_id": 2}],            [],            [],        ),        # added a book (without an id)        (            [{"title": "x", "writer_id": 2}],            [],            [                f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}"            ],            [[{"id": 3, "title": "x", "writer_id": 2}]],        ),        # added a book (with an id)        (            [{"id": 3, "title": "x", "writer_id": 2}],            [],            [                f"INSERT INTO book (id, title, writer_id) VALUES (3, 'x', 2) RETURNING {BOOK_FIELDS}"            ],            [[{"id": 3, "title": "x", "writer_id": 2}]],        ),        # updated a book        (            [{"id": 3, "title": "x", "writer_id": 2}],            [{"id": 3, "title": "a", "writer_id": 2}],            [                f"UPDATE book SET id=3, title='x', writer_id=2 WHERE book.id = 3 RETURNING {BOOK_FIELDS}"            ],            [[{"id": 3, "title": "x", "writer_id": 2}]],        ),        # replaced a book with a new one        (            [{"title": "x", "writer_id": 2}],            [{"id": 15, "title": "a", "writer_id": 2}],            [                f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}",                "DELETE FROM book WHERE book.id = 15 RETURNING book.id",            ],            [[{"id": 3, "title": "x", "writer_id": 2}], [{"id": 15}]],        ),    ],)async def test_set_related_one_to_many(    related_sql_gateway: SQLGateway,    books,    current_books,    expected_queries,    query_results,):    writer = {"id": 2, "books": books}    related_sql_gateway.provider.result.side_effect = [current_books] + query_results    result = writer.copy()    await related_sql_gateway._set_related_one_to_many(        item=writer,        result=result,        field_name="books",        fk_name="writer_id",    )    assert result == {        "id": 2,        "books": [{"id": 3, "title": "x", "writer_id": 2}],    }    assert len(related_sql_gateway.provider.queries) == len(expected_queries) + 1    assert_query_equal(        related_sql_gateway.provider.queries[0][0],        f"SELECT {BOOK_FIELDS} FROM book WHERE book.writer_id = 2",    )    for (actual_query,), expected_query in zip(        related_sql_gateway.provider.queries[1:], expected_queries    ):        assert_query_equal(actual_query, expected_query)async def test_update_transactional(sql_gateway):    existing = {"id": 2, "value": "foo"}    expected = {"id": 2, "value": "bar"}    sql_gateway.provider.result.side_effect = ([existing], [expected])    actual = await sql_gateway.update_transactional(        2, lambda x: {"id": x["id"], "value": "bar"}    )    assert actual == expected    (queries,) = sql_gateway.provider.queries    assert len(queries) == 2    assert_query_equal(        queries[0],        f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 2 FOR UPDATE",    )    assert_query_equal(        queries[1],        (            f"UPDATE writer SET id=2, value='bar' WHERE writer.id = 2 RETURNING {ALL_FIELDS}"        ),    )@pytest.mark.parametrize(    "filters,sql",    [        ([], ""),        ([Filter(field="value", values=[])], " WHERE false"),        ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),        (            [Filter(field="value", values=["foo", "bar"])],            " WHERE writer.value IN ('foo', 'bar')",        ),        ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),        (            [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],            " WHERE writer.id = 1 AND writer.value = 'foo'",        ),    ],)async def test_exists(sql_gateway, filters, sql):    sql_gateway.provider.result.return_value = [{"exists": True}]    assert await sql_gateway.exists(filters) is True    assert len(sql_gateway.provider.queries) == 1    assert_query_equal(        sql_gateway.provider.queries[0][0],        f"SELECT true AS exists FROM writer{sql} LIMIT 1",    )
 |