|
@@ -8,6 +8,7 @@ from typing import Optional
|
|
from typing import TypeVar
|
|
from typing import TypeVar
|
|
|
|
|
|
import inject
|
|
import inject
|
|
|
|
+from sqlalchemy import and_
|
|
from sqlalchemy import asc
|
|
from sqlalchemy import asc
|
|
from sqlalchemy import delete
|
|
from sqlalchemy import delete
|
|
from sqlalchemy import desc
|
|
from sqlalchemy import desc
|
|
@@ -24,6 +25,7 @@ from sqlalchemy.sql.expression import false
|
|
|
|
|
|
from clean_python import AlreadyExists
|
|
from clean_python import AlreadyExists
|
|
from clean_python import Conflict
|
|
from clean_python import Conflict
|
|
|
|
+from clean_python import ctx
|
|
from clean_python import DoesNotExist
|
|
from clean_python import DoesNotExist
|
|
from clean_python import Filter
|
|
from clean_python import Filter
|
|
from clean_python import Gateway
|
|
from clean_python import Gateway
|
|
@@ -50,9 +52,12 @@ T = TypeVar("T", bound="SQLGateway")
|
|
class SQLGateway(Gateway):
|
|
class SQLGateway(Gateway):
|
|
table: Table
|
|
table: Table
|
|
nested: bool
|
|
nested: bool
|
|
|
|
+ multitenant: bool
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
- self, provider_override: Optional[SQLProvider] = None, nested: bool = False
|
|
|
|
|
|
+ self,
|
|
|
|
+ provider_override: Optional[SQLProvider] = None,
|
|
|
|
+ nested: bool = False,
|
|
):
|
|
):
|
|
self.provider_override = provider_override
|
|
self.provider_override = provider_override
|
|
self.nested = nested
|
|
self.nested = nested
|
|
@@ -61,8 +66,11 @@ class SQLGateway(Gateway):
|
|
def provider(self):
|
|
def provider(self):
|
|
return self.provider_override or inject.instance(SQLDatabase)
|
|
return self.provider_override or inject.instance(SQLDatabase)
|
|
|
|
|
|
- def __init_subclass__(cls, table: Table) -> None:
|
|
|
|
|
|
+ def __init_subclass__(cls, table: Table, multitenant: bool = False) -> None:
|
|
cls.table = table
|
|
cls.table = table
|
|
|
|
+ if multitenant and not hasattr(table.c, "tenant"):
|
|
|
|
+ raise ValueError("Can't use a multitenant SQLGateway without tenant column")
|
|
|
|
+ cls.multitenant = multitenant
|
|
super().__init_subclass__()
|
|
super().__init_subclass__()
|
|
|
|
|
|
def rows_to_dict(self, rows: List[Json]) -> List[Json]:
|
|
def rows_to_dict(self, rows: List[Json]) -> List[Json]:
|
|
@@ -73,6 +81,8 @@ class SQLGateway(Gateway):
|
|
result = {k: obj[k] for k in obj.keys() if k in known}
|
|
result = {k: obj[k] for k in obj.keys() if k in known}
|
|
if "id" in result and result["id"] is None:
|
|
if "id" in result and result["id"] is None:
|
|
del result["id"]
|
|
del result["id"]
|
|
|
|
+ if self.multitenant:
|
|
|
|
+ result["tenant"] = self.current_tenant
|
|
return result
|
|
return result
|
|
|
|
|
|
@asynccontextmanager
|
|
@asynccontextmanager
|
|
@@ -83,6 +93,14 @@ class SQLGateway(Gateway):
|
|
async with self.provider.transaction() as provider:
|
|
async with self.provider.transaction() as provider:
|
|
yield self.__class__(provider, nested=True)
|
|
yield self.__class__(provider, nested=True)
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def current_tenant(self) -> Optional[int]:
|
|
|
|
+ if not self.multitenant:
|
|
|
|
+ return None
|
|
|
|
+ if ctx.tenant is None:
|
|
|
|
+ raise RuntimeError(f"{self.__class__} requires a tenant in the context")
|
|
|
|
+ return ctx.tenant.id
|
|
|
|
+
|
|
async def get_related(self, items: List[Json]) -> None:
|
|
async def get_related(self, items: List[Json]) -> None:
|
|
pass
|
|
pass
|
|
|
|
|
|
@@ -114,7 +132,7 @@ class SQLGateway(Gateway):
|
|
id_ = item.get("id")
|
|
id_ = item.get("id")
|
|
if id_ is None:
|
|
if id_ is None:
|
|
raise DoesNotExist("record", id_)
|
|
raise DoesNotExist("record", id_)
|
|
- q = self.table.c.id == id_
|
|
|
|
|
|
+ q = self._id_filter_to_sql(id_)
|
|
if if_unmodified_since is not None:
|
|
if if_unmodified_since is not None:
|
|
q &= self.table.c.updated_at == if_unmodified_since
|
|
q &= self.table.c.updated_at == if_unmodified_since
|
|
query = (
|
|
query = (
|
|
@@ -137,7 +155,7 @@ class SQLGateway(Gateway):
|
|
async def _select_for_update(self, id: int) -> Json:
|
|
async def _select_for_update(self, id: int) -> Json:
|
|
async with self.transaction() as transaction:
|
|
async with self.transaction() as transaction:
|
|
result = await transaction.execute(
|
|
result = await transaction.execute(
|
|
- select(self.table).with_for_update().where(self.table.c.id == id),
|
|
|
|
|
|
+ select(self.table).with_for_update().where(self._id_filter_to_sql(id)),
|
|
)
|
|
)
|
|
if not result:
|
|
if not result:
|
|
raise DoesNotExist("record", id)
|
|
raise DoesNotExist("record", id)
|
|
@@ -157,7 +175,10 @@ class SQLGateway(Gateway):
|
|
query = (
|
|
query = (
|
|
insert(self.table)
|
|
insert(self.table)
|
|
.values(**values)
|
|
.values(**values)
|
|
- .on_conflict_do_update(index_elements=["id"], set_=values)
|
|
|
|
|
|
+ .on_conflict_do_update(
|
|
|
|
+ index_elements=["id", "tenant"] if self.multitenant else ["id"],
|
|
|
|
+ set_=values,
|
|
|
|
+ )
|
|
.returning(self.table)
|
|
.returning(self.table)
|
|
)
|
|
)
|
|
async with self.transaction() as transaction:
|
|
async with self.transaction() as transaction:
|
|
@@ -167,13 +188,15 @@ class SQLGateway(Gateway):
|
|
|
|
|
|
async def remove(self, id) -> bool:
|
|
async def remove(self, id) -> bool:
|
|
query = (
|
|
query = (
|
|
- delete(self.table).where(self.table.c.id == id).returning(self.table.c.id)
|
|
|
|
|
|
+ delete(self.table)
|
|
|
|
+ .where(self._id_filter_to_sql(id))
|
|
|
|
+ .returning(self.table.c.id)
|
|
)
|
|
)
|
|
async with self.transaction() as transaction:
|
|
async with self.transaction() as transaction:
|
|
result = await transaction.execute(query)
|
|
result = await transaction.execute(query)
|
|
return bool(result)
|
|
return bool(result)
|
|
|
|
|
|
- def _to_sqlalchemy_expression(self, filter: Filter) -> ColumnElement:
|
|
|
|
|
|
+ def _filter_to_sql(self, filter: Filter) -> ColumnElement:
|
|
try:
|
|
try:
|
|
column = getattr(self.table.c, filter.field)
|
|
column = getattr(self.table.c, filter.field)
|
|
except AttributeError:
|
|
except AttributeError:
|
|
@@ -185,12 +208,19 @@ class SQLGateway(Gateway):
|
|
else:
|
|
else:
|
|
return column.in_(filter.values)
|
|
return column.in_(filter.values)
|
|
|
|
|
|
|
|
+ def _filters_to_sql(self, filters: List[Filter]) -> ColumnElement:
|
|
|
|
+ qs = [self._filter_to_sql(x) for x in filters]
|
|
|
|
+ if self.multitenant:
|
|
|
|
+ qs.append(self.table.c.tenant == self.current_tenant)
|
|
|
|
+ return and_(*qs)
|
|
|
|
+
|
|
|
|
+ def _id_filter_to_sql(self, id: int) -> ColumnElement:
|
|
|
|
+ return self._filters_to_sql([Filter(field="id", values=[id])])
|
|
|
|
+
|
|
async def filter(
|
|
async def filter(
|
|
self, filters: List[Filter], params: Optional[PageOptions] = None
|
|
self, filters: List[Filter], params: Optional[PageOptions] = None
|
|
) -> List[Json]:
|
|
) -> List[Json]:
|
|
- query = select(self.table).where(
|
|
|
|
- *[self._to_sqlalchemy_expression(x) for x in filters]
|
|
|
|
- )
|
|
|
|
|
|
+ query = select(self.table).where(self._filters_to_sql(filters))
|
|
if params is not None:
|
|
if params is not None:
|
|
sort = asc(params.order_by) if params.ascending else desc(params.order_by)
|
|
sort = asc(params.order_by) if params.ascending else desc(params.order_by)
|
|
query = query.order_by(sort).limit(params.limit).offset(params.offset)
|
|
query = query.order_by(sort).limit(params.limit).offset(params.offset)
|
|
@@ -203,7 +233,7 @@ class SQLGateway(Gateway):
|
|
query = (
|
|
query = (
|
|
select(func.count().label("count"))
|
|
select(func.count().label("count"))
|
|
.select_from(self.table)
|
|
.select_from(self.table)
|
|
- .where(*[self._to_sqlalchemy_expression(x) for x in filters])
|
|
|
|
|
|
+ .where(self._filters_to_sql(filters))
|
|
)
|
|
)
|
|
async with self.transaction() as transaction:
|
|
async with self.transaction() as transaction:
|
|
return (await transaction.execute(query))[0]["count"]
|
|
return (await transaction.execute(query))[0]["count"]
|
|
@@ -212,7 +242,7 @@ class SQLGateway(Gateway):
|
|
query = (
|
|
query = (
|
|
select(true().label("exists"))
|
|
select(true().label("exists"))
|
|
.select_from(self.table)
|
|
.select_from(self.table)
|
|
- .where(*[self._to_sqlalchemy_expression(x) for x in filters])
|
|
|
|
|
|
+ .where(self._filters_to_sql(filters))
|
|
.limit(1)
|
|
.limit(1)
|
|
)
|
|
)
|
|
async with self.transaction() as transaction:
|
|
async with self.transaction() as transaction:
|
|
@@ -257,6 +287,7 @@ class SQLGateway(Gateway):
|
|
]
|
|
]
|
|
}
|
|
}
|
|
"""
|
|
"""
|
|
|
|
+ assert not self.multitenant
|
|
for x in items:
|
|
for x in items:
|
|
x[field_name] = []
|
|
x[field_name] = []
|
|
item_lut = {x["id"]: x for x in items}
|
|
item_lut = {x["id"]: x for x in items}
|
|
@@ -308,7 +339,7 @@ class SQLGateway(Gateway):
|
|
]
|
|
]
|
|
}
|
|
}
|
|
"""
|
|
"""
|
|
-
|
|
|
|
|
|
+ assert not self.multitenant
|
|
# list existing related objects
|
|
# list existing related objects
|
|
existing_lut = {
|
|
existing_lut = {
|
|
x["id"]: x
|
|
x["id"]: x
|