| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366 | # (c) Nelen & Schuurmansfrom contextlib import asynccontextmanagerfrom datetime import datetimefrom typing import AsyncIteratorfrom typing import Callablefrom typing import Listfrom typing import Optionalfrom typing import TypeVarimport injectfrom sqlalchemy import and_from sqlalchemy import ascfrom sqlalchemy import deletefrom sqlalchemy import descfrom sqlalchemy import funcfrom sqlalchemy import selectfrom sqlalchemy import Tablefrom sqlalchemy import truefrom sqlalchemy import updatefrom sqlalchemy.dialects.postgresql import insertfrom sqlalchemy.exc import IntegrityErrorfrom sqlalchemy.sql import Executablefrom sqlalchemy.sql.expression import ColumnElementfrom sqlalchemy.sql.expression import falsefrom clean_python import AlreadyExistsfrom clean_python import Conflictfrom clean_python import ctxfrom clean_python import DoesNotExistfrom clean_python import Filterfrom clean_python import Gatewayfrom clean_python import Idfrom clean_python import Jsonfrom clean_python import PageOptionsfrom .sql_provider import SQLDatabasefrom .sql_provider import SQLProvider__all__ = ["SQLGateway"]def _is_unique_violation_error_id(e: IntegrityError, id: int):    # sqlalchemy wraps the asyncpg error    msg = e.orig.args[0]    return ("duplicate key value violates unique constraint" in msg) and (        f"Key (id)=({id}) already exists." in msg    )T = TypeVar("T", bound="SQLGateway")class SQLGateway(Gateway):    table: Table    nested: bool    multitenant: bool    def __init__(        self,        provider_override: Optional[SQLProvider] = None,        nested: bool = False,    ):        self.provider_override = provider_override        self.nested = nested    @property    def provider(self):        return self.provider_override or inject.instance(SQLDatabase)    def __init_subclass__(cls, table: Table, multitenant: bool = False) -> None:        cls.table = table        if multitenant and not hasattr(table.c, "tenant"):            raise ValueError("Can't use a multitenant SQLGateway without tenant column")        cls.multitenant = multitenant        super().__init_subclass__()    def rows_to_dict(self, rows: List[Json]) -> List[Json]:        return rows    def dict_to_row(self, obj: Json) -> Json:        known = {c.key for c in self.table.c}        result = {k: obj[k] for k in obj.keys() if k in known}        if "id" in result and result["id"] is None:            del result["id"]        if self.multitenant:            result["tenant"] = self.current_tenant        return result    @asynccontextmanager    async def transaction(self: T) -> AsyncIterator[T]:        if self.nested:            yield self        else:            async with self.provider.transaction() as provider:                yield self.__class__(provider, nested=True)    @property    def current_tenant(self) -> Optional[int]:        if not self.multitenant:            return None        if ctx.tenant is None:            raise RuntimeError(f"{self.__class__} requires a tenant in the context")        return ctx.tenant.id    async def get_related(self, items: List[Json]) -> None:        pass    async def set_related(self, item: Json, result: Json) -> None:        pass    async def execute(self, query: Executable) -> List[Json]:        assert self.nested        return self.rows_to_dict(await self.provider.execute(query))    async def add(self, item: Json) -> Json:        query = (            insert(self.table).values(**self.dict_to_row(item)).returning(self.table)        )        async with self.transaction() as transaction:            try:                (result,) = await transaction.execute(query)            except IntegrityError as e:                id_ = item.get("id")                if id_ is not None and _is_unique_violation_error_id(e, id_):                    raise AlreadyExists(id_)                raise            await transaction.set_related(item, result)        return result    async def update(        self, item: Json, if_unmodified_since: Optional[datetime] = None    ) -> Json:        id_ = item.get("id")        if id_ is None:            raise DoesNotExist("record", id_)        q = self._id_filter_to_sql(id_)        if if_unmodified_since is not None:            q &= self.table.c.updated_at == if_unmodified_since        query = (            update(self.table)            .where(q)            .values(**self.dict_to_row(item))            .returning(self.table)        )        async with self.transaction() as transaction:            result = await transaction.execute(query)            if not result:                if if_unmodified_since is not None:                    # note: the get() is to maybe raise DoesNotExist                    if await self.get(id_):                        raise Conflict()                raise DoesNotExist("record", id_)            await transaction.set_related(item, result[0])        return result[0]    async def _select_for_update(self, id: Id) -> Json:        async with self.transaction() as transaction:            result = await transaction.execute(                select(self.table).with_for_update().where(self._id_filter_to_sql(id)),            )            if not result:                raise DoesNotExist("record", id)            await transaction.get_related(result)        return result[0]    async def update_transactional(self, id: Id, func: Callable[[Json], Json]) -> Json:        async with self.transaction() as transaction:            existing = await transaction._select_for_update(id)            updated = func(existing)            return await transaction.update(updated)    async def upsert(self, item: Json) -> Json:        if item.get("id") is None:            return await self.add(item)        values = self.dict_to_row(item)        query = (            insert(self.table)            .values(**values)            .on_conflict_do_update(                index_elements=["id", "tenant"] if self.multitenant else ["id"],                set_=values,            )            .returning(self.table)        )        async with self.transaction() as transaction:            result = await transaction.execute(query)            await transaction.set_related(item, result[0])        return result[0]    async def remove(self, id: Id) -> bool:        query = (            delete(self.table)            .where(self._id_filter_to_sql(id))            .returning(self.table.c.id)        )        async with self.transaction() as transaction:            result = await transaction.execute(query)        return bool(result)    def _filter_to_sql(self, filter: Filter) -> ColumnElement:        try:            column = getattr(self.table.c, filter.field)        except AttributeError:            return false()        if len(filter.values) == 0:            return false()        elif len(filter.values) == 1:            return column == filter.values[0]        else:            return column.in_(filter.values)    def _filters_to_sql(self, filters: List[Filter]) -> ColumnElement:        qs = [self._filter_to_sql(x) for x in filters]        if self.multitenant:            qs.append(self.table.c.tenant == self.current_tenant)        return and_(*qs)    def _id_filter_to_sql(self, id: Id) -> ColumnElement:        return self._filters_to_sql([Filter(field="id", values=[id])])    async def filter(        self, filters: List[Filter], params: Optional[PageOptions] = None    ) -> List[Json]:        query = select(self.table).where(self._filters_to_sql(filters))        if params is not None:            sort = asc(params.order_by) if params.ascending else desc(params.order_by)            query = query.order_by(sort).limit(params.limit).offset(params.offset)        async with self.transaction() as transaction:            result = await transaction.execute(query)            await transaction.get_related(result)        return result    async def count(self, filters: List[Filter]) -> int:        query = (            select(func.count().label("count"))            .select_from(self.table)            .where(self._filters_to_sql(filters))        )        async with self.transaction() as transaction:            return (await transaction.execute(query))[0]["count"]    async def exists(self, filters: List[Filter]) -> bool:        query = (            select(true().label("exists"))            .select_from(self.table)            .where(self._filters_to_sql(filters))            .limit(1)        )        async with self.transaction() as transaction:            return len(await transaction.execute(query)) > 0    async def _get_related_one_to_many(        self,        items: List[Json],        field_name: str,        fk_name: str,    ) -> None:        """Fetch related objects for `items` and add them inplace.        The result is `items` having an additional field containing a list of related        objects which were retrieved from self in 1 SELECT query.        Args:            items: The items for which to fetch related objects. Changed inplace.            field_name: The key in item to put the fetched related objects into.            fk_name: The column name on the related object that refers to item["id"]        Example:            Writer has a one-to-many relation to books.            >>> writers = [{"id": 2, "name": "John Doe"}]            >>> _get_related_one_to_many(                items=writers,                related_gateway=BookSQLGateway,                field_name="books",                fk_name="writer_id",            )            >>> writers[0]            {                "id": 2,                "name": "John Doe",                "books": [                    {                        "id": 1",                        "title": "How to write an ORM",                        "writer_id": 2                    }                ]            }        """        assert not self.multitenant        for x in items:            x[field_name] = []        item_lut = {x["id"]: x for x in items}        related_objs = await self.filter(            [Filter(field=fk_name, values=list(item_lut.keys()))]        )        for related_obj in related_objs:            item_lut[related_obj[fk_name]][field_name].append(related_obj)    async def _set_related_one_to_many(        self,        item: Json,        result: Json,        field_name: str,        fk_name: str,    ) -> None:        """Set related objects for `item`        This method first fetches the current situation and then adds / updates / removes        where appropriate.        Args:            item: The item for which to set related objects.            result: The dictionary to put the resulting (added / updated) objects into            field_name: The key in result to put the (added / updated) related objects into.            fk_name: The column name on the related object that refers to item["id"]        Example:            Writer has a one-to-many relation to books.            >>> writer = {"id": 2, "name": "John Doe", "books": {"title": "Foo"}}            >>> _set_related_one_to_many(                item=writer,                result=writer,                related_gateway=BookSQLGateway,                field_name="books",                fk_name="writer_id",            )            >>> result            {                "id": 2,                "name": "John Doe",                "books": [                    {                        "id": 1",                        "title": "Foo",                        "writer_id": 2                    }                ]            }        """        assert not self.multitenant        # list existing related objects        existing_lut = {            x["id"]: x            for x in await self.filter([Filter(field=fk_name, values=[result["id"]])])        }        # add / update them where necessary        returned = []        for new_value in item.get(field_name, []):            new_value = {fk_name: result["id"], **new_value}            existing = existing_lut.pop(new_value.get("id"), None)            if existing is None:                returned.append(await self.add(new_value))            elif new_value == existing:                returned.append(existing)            else:                returned.append(await self.update(new_value))        result[field_name] = returned        # remove remaining        for to_remove in existing_lut:            assert await self.remove(to_remove)
 |