celery_rmq_broker.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # -*- coding: utf-8 -*-
  2. # (c) Nelen & Schuurmans
  3. import json
  4. import uuid
  5. from typing import Optional
  6. import pika
  7. from asgiref.sync import sync_to_async
  8. from pydantic import AnyUrl
  9. from clean_python.base.domain.value_object import ValueObject
  10. from clean_python.base.infrastructure.gateway import Gateway
  11. from clean_python.base.infrastructure.gateway import Json
  12. __all__ = ["CeleryRmqBroker"]
  13. class CeleryHeaders(ValueObject):
  14. lang: str = "py"
  15. task: str
  16. id: uuid.UUID
  17. root_id: uuid.UUID
  18. parent_id: Optional[uuid.UUID] = None
  19. group: Optional[uuid.UUID] = None
  20. argsrepr: Optional[str] = None
  21. kwargsrepr: Optional[str] = None
  22. origin: Optional[str] = None
  23. def json_dict(self):
  24. return json.loads(self.json())
  25. class CeleryRmqBroker(Gateway):
  26. def __init__(
  27. self, broker_url: AnyUrl, queue: str, origin: str, declare_queue: bool = False
  28. ):
  29. self._parameters = pika.URLParameters(broker_url)
  30. self._queue = queue
  31. self._origin = origin
  32. self._declare_queue = declare_queue
  33. @sync_to_async
  34. def add(self, item: Json) -> Json:
  35. task = item["task"]
  36. args = list(item.get("args") or [])
  37. kwargs = dict(item.get("kwargs") or {})
  38. task_id = uuid.uuid4()
  39. header = CeleryHeaders(
  40. task=task,
  41. id=task_id,
  42. root_id=task_id,
  43. argsrepr=json.dumps(args),
  44. kwargsrepr=json.dumps(kwargs),
  45. origin=self._origin,
  46. )
  47. body = json.dumps((args, kwargs, None))
  48. with pika.BlockingConnection(self._parameters) as connection:
  49. channel = connection.channel()
  50. if self._declare_queue:
  51. channel.queue_declare(queue=self._queue)
  52. else:
  53. pass # Configured by Lizard
  54. properties = pika.BasicProperties(
  55. correlation_id=str(task_id),
  56. content_type="application/json",
  57. content_encoding="utf-8",
  58. headers=header.json_dict(),
  59. )
  60. channel.basic_publish(
  61. exchange="",
  62. routing_key=self._queue,
  63. body=body,
  64. properties=properties,
  65. )
  66. return item