فهرست منبع

Port SQL integration tests (#2)

Casper van der Wel 1 سال پیش
والد
کامیت
f4897247e4

+ 22 - 3
.github/workflows/test.yml

@@ -8,8 +8,8 @@ on:
   pull_request:
 
 jobs:
-  TestLinux:
-    name: Linux, Python ${{ matrix.python }}
+  test:
+    name: Test, Linux, Python ${{ matrix.python }}
     runs-on: ubuntu-latest
     strategy:
       fail-fast: false
@@ -28,6 +28,14 @@ jobs:
           - python: "3.11"
             pins: ""
 
+    services:
+      postgres:
+        image: postgres:14-alpine
+        env:
+          POSTGRES_PASSWORD: postgres
+        ports:
+          - 5432:5432
+
     steps:
       - uses: actions/checkout@v3
 
@@ -43,4 +51,15 @@ jobs:
           pip list
 
       - name: Run tests
-        run: pytest --cov
+        run: pytest tests --cov
+
+      - name: Wait for postgres
+        run: scripts/wait-for-postgres.sh
+        env:
+          POSTGRES_URL: 'postgres:postgres@localhost:5432'
+        timeout-minutes: 1
+
+      - name: Run integration tests
+        run: pytest integration_tests
+        env:
+          POSTGRES_URL: 'postgres:postgres@localhost:5432'

+ 13 - 0
clean_python/base/domain/filter.py

@@ -0,0 +1,13 @@
+# (c) Nelen & Schuurmans
+
+from typing import Any
+from typing import List
+
+from .value_object import ValueObject
+
+__all__ = ["Filter"]
+
+
+class Filter(ValueObject):
+    field: str
+    values: List[Any]

+ 56 - 0
clean_python/base/domain/gateway.py

@@ -0,0 +1,56 @@
+# (c) Nelen & Schuurmans
+
+from abc import ABC
+from datetime import datetime
+from typing import Callable
+from typing import List
+from typing import Optional
+
+from .exceptions import DoesNotExist
+from .filter import Filter
+from .json import Json
+from .pagination import PageOptions
+
+__all__ = ["Gateway"]
+
+
+class Gateway(ABC):
+    async def filter(
+        self, filters: List[Filter], params: Optional[PageOptions] = None
+    ) -> List[Json]:
+        raise NotImplementedError()
+
+    async def count(self, filters: List[Filter]) -> int:
+        return len(await self.filter(filters, params=None))
+
+    async def exists(self, filters: List[Filter]) -> bool:
+        return len(await self.filter(filters, params=PageOptions(limit=1))) > 0
+
+    async def get(self, id: int) -> Optional[Json]:
+        result = await self.filter([Filter(field="id", values=[id])], params=None)
+        return result[0] if result else None
+
+    async def add(self, item: Json) -> Json:
+        raise NotImplementedError()
+
+    async def update(
+        self, item: Json, if_unmodified_since: Optional[datetime] = None
+    ) -> Json:
+        raise NotImplementedError()
+
+    async def update_transactional(self, id: int, func: Callable[[Json], Json]) -> Json:
+        existing = await self.get(id)
+        if existing is None:
+            raise DoesNotExist("record", id)
+        return await self.update(
+            func(existing), if_unmodified_since=existing["updated_at"]
+        )
+
+    async def upsert(self, item: Json) -> Json:
+        try:
+            return await self.update(item)
+        except DoesNotExist:
+            return await self.add(item)
+
+    async def remove(self, id: int) -> bool:
+        raise NotImplementedError()

+ 9 - 0
clean_python/base/domain/json.py

@@ -0,0 +1,9 @@
+# (c) Nelen & Schuurmans
+
+from typing import Any
+from typing import Dict
+
+__all__ = ["Json"]
+
+
+Json = Dict[str, Any]

+ 81 - 0
clean_python/base/infrastructure/in_memory_gateway.py

@@ -0,0 +1,81 @@
+# (c) Nelen & Schuurmans
+
+from copy import deepcopy
+from datetime import datetime
+from typing import List
+from typing import Optional
+
+from clean_python.base.domain import AlreadyExists
+from clean_python.base.domain import Conflict
+from clean_python.base.domain import DoesNotExist
+from clean_python.base.domain import Filter
+from clean_python.base.domain import Gateway
+from clean_python.base.domain import Json
+from clean_python.base.domain import PageOptions
+
+__all__ = ["InMemoryGateway"]
+
+
+class InMemoryGateway(Gateway):
+    """For testing purposes"""
+
+    def __init__(self, data: List[Json]):
+        self.data = {x["id"]: deepcopy(x) for x in data}
+
+    def _get_next_id(self) -> int:
+        if len(self.data) == 0:
+            return 1
+        else:
+            return max(self.data) + 1
+
+    def _paginate(self, objs: List[Json], params: PageOptions) -> List[Json]:
+        objs = sorted(
+            objs,
+            key=lambda x: (x.get(params.order_by) is None, x.get(params.order_by)),
+            reverse=not params.ascending,
+        )
+        return objs[params.offset : params.offset + params.limit]
+
+    async def filter(
+        self, filters: List[Filter], params: Optional[PageOptions] = None
+    ) -> List[Json]:
+        result = []
+        for x in self.data.values():
+            for filter in filters:
+                if x.get(filter.field) not in filter.values:
+                    break
+            else:
+                result.append(deepcopy(x))
+        if params is not None:
+            result = self._paginate(result, params)
+        return result
+
+    async def add(self, item: Json) -> Json:
+        item = item.copy()
+        id_ = item.pop("id", None)
+        # autoincrement (like SQL does)
+        if id_ is None:
+            id_ = self._get_next_id()
+        elif id_ in self.data:
+            raise AlreadyExists(id_)
+
+        self.data[id_] = {"id": id_, **item}
+        return deepcopy(self.data[id_])
+
+    async def update(
+        self, item: Json, if_unmodified_since: Optional[datetime] = None
+    ) -> Json:
+        _id = item.get("id")
+        if _id is None or _id not in self.data:
+            raise DoesNotExist("item", _id)
+        existing = self.data[_id]
+        if if_unmodified_since and existing.get("updated_at") != if_unmodified_since:
+            raise Conflict()
+        existing.update(item)
+        return deepcopy(existing)
+
+    async def remove(self, id: int) -> bool:
+        if id not in self.data:
+            return False
+        del self.data[id]
+        return True

+ 19 - 0
clean_python/sql/sql_provider.py

@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
 from typing import AsyncIterator
 from typing import List
 
+from sqlalchemy import text
 from sqlalchemy.exc import DBAPIError
 from sqlalchemy.ext.asyncio import AsyncConnection
 from sqlalchemy.ext.asyncio import AsyncEngine
@@ -54,6 +55,24 @@ class SQLDatabase(SQLProvider):
             async with connection.begin():
                 yield SQLTransaction(connection)
 
+    @asynccontextmanager
+    async def testing_transaction(self) -> AsyncIterator[SQLProvider]:
+        async with self.engine.connect() as connection:
+            async with connection.begin() as transaction:
+                yield SQLTransaction(connection)
+                await transaction.rollback()
+
+    async def _execute_autocommit(self, query: Executable) -> None:
+        engine = create_async_engine(self.engine.url, isolation_level="AUTOCOMMIT")
+        async with engine.connect() as connection:
+            await connection.execute(query)
+
+    async def create_database(self, name: str) -> None:
+        await self._execute_autocommit(text(f"CREATE DATABASE {name}"))
+
+    async def drop_database(self, name: str) -> None:
+        await self._execute_autocommit(text(f"DROP DATABASE IF EXISTS {name}"))
+
 
 class SQLTransaction(SQLProvider):
     def __init__(self, connection: AsyncConnection):

+ 57 - 0
clean_python/sql/testing.py

@@ -0,0 +1,57 @@
+from contextlib import asynccontextmanager
+from typing import AsyncIterator
+from typing import List
+from unittest import mock
+
+from sqlalchemy.dialects import postgresql
+from sqlalchemy.sql import Executable
+
+from clean_python import Json
+from clean_python.sql import SQLProvider
+
+__all__ = ["FakeSQLDatabase", "assert_query_equal"]
+
+
+class FakeSQLDatabase(SQLProvider):
+    def __init__(self):
+        self.queries: List[List[Executable]] = []
+        self.result = mock.Mock(return_value=[])
+
+    async def execute(self, query: Executable) -> List[Json]:
+        self.queries.append([query])
+        return self.result()
+
+    @asynccontextmanager
+    async def transaction(self) -> AsyncIterator["SQLProvider"]:
+        x = FakeSQLTransaction(result=self.result)
+        self.queries.append(x.queries)
+        yield x
+
+
+class FakeSQLTransaction(SQLProvider):
+    def __init__(self, result: mock.Mock):
+        self.queries: List[Executable] = []
+        self.result = result
+
+    async def execute(self, query: Executable) -> List[Json]:
+        self.queries.append(query)
+        return self.result()
+
+
+def assert_query_equal(q: Executable, expected: str, literal_binds: bool = True):
+    """There are two ways of 'binding' parameters (for testing!):
+
+    literal_binds=True: use the built-in sqlalchemy way, which fails on some datatypes (Range)
+    literal_binds=False: do it yourself using %, there is no 'mogrify' so don't expect quotes.
+    """
+    assert isinstance(q, Executable)
+    compiled = q.compile(
+        compile_kwargs={"literal_binds": literal_binds},
+        dialect=postgresql.dialect(),
+    )
+    if not literal_binds:
+        actual = str(compiled) % compiled.params
+    else:
+        actual = str(compiled)
+    actual = actual.replace("\n", "").replace("  ", " ")
+    assert actual == expected

+ 1 - 1
clean_python/testing/__init__.py

@@ -1,2 +1,2 @@
 from .attr_dict import *  # NOQA
-from .profilers import *  # NOQA
+from .debugger import *  # NOQA

+ 19 - 0
clean_python/testing/debugger.py

@@ -0,0 +1,19 @@
+import os
+
+
+def setup_debugger(*, host: str = "0.0.0.0", port: int = 5678):
+    """Configure debugging via debugpy."""
+
+    # Only to be used in development. Should someone inadvertently set DEBUG to True in
+    # staging or production, a ModuleNotFoundError will be raised, because debugpy is
+    # only available via requirements-dev.txt - this is intentionally.
+    if os.environ.get("DEBUG") or os.environ.get("DEBUG_WAIT_FOR_CLIENT"):
+        try:
+            import debugpy
+
+            debugpy.listen((host, port))
+            if os.environ.get("DEBUG_WAIT_FOR_CLIENT"):
+                print("🔌 debugpy waiting for a client to attach 🔌", flush=True)
+                debugpy.wait_for_client()
+        except (ModuleNotFoundError, RuntimeError) as e:
+            print(e, flush=True)

+ 11 - 0
docker-compose.yaml

@@ -0,0 +1,11 @@
+version: "3.8"
+
+services:
+
+  db:
+    image: postgres:14-alpine
+    environment:
+      POSTGRES_PASSWORD: "postgres"
+    # command: ["postgres", "-c", "log_connections=all", "-c", "log_disconnections=all", "-c", "log_statement=all", "-c", "log_destination=stderr"]
+    ports:
+      - "5432:5432"

+ 35 - 0
integration_tests/conftest.py

@@ -0,0 +1,35 @@
+# (c) Nelen & Schuurmans
+
+import asyncio
+import os
+
+import pytest
+
+from clean_python.testing import setup_debugger
+
+
+def pytest_sessionstart(session):
+    """
+    Called after the Session object has been created and
+    before performing collection and entering the run test loop.
+    """
+    setup_debugger()
+
+
+@pytest.fixture(scope="session")
+def event_loop(request):
+    """Create an instance of the default event loop per test session.
+
+    Async fixtures need the event loop, and so must have the same or narrower scope than
+    the event_loop fixture. Since we have async session-scoped fixtures, the default
+    event_loop fixture, which has function scope, cannot be used. See:
+    https://github.com/pytest-dev/pytest-asyncio#async-fixtures
+    """
+    loop = asyncio.get_event_loop_policy().new_event_loop()
+    yield loop
+    loop.close()
+
+
+@pytest.fixture(scope="session")
+async def postgres_url():
+    return os.environ.get("POSTGRES_URL", "postgres:postgres@localhost:5432")

+ 369 - 0
integration_tests/test_sql_database.py

@@ -0,0 +1,369 @@
+# -*- coding: utf-8 -*-
+# (c) Nelen & Schuurmans
+from datetime import datetime
+from datetime import timezone
+from unittest import mock
+
+import pytest
+from sqlalchemy import Boolean
+from sqlalchemy import Column
+from sqlalchemy import DateTime
+from sqlalchemy import Float
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import Table
+from sqlalchemy import Text
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy.sql import text
+
+from clean_python import AlreadyExists
+from clean_python import Conflict
+from clean_python import DoesNotExist
+from clean_python import Filter
+from clean_python.sql import SQLDatabase
+from clean_python.sql import SQLGateway
+
+test_model = Table(
+    "test_model",
+    MetaData(),
+    Column("id", Integer, primary_key=True, autoincrement=True),
+    Column("t", Text, nullable=False),
+    Column("f", Float, nullable=False),
+    Column("b", Boolean, nullable=False),
+    Column("updated_at", DateTime(timezone=True), nullable=False),
+    Column("n", Float, nullable=True),
+)
+
+
+### SQLProvider integration tests
+count_query = text("SELECT COUNT(*) FROM test_model")
+insert_query = text(
+    "INSERT INTO test_model (t, f, b, updated_at) "
+    "VALUES ('foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
+    "RETURNING id"
+)
+
+
+@pytest.fixture(scope="session")
+async def database(postgres_url):
+    dburl = f"postgresql+asyncpg://{postgres_url}"
+    dbname = "cleanpython_test"
+    root_provider = SQLDatabase(f"{dburl}/")
+    await root_provider.drop_database(dbname)
+    await root_provider.create_database(dbname)
+    provider = SQLDatabase(f"{dburl}/{dbname}")
+    async with provider.engine.begin() as conn:
+        await conn.run_sync(test_model.metadata.drop_all)
+        await conn.run_sync(test_model.metadata.create_all)
+    yield SQLDatabase(f"{dburl}/{dbname}")
+
+
+@pytest.fixture
+async def database_with_cleanup(database):
+    await database.execute(text("DELETE FROM test_model WHERE TRUE RETURNING id"))
+    yield database
+    await database.execute(text("DELETE FROM test_model WHERE TRUE RETURNING id"))
+
+
+@pytest.fixture
+async def transaction_with_cleanup(database_with_cleanup):
+    async with database_with_cleanup.transaction() as trans:
+        yield trans
+
+
+async def test_execute(database_with_cleanup):
+    db = database_with_cleanup
+    await db.execute(insert_query)
+    assert await db.execute(count_query) == [{"count": 1}]
+
+
+async def test_transaction_commits(database_with_cleanup):
+    db = database_with_cleanup
+
+    async with db.transaction() as trans:
+        await trans.execute(insert_query)
+
+    assert await db.execute(count_query) == [{"count": 1}]
+
+
+async def test_transaction_err(database_with_cleanup):
+    db = database_with_cleanup
+    await db.execute(insert_query)
+
+    with pytest.raises(RuntimeError):
+        async with db.transaction() as trans:
+            await trans.execute(insert_query)
+
+            raise RuntimeError()  # triggers rollback
+
+    assert await db.execute(count_query) == [{"count": 1}]
+
+
+async def test_nested_transaction_commits(transaction_with_cleanup):
+    db = transaction_with_cleanup
+
+    async with db.transaction() as trans:
+        await trans.execute(insert_query)
+
+    assert await db.execute(count_query) == [{"count": 1}]
+
+
+async def test_nested_transaction_err(transaction_with_cleanup):
+    db = transaction_with_cleanup
+    await db.execute(insert_query)
+
+    with pytest.raises(RuntimeError):
+        async with db.transaction() as trans:
+            await trans.execute(insert_query)
+
+            raise RuntimeError()  # triggers rollback
+
+    assert await db.execute(count_query) == [{"count": 1}]
+
+
+async def test_testing_transaction_rollback(database_with_cleanup):
+    async with database_with_cleanup.testing_transaction() as trans:
+        await trans.execute(insert_query)
+
+    assert await database_with_cleanup.execute(count_query) == [{"count": 0}]
+
+
+### SQLGateway integration tests
+
+
+class TstSQLGateway(SQLGateway, table=test_model):
+    pass
+
+
+@pytest.fixture
+async def test_transaction(database):
+    async with database.testing_transaction() as test_transaction:
+        yield test_transaction
+
+
+@pytest.fixture
+def sql_gateway(test_transaction):
+    return TstSQLGateway(test_transaction)
+
+
+@pytest.fixture
+def obj():
+    return {
+        "t": "foo",
+        "f": 1.23,
+        "b": True,
+        "updated_at": datetime(2016, 6, 23, 2, 10, 25, tzinfo=timezone.utc),
+        "n": None,
+    }
+
+
+@pytest.fixture
+async def obj_in_db(test_transaction, obj):
+    res = await test_transaction.execute(
+        text(
+            "INSERT INTO test_model (t, f, b, updated_at) "
+            "VALUES ('foo', 1.23, TRUE, '2016-06-22 19:10:25-07') "
+            "RETURNING id"
+        )
+    )
+    return {"id": res[0]["id"], **obj}
+
+
+async def test_get(sql_gateway, obj_in_db):
+    actual = await sql_gateway.get(obj_in_db["id"])
+
+    assert isinstance(actual, dict)
+    assert actual == obj_in_db
+    assert actual is not obj_in_db
+
+
+async def test_get_not_found(sql_gateway, obj_in_db):
+    assert await sql_gateway.get(obj_in_db["id"] + 1) is None
+
+
+async def test_add(sql_gateway, test_transaction, obj):
+    created = await sql_gateway.add(obj)
+
+    id = created.pop("id")
+    assert isinstance(id, int)
+    assert created is not obj
+    assert created == obj
+
+    res = await test_transaction.execute(
+        text(f"SELECT * FROM test_model WHERE id = {id}")
+    )
+    assert res[0]["t"] == obj["t"]
+
+
+async def test_add_id_exists(sql_gateway, obj_in_db):
+    with pytest.raises(AlreadyExists):
+        await sql_gateway.add(obj_in_db)
+
+
+@pytest.mark.parametrize("id", [10, None, "delete"])
+async def test_add_integrity_error(sql_gateway, obj, id):
+    obj.pop("t")  # will cause the IntegrityError
+    if id != "delete":
+        obj["id"] = id
+    with pytest.raises(IntegrityError):
+        await sql_gateway.add(obj)
+
+
+async def test_add_unkown_column(sql_gateway, obj):
+    created = await sql_gateway.add({"unknown": "foo", **obj})
+
+    created.pop("id")
+    assert created == obj
+
+
+async def test_update(sql_gateway, test_transaction, obj_in_db):
+    obj_in_db["t"] = "bar"
+
+    updated = await sql_gateway.update(obj_in_db)
+
+    assert updated is not obj_in_db
+    assert updated == obj_in_db
+
+    res = await test_transaction.execute(
+        text(f"SELECT * FROM test_model WHERE id = {obj_in_db['id']}")
+    )
+    assert res[0]["t"] == "bar"
+
+
+async def test_update_not_found(sql_gateway, obj):
+    obj["id"] = 42
+
+    with pytest.raises(DoesNotExist):
+        await sql_gateway.update(obj)
+
+
+async def test_update_unkown_column(sql_gateway, obj_in_db):
+    obj_in_db["t"] = "bar"
+    updated = await sql_gateway.update({"unknown": "foo", **obj_in_db})
+
+    assert updated == obj_in_db
+
+
+async def test_upsert_does_add(sql_gateway, test_transaction, obj):
+    obj["id"] = 42
+    created = await sql_gateway.upsert(obj)
+
+    assert created is not obj
+    assert created == obj
+
+    res = await test_transaction.execute(text("SELECT * FROM test_model WHERE id = 42"))
+    assert res[0]["t"] == obj["t"]
+
+
+async def test_upsert_does_update(sql_gateway, test_transaction, obj_in_db):
+    obj_in_db["t"] = "bar"
+    updated = await sql_gateway.upsert(obj_in_db)
+
+    assert updated is not obj_in_db
+    assert updated == obj_in_db
+
+    res = await test_transaction.execute(
+        text(f"SELECT * FROM test_model WHERE id = {obj_in_db['id']}")
+    )
+    assert res[0]["t"] == "bar"
+
+
+async def test_upsert_no_id(sql_gateway, test_transaction, obj):
+    with mock.patch.object(sql_gateway, "add", new_callable=mock.AsyncMock) as add_m:
+        created = await sql_gateway.upsert(obj)
+        add_m.assert_awaited_with(obj)
+        assert created == add_m.return_value
+
+
+async def test_remove(sql_gateway, test_transaction, obj_in_db):
+    assert await sql_gateway.remove(obj_in_db["id"])
+
+    res = await test_transaction.execute(
+        text(f"SELECT COUNT(*) FROM test_model WHERE id = {obj_in_db['id']}")
+    )
+    assert res[0]["count"] == 0
+
+
+async def test_remove_not_found(sql_gateway):
+    assert not await sql_gateway.remove(42)
+
+
+async def test_update_if_unmodified_since(sql_gateway, obj_in_db):
+    obj_in_db["t"] = "bar"
+
+    updated = await sql_gateway.update(
+        obj_in_db, if_unmodified_since=obj_in_db["updated_at"]
+    )
+
+    assert updated == obj_in_db
+
+
+@pytest.mark.parametrize(
+    "if_unmodified_since", [datetime.now(timezone.utc), datetime(2010, 1, 1)]
+)
+async def test_update_if_unmodified_since_not_ok(
+    sql_gateway, obj_in_db, if_unmodified_since
+):
+    obj_in_db["t"] = "bar"
+
+    with pytest.raises(Conflict):
+        await sql_gateway.update(obj_in_db, if_unmodified_since=if_unmodified_since)
+
+
+@pytest.mark.parametrize(
+    "filters,match",
+    [
+        ([], True),
+        ([Filter(field="t", values=["foo"])], True),
+        ([Filter(field="t", values=["bar"])], False),
+        ([Filter(field="t", values=["foo"]), Filter(field="f", values=[1.23])], True),
+        ([Filter(field="t", values=["foo"]), Filter(field="f", values=[1.24])], False),
+        ([Filter(field="nonexisting", values=["foo"])], False),
+        ([Filter(field="t", values=[])], False),
+        ([Filter(field="t", values=["foo", "bar"])], True),
+    ],
+)
+async def test_filter(filters, match, sql_gateway, obj_in_db):
+    actual = await sql_gateway.filter(filters)
+
+    assert actual == ([obj_in_db] if match else [])
+
+
+@pytest.fixture
+async def obj2_in_db(test_transaction, obj):
+    res = await test_transaction.execute(
+        text(
+            "INSERT INTO test_model (t, f, b, updated_at) "
+            "VALUES ('bar', 1.24, TRUE, '2018-06-22 19:10:25-07') "
+            "RETURNING id"
+        )
+    )
+    return {"id": res[0]["id"], **obj}
+
+
+@pytest.mark.parametrize(
+    "filters,expected",
+    [
+        ([], 2),
+        ([Filter(field="t", values=["foo"])], 1),
+        ([Filter(field="t", values=["bar"])], 1),
+        ([Filter(field="t", values=["baz"])], 0),
+    ],
+)
+async def test_count(filters, expected, sql_gateway, obj_in_db, obj2_in_db):
+    actual = await sql_gateway.count(filters)
+    assert actual == expected
+
+
+@pytest.mark.parametrize(
+    "filters,expected",
+    [
+        ([], True),
+        ([Filter(field="t", values=["foo"])], True),
+        ([Filter(field="t", values=["bar"])], True),
+        ([Filter(field="t", values=["baz"])], False),
+    ],
+)
+async def test_exists(filters, expected, sql_gateway, obj_in_db, obj2_in_db):
+    actual = await sql_gateway.exists(filters)
+    assert actual == expected

+ 3 - 3
pyproject.toml

@@ -17,14 +17,15 @@ dynamic = ["version"]
 test = [
     "pytest",
     "pytest-cov",
-    "pytest-asyncio"
+    "pytest-asyncio",
+    "debugpy",
 ]
 dramatiq = ["dramatiq"]
 fastapi = ["fastapi"]
 auth = ["pyjwt[crypto]==2.6.0"]
 celery = ["pika"]
 fluentbit = ["fluent-logger"]
-sql = ["sqlalchemy==2.*"]
+sql = ["sqlalchemy==2.*", "asyncpg"]
 
 [project.urls]
 homepage = "https://github.com/nens/clean-python"
@@ -47,4 +48,3 @@ force_single_line = true
 [tool.pytest.ini_options]
 norecursedirs=".venv data doc etc *.egg-info misc var build lib include"
 python_files = "test_*.py"
-testpaths = "clean_python"

+ 0 - 1
pytest.ini

@@ -3,7 +3,6 @@
 
 [pytest]
 asyncio_mode = auto
-testpaths = tests
 ; Apparently, gevent is enabling "flush to zero":
 ; https://github.com/numpy/numpy/issues/20895
 ; Suppress the corresponding warnings:

+ 2 - 0
scripts/init-db.sql

@@ -0,0 +1,2 @@
+\set test_db `echo cleanpython_test`
+CREATE DATABASE :test_db;

+ 18 - 0
scripts/wait-for-postgres.sh

@@ -0,0 +1,18 @@
+#!/bin/sh
+# wait-for-postgres.sh
+
+# https://docs.docker.com/compose/startup-order/
+# https://gist.github.com/mihow/9c7f559807069a03e302605691f85572?permalink_comment_id=3709779#gistcomment-3709779
+# https://www.postgresql.org/docs/current/libpq-envars.html
+
+set -e
+
+until psql "postgres://$POSTGRES_URL/" -c '\q'; do
+  >&2 echo "Postgres is unavailable - sleeping"
+  sleep 1
+done
+
+echo "$@"
+
+>&2 echo "Postgres is up - executing command"
+exec "$@"

+ 0 - 0
tests/__init__.py