Skip to content

Commit

Permalink
Minor refactor of backend code (#702)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer committed Apr 30, 2023
1 parent 665dfb8 commit fa287a2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
8 changes: 8 additions & 0 deletions aiohttp_admin/backends/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ async def _get_many(self, request: web.Request) -> web.Response:
query = parse_obj_as(GetManyParams, request.query)

results = await self.get_many(query)
if not results:
raise web.HTTPNotFound()

results = [await self.filter_by_permissions(request, "view", r) for r in results
if await permits(request, f"admin.{self.name}.view", context=(request, r))]
for r in results:
Expand Down Expand Up @@ -255,6 +258,8 @@ async def _update_many(self, request: web.Request) -> web.Response:

# Check original records are allowed by permission filters.
originals = await self.get_many({"ids": query["ids"]})
if not originals:
raise web.HTTPNotFound()
allowed = (permits(request, f"admin.{self.name}.edit", context=(request, r))
for r in originals)
allowed_f = (permits(request, f"admin.{self.name}.{k}.edit", context=(request, r))
Expand All @@ -266,6 +271,7 @@ async def _update_many(self, request: web.Request) -> web.Response:
raise web.HTTPForbidden()

ids = await self.update_many(query)
# get_many() is called above, so we can be sure there will be results here.
return json_response({"data": ids})

async def _delete(self, request: web.Request) -> web.Response:
Expand Down Expand Up @@ -293,6 +299,8 @@ async def _delete_many(self, request: web.Request) -> web.Response:
raise web.HTTPForbidden()

ids = await self.delete_many(query)
if not ids:
raise web.HTTPNotFound()
return json_response({"data": ids})

@cached_property
Expand Down
15 changes: 3 additions & 12 deletions aiohttp_admin/backends/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,7 @@ async def get_many(self, params: GetManyParams) -> list[Record]:
async with self._db.connect() as conn:
stmt = sa.select(self._table).where(self._table.c[self.primary_key].in_(params["ids"]))
result = await conn.execute(stmt)
records = [r._asdict() for r in result]
if records:
return records
raise web.HTTPNotFound()
return [r._asdict() for r in result]

@handle_errors
async def create(self, params: CreateParams) -> Record:
Expand All @@ -241,10 +238,7 @@ async def update_many(self, params: UpdateManyParams) -> list[Union[str, int]]:
async with self._db.begin() as conn:
stmt = sa.update(self._table).where(self._table.c[self.primary_key].in_(params["ids"]))
stmt = stmt.values(params["data"]).returning(self._table.c[self.primary_key])
r = await conn.scalars(stmt)
# The security check has already called get_many(), so we can be sure
# there will be results here.
return list(r)
return list(await conn.scalars(stmt))

@handle_errors
async def delete(self, params: DeleteParams) -> Record:
Expand All @@ -258,10 +252,7 @@ async def delete_many(self, params: DeleteManyParams) -> list[Union[str, int]]:
async with self._db.begin() as conn:
stmt = sa.delete(self._table).where(self._table.c[self.primary_key].in_(params["ids"]))
r = await conn.scalars(stmt.returning(self._table.c[self.primary_key]))
ids = list(r)
if ids:
return ids
raise web.HTTPNotFound()
return list(r)

def _get_validators(
self, table: sa.Table, c: sa.Column[object]
Expand Down

0 comments on commit fa287a2

Please sign in to comment.