123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458 |
- from datetime import datetime
- from datetime import timezone
- from unittest import mock
- import pytest
- from sqlalchemy import Column
- from sqlalchemy import DateTime
- from sqlalchemy import ForeignKey
- from sqlalchemy import Integer
- from sqlalchemy import MetaData
- from sqlalchemy import Table
- from sqlalchemy import Text
- from clean_python import Conflict
- from clean_python import DoesNotExist
- from clean_python import Filter
- from clean_python import PageOptions
- from clean_python.sql import SQLGateway
- from clean_python.sql.testing import assert_query_equal
- from clean_python.sql.testing import FakeSQLDatabase
- writer = Table(
- "writer",
- MetaData(),
- Column("id", Integer, primary_key=True, autoincrement=True),
- Column("value", Text, nullable=False),
- Column("updated_at", DateTime(timezone=True), nullable=False),
- )
- book = Table(
- "book",
- MetaData(),
- Column("id", Integer, primary_key=True, autoincrement=True),
- Column("title", Text, nullable=False),
- Column(
- "writer_id",
- Integer,
- ForeignKey("writer.id", ondelete="CASCADE", name="book_writer_id_fkey"),
- nullable=False,
- ),
- )
- ALL_FIELDS = "writer.id, writer.value, writer.updated_at"
- BOOK_FIELDS = "book.id, book.title, book.writer_id"
- class TstSQLGateway(SQLGateway, table=writer):
- pass
- class TstRelatedSQLGateway(SQLGateway, table=book):
- pass
- @pytest.fixture
- def sql_gateway():
- return TstSQLGateway(FakeSQLDatabase())
- @pytest.fixture
- def related_sql_gateway():
- return TstRelatedSQLGateway(FakeSQLDatabase())
- @pytest.mark.parametrize(
- "filters,sql",
- [
- ([], ""),
- ([Filter(field="value", values=[])], " WHERE false"),
- ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
- (
- [Filter(field="value", values=["foo", "bar"])],
- " WHERE writer.value IN ('foo', 'bar')",
- ),
- ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
- (
- [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
- " WHERE writer.id = 1 AND writer.value = 'foo'",
- ),
- ],
- )
- async def test_filter(sql_gateway, filters, sql):
- sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
- assert await sql_gateway.filter(filters) == [{"id": 2, "value": "foo"}]
- assert len(sql_gateway.provider.queries) == 1
- assert_query_equal(
- sql_gateway.provider.queries[0][0],
- f"SELECT {ALL_FIELDS} FROM writer{sql}",
- )
- @pytest.mark.parametrize(
- "page_options,sql",
- [
- (None, ""),
- (
- PageOptions(limit=5, order_by="id"),
- " ORDER BY writer.id ASC LIMIT 5 OFFSET 0",
- ),
- (
- PageOptions(limit=5, offset=2, order_by="id", ascending=False),
- " ORDER BY writer.id DESC LIMIT 5 OFFSET 2",
- ),
- ],
- )
- async def test_filter_with_pagination(sql_gateway, page_options, sql):
- sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
- assert await sql_gateway.filter([], params=page_options) == [
- {"id": 2, "value": "foo"}
- ]
- assert len(sql_gateway.provider.queries) == 1
- assert_query_equal(
- sql_gateway.provider.queries[0][0],
- f"SELECT {ALL_FIELDS} FROM writer{sql}",
- )
- async def test_filter_with_pagination_and_filter(sql_gateway):
- sql_gateway.provider.result.return_value = [{"id": 2, "value": "foo"}]
- assert await sql_gateway.filter(
- [Filter(field="value", values=["foo"])],
- params=PageOptions(limit=5, order_by="id"),
- ) == [{"id": 2, "value": "foo"}]
- assert len(sql_gateway.provider.queries) == 1
- assert_query_equal(
- sql_gateway.provider.queries[0][0],
- (
- f"SELECT {ALL_FIELDS} FROM writer "
- f"WHERE writer.value = 'foo' "
- f"ORDER BY writer.id ASC LIMIT 5 OFFSET 0"
- ),
- )
- @pytest.mark.parametrize(
- "filters,sql",
- [
- ([], ""),
- ([Filter(field="value", values=[])], " WHERE false"),
- ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
- (
- [Filter(field="value", values=["foo", "bar"])],
- " WHERE writer.value IN ('foo', 'bar')",
- ),
- ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
- (
- [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
- " WHERE writer.id = 1 AND writer.value = 'foo'",
- ),
- ],
- )
- async def test_count(sql_gateway, filters, sql):
- sql_gateway.provider.result.return_value = [{"count": 4}]
- assert await sql_gateway.count(filters) == 4
- assert len(sql_gateway.provider.queries) == 1
- assert_query_equal(
- sql_gateway.provider.queries[0][0],
- f"SELECT count(*) AS count FROM writer{sql}",
- )
- @mock.patch.object(SQLGateway, "filter")
- async def test_get(filter_m, sql_gateway):
- filter_m.return_value = [{"id": 2, "value": "foo"}]
- assert await sql_gateway.get(2) == filter_m.return_value[0]
- assert len(sql_gateway.provider.queries) == 0
- filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)
- @mock.patch.object(SQLGateway, "filter")
- async def test_get_does_not_exist(filter_m, sql_gateway):
- filter_m.return_value = []
- assert await sql_gateway.get(2) is None
- assert len(sql_gateway.provider.queries) == 0
- filter_m.assert_awaited_once_with([Filter(field="id", values=[2])], params=None)
- @pytest.mark.parametrize(
- "record,sql",
- [
- ({}, "DEFAULT VALUES"),
- ({"value": "foo"}, "(value) VALUES ('foo')"),
- ({"id": None, "value": "foo"}, "(value) VALUES ('foo')"),
- ({"id": 2, "value": "foo"}, "(id, value) VALUES (2, 'foo')"),
- ({"value": "foo", "nonexisting": 2}, "(value) VALUES ('foo')"),
- ],
- )
- async def test_add(sql_gateway, record, sql):
- records = [{"id": 2, "value": "foo"}]
- sql_gateway.provider.result.return_value = records
- assert await sql_gateway.add(record) == records[0]
- assert len(sql_gateway.provider.queries) == 1
- assert_query_equal(
- sql_gateway.provider.queries[0][0],
- (f"INSERT INTO writer {sql} RETURNING {ALL_FIELDS}"),
- )
- @pytest.mark.parametrize(
- "record,if_unmodified_since,sql",
- [
- (
- {"id": 2, "value": "foo"},
- None,
- "SET id=2, value='foo' WHERE writer.id = 2",
- ),
- ({"id": 2, "other": "foo"}, None, "SET id=2 WHERE writer.id = 2"),
- (
- {"id": 2, "value": "foo"},
- datetime(2010, 1, 1, tzinfo=timezone.utc),
- (
- "SET id=2, value='foo' WHERE writer.id = 2 "
- "AND writer.updated_at = '2010-01-01 00:00:00+00:00'"
- ),
- ),
- ],
- )
- async def test_update(sql_gateway, record, if_unmodified_since, sql):
- records = [{"id": 2, "value": "foo"}]
- sql_gateway.provider.result.return_value = records
- assert await sql_gateway.update(record, if_unmodified_since) == records[0]
- assert len(sql_gateway.provider.queries) == 1
- assert_query_equal(
- sql_gateway.provider.queries[0][0],
- (f"UPDATE writer {sql} RETURNING {ALL_FIELDS}"),
- )
- async def test_update_does_not_exist(sql_gateway):
- sql_gateway.provider.result.return_value = []
- with pytest.raises(DoesNotExist):
- await sql_gateway.update({"id": 2})
- assert len(sql_gateway.provider.queries) == 1
- @mock.patch.object(SQLGateway, "get")
- async def test_update_if_unmodified_since_does_not_exist(get_m, sql_gateway):
- get_m.return_value = None
- sql_gateway.provider.result.return_value = []
- with pytest.raises(DoesNotExist):
- await sql_gateway.update(
- {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)
- )
- assert len(sql_gateway.provider.queries) == 1
- get_m.assert_awaited_once_with(2)
- @mock.patch.object(SQLGateway, "get")
- async def test_update_if_unmodified_since_conflict(get_m, sql_gateway):
- get_m.return_value = {"id": 2, "value": "foo"}
- sql_gateway.provider.result.return_value = []
- with pytest.raises(Conflict):
- await sql_gateway.update(
- {"id": 2}, if_unmodified_since=datetime(2010, 1, 1, tzinfo=timezone.utc)
- )
- assert len(sql_gateway.provider.queries) == 1
- get_m.assert_awaited_once_with(2)
- async def test_remove(sql_gateway):
- sql_gateway.provider.result.return_value = [{"id": 2}]
- assert (await sql_gateway.remove(2)) is True
- assert len(sql_gateway.provider.queries) == 1
- assert_query_equal(
- sql_gateway.provider.queries[0][0],
- ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),
- )
- async def test_remove_does_not_exist(sql_gateway):
- sql_gateway.provider.result.return_value = []
- assert (await sql_gateway.remove(2)) is False
- assert len(sql_gateway.provider.queries) == 1
- assert_query_equal(
- sql_gateway.provider.queries[0][0],
- ("DELETE FROM writer WHERE writer.id = 2 RETURNING writer.id"),
- )
- async def test_upsert(sql_gateway):
- record = {"id": 2, "value": "foo"}
- sql_gateway.provider.result.return_value = [record]
- assert await sql_gateway.upsert(record) == record
- assert len(sql_gateway.provider.queries) == 1
- assert_query_equal(
- sql_gateway.provider.queries[0][0],
- (
- f"INSERT INTO writer (id, value) VALUES (2, 'foo') "
- f"ON CONFLICT (id) DO UPDATE SET "
- f"id = %(param_1)s, value = %(param_2)s "
- f"RETURNING {ALL_FIELDS}"
- ),
- )
- @mock.patch.object(SQLGateway, "add")
- async def test_upsert_no_id(add_m, sql_gateway):
- add_m.return_value = {"id": 5, "value": "foo"}
- assert await sql_gateway.upsert({"value": "foo"}) == add_m.return_value
- add_m.assert_awaited_once_with({"value": "foo"})
- assert len(sql_gateway.provider.queries) == 0
- async def test_get_related_one_to_many(related_sql_gateway: SQLGateway):
- writers = [{"id": 2}, {"id": 3}]
- books = [
- {"id": 3, "title": "x", "writer_id": 2},
- {"id": 4, "title": "y", "writer_id": 2},
- ]
- related_sql_gateway.provider.result.return_value = books
- await related_sql_gateway._get_related_one_to_many(
- items=writers,
- field_name="books",
- fk_name="writer_id",
- )
- assert writers == [{"id": 2, "books": books}, {"id": 3, "books": []}]
- assert len(related_sql_gateway.provider.queries) == 1
- assert_query_equal(
- related_sql_gateway.provider.queries[0][0],
- (
- "SELECT book.id, book.title, book.writer_id FROM book WHERE book.writer_id IN (2, 3)"
- ),
- )
- @pytest.mark.parametrize(
- "books,current_books,expected_queries,query_results",
- [
- # no change
- (
- [{"id": 3, "title": "x", "writer_id": 2}],
- [{"id": 3, "title": "x", "writer_id": 2}],
- [],
- [],
- ),
- # added a book (without an id)
- (
- [{"title": "x", "writer_id": 2}],
- [],
- [
- f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}"
- ],
- [[{"id": 3, "title": "x", "writer_id": 2}]],
- ),
- # added a book (with an id)
- (
- [{"id": 3, "title": "x", "writer_id": 2}],
- [],
- [
- f"INSERT INTO book (id, title, writer_id) VALUES (3, 'x', 2) RETURNING {BOOK_FIELDS}"
- ],
- [[{"id": 3, "title": "x", "writer_id": 2}]],
- ),
- # updated a book
- (
- [{"id": 3, "title": "x", "writer_id": 2}],
- [{"id": 3, "title": "a", "writer_id": 2}],
- [
- f"UPDATE book SET id=3, title='x', writer_id=2 WHERE book.id = 3 RETURNING {BOOK_FIELDS}"
- ],
- [[{"id": 3, "title": "x", "writer_id": 2}]],
- ),
- # replaced a book with a new one
- (
- [{"title": "x", "writer_id": 2}],
- [{"id": 15, "title": "a", "writer_id": 2}],
- [
- f"INSERT INTO book (title, writer_id) VALUES ('x', 2) RETURNING {BOOK_FIELDS}",
- "DELETE FROM book WHERE book.id = 15 RETURNING book.id",
- ],
- [[{"id": 3, "title": "x", "writer_id": 2}], [{"id": 15}]],
- ),
- ],
- )
- async def test_set_related_one_to_many(
- related_sql_gateway: SQLGateway,
- books,
- current_books,
- expected_queries,
- query_results,
- ):
- writer = {"id": 2, "books": books}
- related_sql_gateway.provider.result.side_effect = [current_books] + query_results
- result = writer.copy()
- await related_sql_gateway._set_related_one_to_many(
- item=writer,
- result=result,
- field_name="books",
- fk_name="writer_id",
- )
- assert result == {
- "id": 2,
- "books": [{"id": 3, "title": "x", "writer_id": 2}],
- }
- assert len(related_sql_gateway.provider.queries) == len(expected_queries) + 1
- assert_query_equal(
- related_sql_gateway.provider.queries[0][0],
- f"SELECT {BOOK_FIELDS} FROM book WHERE book.writer_id = 2",
- )
- for (actual_query,), expected_query in zip(
- related_sql_gateway.provider.queries[1:], expected_queries
- ):
- assert_query_equal(actual_query, expected_query)
- async def test_update_transactional(sql_gateway):
- existing = {"id": 2, "value": "foo"}
- expected = {"id": 2, "value": "bar"}
- sql_gateway.provider.result.side_effect = ([existing], [expected])
- actual = await sql_gateway.update_transactional(
- 2, lambda x: {"id": x["id"], "value": "bar"}
- )
- assert actual == expected
- (queries,) = sql_gateway.provider.queries
- assert len(queries) == 2
- assert_query_equal(
- queries[0],
- f"SELECT {ALL_FIELDS} FROM writer WHERE writer.id = 2 FOR UPDATE",
- )
- assert_query_equal(
- queries[1],
- (
- f"UPDATE writer SET id=2, value='bar' WHERE writer.id = 2 RETURNING {ALL_FIELDS}"
- ),
- )
- @pytest.mark.parametrize(
- "filters,sql",
- [
- ([], ""),
- ([Filter(field="value", values=[])], " WHERE false"),
- ([Filter(field="value", values=["foo"])], " WHERE writer.value = 'foo'"),
- (
- [Filter(field="value", values=["foo", "bar"])],
- " WHERE writer.value IN ('foo', 'bar')",
- ),
- ([Filter(field="nonexisting", values=["foo"])], " WHERE false"),
- (
- [Filter(field="id", values=[1]), Filter(field="value", values=["foo"])],
- " WHERE writer.id = 1 AND writer.value = 'foo'",
- ),
- ],
- )
- async def test_exists(sql_gateway, filters, sql):
- sql_gateway.provider.result.return_value = [{"exists": True}]
- assert await sql_gateway.exists(filters) is True
- assert len(sql_gateway.provider.queries) == 1
- assert_query_equal(
- sql_gateway.provider.queries[0][0],
- f"SELECT true AS exists FROM writer{sql} LIMIT 1",
- )
|