celery_rmq_broker.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # (c) Nelen & Schuurmans
  2. import json
  3. import uuid
  4. from typing import Optional
  5. import pika
  6. from asgiref.sync import sync_to_async
  7. from pydantic import AnyUrl
  8. from clean_python import Gateway
  9. from clean_python import Json
  10. from clean_python import ValueObject
  11. __all__ = ["CeleryRmqBroker"]
  12. class CeleryHeaders(ValueObject):
  13. lang: str = "py"
  14. task: str
  15. id: uuid.UUID
  16. root_id: uuid.UUID
  17. parent_id: Optional[uuid.UUID] = None
  18. group: Optional[uuid.UUID] = None
  19. argsrepr: Optional[str] = None
  20. kwargsrepr: Optional[str] = None
  21. origin: Optional[str] = None
  22. def json_dict(self):
  23. return json.loads(self.model_dump_json())
  24. class CeleryRmqBroker(Gateway):
  25. def __init__(
  26. self, broker_url: AnyUrl, queue: str, origin: str, declare_queue: bool = False
  27. ):
  28. self._parameters = pika.URLParameters(str(broker_url))
  29. self._queue = queue
  30. self._origin = origin
  31. self._declare_queue = declare_queue
  32. @sync_to_async
  33. def add(self, item: Json) -> Json:
  34. task = item["task"]
  35. args = list(item.get("args") or [])
  36. kwargs = dict(item.get("kwargs") or {})
  37. task_id = uuid.uuid4()
  38. header = CeleryHeaders(
  39. task=task,
  40. id=task_id,
  41. root_id=task_id,
  42. argsrepr=json.dumps(args),
  43. kwargsrepr=json.dumps(kwargs),
  44. origin=self._origin,
  45. )
  46. body = json.dumps((args, kwargs, None))
  47. with pika.BlockingConnection(self._parameters) as connection:
  48. channel = connection.channel()
  49. if self._declare_queue:
  50. channel.queue_declare(queue=self._queue)
  51. else:
  52. pass # Configured by Lizard
  53. properties = pika.BasicProperties(
  54. correlation_id=str(task_id),
  55. content_type="application/json",
  56. content_encoding="utf-8",
  57. headers=header.json_dict(),
  58. )
  59. channel.basic_publish(
  60. exchange="",
  61. routing_key=self._queue,
  62. body=body,
  63. properties=properties,
  64. )
  65. return item