|
|
@@ -0,0 +1,453 @@
|
|
|
+from datetime import datetime, timezone
|
|
|
+from unittest import mock
|
|
|
+
|
|
|
+import pytest
|
|
|
+from sqlalchemy import Column, DateTime, ForeignKey, Integer, MetaData, Table, Text
|
|
|
+
|
|
|
+from base_lib import (
|
|
|
+ assert_query_equal,
|
|
|
+ Conflict,
|
|
|
+ DoesNotExist,
|
|
|
+ FakeSQLDatabase,
|
|
|
+ Filter,
|
|
|
+ PageOptions,
|
|
|
+ SQLGateway,
|
|
|
+)
|
|
|
+
|
|
|
+writer = Table(
|
|
|
+ "writer",
|
|
|
+ MetaData(),
|
|
|
+ Column("id", Integer, primary_key=True, autoincrement=True),
|
|
|
+ Column("value", Text, nullable=False),
|
|
|
+ Column("updated_at", DateTime(timezone=True), nullable=False),
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+book = Table(
|
|
|
+ "book",
|
|
|
+ MetaData(),
|
|
|
+ Column("id", Integer, primary_key=True, autoincrement=True),
|
|
|
+ Column("title", Text, nullable=False),
|
|
|
+ Column(
|
|
|
+ "writer_id",
|
|
|
+ Integer,
|
|
|
+ ForeignKey("writer.id", ondelete="CASCADE", name="book_writer_id_fkey"),
|
|
|
+ nullable=False,
|
|
|
+ ),
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+ALL_FIELDS = "writer.id, writer.value, writer.updated_at"
|
|
|
+BOOK_FIELDS = "book.id, book.title, book.writer_id"
|
|
|
+
|
|
|
+
|
|
|
+class TstSQLGateway(SQLGateway, table=writer):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+class TstRelatedSQLGateway(SQLGateway, table=book):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def sql_gateway():
|
|
|
+ return TstSQLGateway(FakeSQLDatabase())
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def related_sql_gateway():
|
|
|
+ return TstRelatedSQLGateway(FakeSQLDatabase())
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "filters,sql",
|
|
|
+ [
|
|
|
+ ([], ""),
|
|
|
+ ([Filter(field="value", values=[])], " WHERE false"),
|
|
|
+ ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
|
|
|
+ (
|
|
|
+ [Filter(field="value", values=["foo", "bar"])],
|
|
|
+ " WHERE writer.value IN ('foo', 'bar')",
|
|
|
+ ),
|
|
|
+ ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
|
|
|
+ (
|
|
|
+ [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
|
|
|
+ " WHERE writer.id = 1 AND writer.value = 'foo'",
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+)
|
|
|
+async def test_filter(sql_gateway, filters, sql):
|
|
|
+ sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
|
|
|
+ assert await sql_gateway.filter(filters) == [{"id": 2, "value": "foo"}]
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ f"SELECT {ALL_FIELDS} FROM writer{sql}",
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "page_options,sql",
|
|
|
+ [
|
|
|
+ (None, ""),
|
|
|
+ (
|
|
|
+ PageOptions(limit=5, order_by="id"),
|
|
|
+ " ORDER BY writer.id ASC LIMIT 5 OFFSET 0",
|
|
|
+ ),
|
|
|
+ (
|
|
|
+ PageOptions(limit=5, offset=2, order_by="id", ascending=False),
|
|
|
+ " ORDER BY writer.id DESC LIMIT 5 OFFSET 2",
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+)
|
|
|
+async def test_filter_with_pagination(sql_gateway, page_options, sql):
|
|
|
+ sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
|
|
|
+ assert await sql_gateway.filter([], params=page_options) == [
|
|
|
+ {"id": 2, "value": "foo"}
|
|
|
+ ]
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ f"SELECT {ALL_FIELDS} FROM writer{sql}",
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def test_filter_with_pagination_and_filter(sql_gateway):
|
|
|
+ sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
|
|
|
+ assert await sql_gateway.filter(
|
|
|
+ [Filter(field="value", values=["foo"])],
|
|
|
+ params=PageOptions(limit=5, order_by="id"),
|
|
|
+ ) == [{"id": 2, "value": "foo"}]
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ (
|
|
|
+ f"SELECT {ALL_FIELDS} FROM writer "
|
|
|
+ f"WHERE writer.value = 'foo' "
|
|
|
+ f"ORDER BY writer.id ASC LIMIT 5 OFFSET 0"
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "filters,sql",
|
|
|
+ [
|
|
|
+ ([], ""),
|
|
|
+ ([Filter(field="value", values=[])], " WHERE false"),
|
|
|
+ ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
|
|
|
+ (
|
|
|
+ [Filter(field="value", values=["foo", "bar"])],
|
|
|
+ " WHERE writer.value IN ('foo', 'bar')",
|
|
|
+ ),
|
|
|
+ ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
|
|
|
+ (
|
|
|
+ [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
|
|
|
+ " WHERE writer.id = 1 AND writer.value = 'foo'",
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+)
|
|
|
+async def test_count(sql_gateway, filters, sql):
|
|
|
+ sql_gateway.provider.result.return_value = [{"count": 4}]
|
|
|
+ assert await sql_gateway.count(filters) == 4
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ f"SELECT count(*) AS count FROM writer{sql}",
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@mock.patch.object(SQLGateway, "filter")
|
|
|
+async def test_get(filter_m, sql_gateway):
|
|
|
+ filter_m.return_value = [{"id": 2, "value": "foo"}]
|
|
|
+ assert await sql_gateway.get(2) == filter_m.return_value[0]
|
|
|
+ assert len(sql_gateway.provider.queries) == 0
|
|
|
+ filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)
|
|
|
+
|
|
|
+
|
|
|
+@mock.patch.object(SQLGateway, "filter")
|
|
|
+async def test_get_does_not_exist(filter_m, sql_gateway):
|
|
|
+ filter_m.return_value = []
|
|
|
+ assert await sql_gateway.get(2) is None
|
|
|
+ assert len(sql_gateway.provider.queries) == 0
|
|
|
+ filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "record,sql",
|
|
|
+ [
|
|
|
+ ({}, "DEFAULT VALUES"),
|
|
|
+ ({"value": "foo"}, "(value) VALUES ('foo')"),
|
|
|
+ ({"id": None, "value": "foo"}, "(value) VALUES ('foo')"),
|
|
|
+ ({"id": 2, "value": "foo"}, "(id, value) VALUES (2, 'foo')"),
|
|
|
+ ({"value": "foo", "nonexisting": 2}, "(value) VALUES ('foo')"),
|
|
|
+ ],
|
|
|
+)
|
|
|
+async def test_add(sql_gateway, record, sql):
|
|
|
+ records = [{"id": 2, "value": "foo"}]
|
|
|
+ sql_gateway.provider.result.return_value = records
|
|
|
+ assert await sql_gateway.add(record) == records[0]
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ (f"INSERT INTO writer {sql} RETURNING {ALL_FIELDS}"),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "record,if_unmodified_since,sql",
|
|
|
+ [
|
|
|
+ (
|
|
|
+ {"id": 2, "value": "foo"},
|
|
|
+ None,
|
|
|
+ "SET id=2, value='foo' WHERE writer.id = 2",
|
|
|
+ ),
|
|
|
+ ({"id": 2, "other": "foo"}, None, "SET id=2 WHERE writer.id = 2"),
|
|
|
+ (
|
|
|
+ {"id": 2, "value": "foo"},
|
|
|
+ datetime(2010, 1, 1, tzinfo=timezone.utc),
|
|
|
+ (
|
|
|
+ "SET id=2, value='foo' WHERE writer.id = 2 "
|
|
|
+ "AND writer.updated_at = '2010-01-01 00:00:00+00:00'"
|
|
|
+ ),
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+)
|
|
|
+async def test_update(sql_gateway, record, if_unmodified_since, sql):
|
|
|
+ records = [{"id": 2, "value": "foo"}]
|
|
|
+ sql_gateway.provider.result.return_value = records
|
|
|
+ assert await sql_gateway.update(record, if_unmodified_since) == records[0]
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ (f"UPDATE writer {sql} RETURNING {ALL_FIELDS}"),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def test_update_does_not_exist(sql_gateway):
|
|
|
+ sql_gateway.provider.result.return_value = []
|
|
|
+ with pytest.raises(DoesNotExist):
|
|
|
+ await sql_gateway.update({"id": 2})
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+
|
|
|
+
|
|
|
+@mock.patch.object(SQLGateway, "get")
|
|
|
+async def test_update_if_unmodified_since_does_not_exist(get_m, sql_gateway):
|
|
|
+ get_m.return_value = None
|
|
|
+ sql_gateway.provider.result.return_value = []
|
|
|
+ with pytest.raises(DoesNotExist):
|
|
|
+ await sql_gateway.update(
|
|
|
+ {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)
|
|
|
+ )
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ get_m.assert_awaited_once_with(2)
|
|
|
+
|
|
|
+
|
|
|
+@mock.patch.object(SQLGateway, "get")
|
|
|
+async def test_update_if_unmodified_since_conflict(get_m, sql_gateway):
|
|
|
+ get_m.return_value = {"id": 2, "value": "foo"}
|
|
|
+ sql_gateway.provider.result.return_value = []
|
|
|
+ with pytest.raises(Conflict):
|
|
|
+ await sql_gateway.update(
|
|
|
+ {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)
|
|
|
+ )
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ get_m.assert_awaited_once_with(2)
|
|
|
+
|
|
|
+
|
|
|
+async def test_remove(sql_gateway):
|
|
|
+ sql_gateway.provider.result.return_value = [{"id": 2}]
|
|
|
+ assert (await sql_gateway.remove(2)) is True
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def test_remove_does_not_exist(sql_gateway):
|
|
|
+ sql_gateway.provider.result.return_value = []
|
|
|
+ assert (await sql_gateway.remove(2)) is False
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def test_upsert(sql_gateway):
|
|
|
+ record = {"id": 2, "value": "foo"}
|
|
|
+ sql_gateway.provider.result.return_value = [record]
|
|
|
+ assert await sql_gateway.upsert(record) == record
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ (
|
|
|
+ f"INSERT INTO writer (id, value) VALUES (2, 'foo') "
|
|
|
+ f"ON CONFLICT (id) DO UPDATE SET "
|
|
|
+ f"id = %(param_1)s, value = %(param_2)s "
|
|
|
+ f"RETURNING {ALL_FIELDS}"
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@mock.patch.object(SQLGateway, "add")
|
|
|
+async def test_upsert_no_id(add_m, sql_gateway):
|
|
|
+ add_m.return_value = {"id": 5, "value": "foo"}
|
|
|
+ assert await sql_gateway.upsert({"value": "foo"}) == add_m.return_value
|
|
|
+
|
|
|
+ add_m.assert_awaited_once_with({"value": "foo"})
|
|
|
+ assert len(sql_gateway.provider.queries) == 0
|
|
|
+
|
|
|
+
|
|
|
+async def test_get_related_one_to_many(related_sql_gateway: SQLGateway):
|
|
|
+ writers = [{"id": 2}, {"id": 3}]
|
|
|
+ books = [
|
|
|
+ {"id": 3, "title": "x", "writer_id": 2},
|
|
|
+ {"id": 4, "title": "y", "writer_id": 2},
|
|
|
+ ]
|
|
|
+ related_sql_gateway.provider.result.return_value = books
|
|
|
+ await related_sql_gateway._get_related_one_to_many(
|
|
|
+ items=writers,
|
|
|
+ field_name="books",
|
|
|
+ fk_name="writer_id",
|
|
|
+ )
|
|
|
+
|
|
|
+ assert writers == [{"id": 2, "books": books}, {"id": 3, "books": []}]
|
|
|
+ assert len(related_sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ related_sql_gateway.provider.queries[0][0],
|
|
|
+ (
|
|
|
+ "SELECT book.id, book.title, book.writer_id FROM book WHERE book.writer_id IN (2, 3)"
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "books,current_books,expected_queries,query_results",
|
|
|
+ [
|
|
|
+ # no change
|
|
|
+ (
|
|
|
+ [{"id": 3, "title": "x", "writer_id": 2}],
|
|
|
+ [{"id": 3, "title": "x", "writer_id": 2}],
|
|
|
+ [],
|
|
|
+ [],
|
|
|
+ ),
|
|
|
+ # added a book (without an id)
|
|
|
+ (
|
|
|
+ [{"title": "x", "writer_id": 2}],
|
|
|
+ [],
|
|
|
+ [
|
|
|
+ f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}"
|
|
|
+ ],
|
|
|
+ [[{"id": 3, "title": "x", "writer_id": 2}]],
|
|
|
+ ),
|
|
|
+ # added a book (with an id)
|
|
|
+ (
|
|
|
+ [{"id": 3, "title": "x", "writer_id": 2}],
|
|
|
+ [],
|
|
|
+ [
|
|
|
+ f"INSERT INTO book (id, title, writer_id) VALUES (3, 'x', 2) RETURNING {BOOK_FIELDS}"
|
|
|
+ ],
|
|
|
+ [[{"id": 3, "title": "x", "writer_id": 2}]],
|
|
|
+ ),
|
|
|
+ # updated a book
|
|
|
+ (
|
|
|
+ [{"id": 3, "title": "x", "writer_id": 2}],
|
|
|
+ [{"id": 3, "title": "a", "writer_id": 2}],
|
|
|
+ [
|
|
|
+ f"UPDATE book SET id=3, title='x', writer_id=2 WHERE book.id = 3 RETURNING {BOOK_FIELDS}"
|
|
|
+ ],
|
|
|
+ [[{"id": 3, "title": "x", "writer_id": 2}]],
|
|
|
+ ),
|
|
|
+ # replaced a book with a new one
|
|
|
+ (
|
|
|
+ [{"title": "x", "writer_id": 2}],
|
|
|
+ [{"id": 15, "title": "a", "writer_id": 2}],
|
|
|
+ [
|
|
|
+ f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}",
|
|
|
+ "DELETE FROM book WHERE book.id = 15 RETURNING book.id",
|
|
|
+ ],
|
|
|
+ [[{"id": 3, "title": "x", "writer_id": 2}], [{"id": 15}]],
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+)
|
|
|
+async def test_set_related_one_to_many(
|
|
|
+ related_sql_gateway: SQLGateway,
|
|
|
+ books,
|
|
|
+ current_books,
|
|
|
+ expected_queries,
|
|
|
+ query_results,
|
|
|
+):
|
|
|
+ writer = {"id": 2, "books": books}
|
|
|
+ related_sql_gateway.provider.result.side_effect = [current_books] + query_results
|
|
|
+ result = writer.copy()
|
|
|
+ await related_sql_gateway._set_related_one_to_many(
|
|
|
+ item=writer,
|
|
|
+ result=result,
|
|
|
+ field_name="books",
|
|
|
+ fk_name="writer_id",
|
|
|
+ )
|
|
|
+
|
|
|
+ assert result == {
|
|
|
+ "id": 2,
|
|
|
+ "books": [{"id": 3, "title": "x", "writer_id": 2}],
|
|
|
+ }
|
|
|
+ assert len(related_sql_gateway.provider.queries) == len(expected_queries) + 1
|
|
|
+ assert_query_equal(
|
|
|
+ related_sql_gateway.provider.queries[0][0],
|
|
|
+ f"SELECT {BOOK_FIELDS} FROM book WHERE book.writer_id = 2",
|
|
|
+ )
|
|
|
+ for (actual_query,), expected_query in zip(
|
|
|
+ related_sql_gateway.provider.queries[1:], expected_queries
|
|
|
+ ):
|
|
|
+ assert_query_equal(actual_query, expected_query)
|
|
|
+
|
|
|
+
|
|
|
+async def test_update_transactional(sql_gateway):
|
|
|
+ existing = {"id": 2, "value": "foo"}
|
|
|
+ expected = {"id": 2, "value": "bar"}
|
|
|
+ sql_gateway.provider.result.side_effect = ([existing], [expected])
|
|
|
+ actual = await sql_gateway.update_transactional(
|
|
|
+ 2, lambda x: {"id": x["id"], "value": "bar"}
|
|
|
+ )
|
|
|
+ assert actual == expected
|
|
|
+
|
|
|
+ (queries,) = sql_gateway.provider.queries
|
|
|
+ assert len(queries) == 2
|
|
|
+ assert_query_equal(
|
|
|
+ queries[0],
|
|
|
+ f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 2 FOR UPDATE",
|
|
|
+ )
|
|
|
+ assert_query_equal(
|
|
|
+ queries[1],
|
|
|
+ (
|
|
|
+ f"UPDATE writer SET id=2, value='bar' WHERE writer.id = 2 RETURNING {ALL_FIELDS}"
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "filters,sql",
|
|
|
+ [
|
|
|
+ ([], ""),
|
|
|
+ ([Filter(field="value", values=[])], " WHERE false"),
|
|
|
+ ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
|
|
|
+ (
|
|
|
+ [Filter(field="value", values=["foo", "bar"])],
|
|
|
+ " WHERE writer.value IN ('foo', 'bar')",
|
|
|
+ ),
|
|
|
+ ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
|
|
|
+ (
|
|
|
+ [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
|
|
|
+ " WHERE writer.id = 1 AND writer.value = 'foo'",
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+)
|
|
|
+async def test_exists(sql_gateway, filters, sql):
|
|
|
+ sql_gateway.provider.result.return_value = [{"exists": True}]
|
|
|
+ assert await sql_gateway.exists(filters) is True
|
|
|
+ assert len(sql_gateway.provider.queries) == 1
|
|
|
+ assert_query_equal(
|
|
|
+ sql_gateway.provider.queries[0][0],
|
|
|
+ f"SELECT true AS exists FROM writer{sql} LIMIT 1",
|
|
|
+ )
|