diff --git a/docs/examples/router/multiple_brokers.py b/docs/examples/router/multiple_brokers.py new file mode 100644 index 00000000..c1d5c1aa --- /dev/null +++ b/docs/examples/router/multiple_brokers.py @@ -0,0 +1,65 @@ +"""Route one task through several brokers with a shared router.""" + +import asyncio + +from taskiq import Flow, InMemoryBroker, TaskiqRouter + +router = TaskiqRouter() + +default_email_flow = Flow.queue("emails.default") +priority_email_flow = Flow.queue("emails.priority") +bulk_email_flow = Flow.queue("emails.bulk") + +default_broker = InMemoryBroker( + router=router, + broker_name="default", + default_flow=default_email_flow, + await_inplace=True, +) +priority_broker = InMemoryBroker( + router=router, + broker_name="priority", + default_flow=priority_email_flow, + await_inplace=True, +) + + +@default_broker.task(task_name="examples.send_email", domain="notifications") +async def send_email(user_id: int, template: str) -> str: + """Pretend to render and send an email.""" + return f"{template} email sent to user {user_id}" + + +router.route_task( + send_email.task_name, + broker="priority", + flow=priority_email_flow, +) + + +async def _main() -> None: + await default_broker.startup() + await priority_broker.startup() + try: + direct_result = await send_email(7, "welcome") + + routed_task = await send_email.kiq(7, "welcome") + routed_result = await routed_task.wait_result(timeout=2) + + bulk_task = await send_email.kicker().with_route( + "default", + bulk_email_flow, + ).kiq(8, "digest") + bulk_result = await bulk_task.wait_result(timeout=2) + + print(f"Direct call: {direct_result}") + print(f"Default route: {router.resolve_route(send_email.task_name)}") + print(f"Routed call: {routed_result.return_value}") + print(f"Route override: {bulk_result.return_value}") + finally: + await priority_broker.shutdown() + await default_broker.shutdown() + + +if __name__ == "__main__": + asyncio.run(_main()) diff --git a/docs/examples/router/shared_task_package.py b/docs/examples/router/shared_task_package.py new file mode 100644 index 00000000..0614679e --- /dev/null +++ b/docs/examples/router/shared_task_package.py @@ -0,0 +1,48 @@ +"""Declare shared task definitions and bind them in the final application.""" + +import asyncio + +from taskiq import Flow, InMemoryBroker, TaskiqRouter, task_builder + + +@task_builder("billing.calculate_total", domain="billing") +async def calculate_total(price: int, quantity: int) -> int: + """Package-level task definition that is not bound to any broker.""" + return price * quantity + + +router = TaskiqRouter() +billing_flow = Flow.queue("billing.tasks") +priority_billing_flow = Flow.queue("billing.priority") + +billing_broker = InMemoryBroker( + router=router, + broker_name="billing", + default_flow=billing_flow, + await_inplace=True, +) + +registered_calculate_total = billing_broker.register_task(calculate_total) + + +async def _main() -> None: + await billing_broker.startup() + try: + direct_result = await calculate_total.call(19, 3) + + prepared_task = registered_calculate_total.kicker().with_flow( + priority_billing_flow, + ).prepare(19, 3) + + queued_task = await prepared_task.kiq() + queued_result = await queued_task.wait_result(timeout=2) + + print(f"Shared task direct call: {direct_result}") + print(f"Prepared message: {prepared_task.message.task_name}") + print(f"Registered queued call: {queued_result.return_value}") + finally: + await billing_broker.shutdown() + + +if __name__ == "__main__": + asyncio.run(_main()) diff --git a/taskiq/__init__.py b/taskiq/__init__.py index 2414754f..856beb3c 100644 --- a/taskiq/__init__.py +++ b/taskiq/__init__.py @@ -24,7 +24,9 @@ TaskiqError, TaskiqResultTimeoutError, ) +from taskiq.flow import Flow, FlowKind from taskiq.funcs import gather +from taskiq.kicker import PreparedKiq from taskiq.message import BrokerMessage, TaskiqMessage from taskiq.middlewares import ( PrometheusMiddleware, @@ -32,10 +34,12 @@ SmartRetryMiddleware, ) from taskiq.result import TaskiqResult +from taskiq.router import TaskiqRoute, TaskiqRouter from taskiq.scheduler.scheduled_task import ScheduledTask from taskiq.scheduler.scheduler import TaskiqScheduler from taskiq.state import TaskiqState from taskiq.task import AsyncTaskiqTask +from taskiq.task_builder import TaskDefinition, task_builder __version__ = version("taskiq") @@ -47,8 +51,11 @@ "AsyncTaskiqTask", "BrokerMessage", "Context", + "Flow", + "FlowKind", "InMemoryBroker", "NoResultError", + "PreparedKiq", "PrometheusMiddleware", "ResultGetError", "ResultIsReadyError", @@ -58,6 +65,7 @@ "SendTaskError", "SimpleRetryMiddleware", "SmartRetryMiddleware", + "TaskDefinition", "TaskiqDepends", "TaskiqError", "TaskiqEvents", @@ -66,10 +74,13 @@ "TaskiqMiddleware", "TaskiqResult", "TaskiqResultTimeoutError", + "TaskiqRoute", + "TaskiqRouter", "TaskiqScheduler", "TaskiqState", "ZeroMQBroker", "__version__", "async_shared_broker", "gather", + "task_builder", ] diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index ea2e86c0..71cd1c79 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -24,11 +24,14 @@ from taskiq.decor import AsyncTaskiqDecoratedTask from taskiq.events import TaskiqEvents from taskiq.exceptions import TaskBrokerMismatchError +from taskiq.flow import Flow from taskiq.formatters.proxy_formatter import ProxyFormatter from taskiq.message import BrokerMessage from taskiq.result_backends.dummy import DummyResultBackend +from taskiq.router import TaskiqRouter from taskiq.serializers.json_serializer import JSONSerializer from taskiq.state import TaskiqState +from taskiq.task_builder import TaskDefinition from taskiq.utils import maybe_awaitable from taskiq.warnings import TaskiqDeprecationWarning @@ -78,6 +81,10 @@ def __init__( self, result_backend: "AsyncResultBackend[_T] | None" = None, task_id_generator: Callable[[], str] | None = None, + *, + router: TaskiqRouter | None = None, + broker_name: str | None = None, + default_flow: Flow | None = None, ) -> None: if result_backend is None: result_backend = DummyResultBackend() @@ -103,6 +110,13 @@ def __init__( self.serializer: TaskiqSerializer = JSONSerializer() self.formatter: TaskiqFormatter = ProxyFormatter(self) self.id_generator = task_id_generator + self.router = router or TaskiqRouter() + self.default_flow = default_flow + self.broker_name = self.router.set_broker( + self, + name=broker_name, + default_flow=default_flow, + ) self.local_task_registry: dict[str, AsyncTaskiqDecoratedTask[Any, Any]] = {} # Every event has a list of handlers. # Every handler is a function which takes state as a first argument. @@ -133,10 +147,14 @@ def find_task(self, task_name: str) -> AsyncTaskiqDecoratedTask[Any, Any] | None :param task_name: name of a task. :returns: found task or None. """ - return self.local_task_registry.get( - task_name, - ) or self.global_task_registry.get( - task_name, + return ( + self.local_task_registry.get( + task_name, + ) + or self.router.find_task(task_name) + or self.global_task_registry.get( + task_name, + ) ) def get_all_tasks(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: @@ -152,7 +170,11 @@ def get_all_tasks(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: :return: dict of all tasks. Keys are task names, values are tasks. """ - return {**self.global_task_registry, **self.local_task_registry} + return { + **self.global_task_registry, + **self.router.get_all_tasks(), + **self.local_task_registry, + } def add_dependency_context(self, new_ctx: dict[Any, Any]) -> None: """ @@ -237,6 +259,23 @@ async def kick( :param message: name of a task. """ + async def kick_to_flow( + self, + message: BrokerMessage, + flow: Flow | None = None, + ) -> None: + """ + Send message to a flow-aware broker. + + Existing brokers can keep implementing only `kick`. New brokers may + override this method and use `flow` to route to a concrete queue, topic, + stream or any other transport address. + + :param message: message to send. + :param flow: optional transport-neutral flow. + """ + await self.kick(message) + @abstractmethod def listen(self) -> AsyncGenerator[bytes | AckableMessage, None]: """ @@ -362,7 +401,8 @@ def inner( def register_task( self, - func: Callable[_FuncParams, _ReturnType], + func: Callable[_FuncParams, _ReturnType] + | TaskDefinition[_FuncParams, _ReturnType], task_name: str | None = None, **labels: Any, ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: @@ -380,6 +420,12 @@ def register_task( :returns: registered task. """ + if isinstance(func, TaskDefinition): + return self.router.register_task( + func, + broker=self, + flow=self.default_flow, + ) return self.task(task_name=task_name, **labels)(func) def on_event(self, *events: TaskiqEvents) -> Callable[[EventHandler], EventHandler]: @@ -533,6 +579,11 @@ def _register_task( if task.broker != self: raise TaskBrokerMismatchError(broker=task.broker) self.local_task_registry[task_name] = task + self.router.register_task( + task, + broker=self, + flow=self.default_flow, + ) async def __aenter__(self) -> None: """Starts the broker as ctx manager.""" diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index 0a7cc98e..f341d4dc 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -9,8 +9,10 @@ from taskiq.depends.progress_tracker import TaskProgress from taskiq.events import TaskiqEvents from taskiq.exceptions import UnknownTaskError +from taskiq.flow import Flow from taskiq.message import BrokerMessage from taskiq.receiver import Receiver +from taskiq.router import TaskiqRouter from taskiq.utils import maybe_awaitable _ReturnType = TypeVar("_ReturnType") @@ -130,8 +132,16 @@ def __init__( max_async_tasks_jitter: int = 0, propagate_exceptions: bool = True, await_inplace: bool = False, + *, + router: TaskiqRouter | None = None, + broker_name: str | None = None, + default_flow: Flow | None = None, ) -> None: - super().__init__() + super().__init__( + router=router, + broker_name=broker_name, + default_flow=default_flow, + ) self.result_backend: InmemoryResultBackend[Any] = InmemoryResultBackend( max_stored_results=max_stored_results, ) diff --git a/taskiq/context.py b/taskiq/context.py index f9f1d0ee..d29b8cb3 100644 --- a/taskiq/context.py +++ b/taskiq/context.py @@ -30,7 +30,7 @@ async def requeue(self) -> None: requeue_count = int(self.message.labels.get("X-Taskiq-requeue", 0)) requeue_count += 1 self.message.labels["X-Taskiq-requeue"] = str(requeue_count) - await self.broker.kick(self.broker.formatter.dumps(self.message)) + await self.broker.router.requeue(self.message, broker=self.broker) raise NoResultError def reject(self) -> None: diff --git a/taskiq/flow.py b/taskiq/flow.py new file mode 100644 index 00000000..5fde2b46 --- /dev/null +++ b/taskiq/flow.py @@ -0,0 +1,58 @@ +import enum +from dataclasses import dataclass, field, replace +from typing import Any + +__all__ = ("Flow", "FlowKind") + + +@enum.unique +class FlowKind(str, enum.Enum): + """Transport-neutral flow shape.""" + + QUEUE = "queue" + TOPIC = "topic" + STREAM = "stream" + + +@dataclass(frozen=True, slots=True) +class Flow: + """Transport-neutral publish or subscribe address. + + Plain flows are intentionally generic. Every broker may interpret a flow + using its own defaults: queue name, topic, stream, channel, list key, or any + other transport address. + + Broker packages can subclass this value object to expose transport-specific + details while still accepting plain Flow instances. + """ + + name: str + kind: FlowKind = FlowKind.QUEUE + options: dict[str, Any] = field( + default_factory=dict, + compare=False, + hash=False, + ) + + @classmethod + def queue(cls, name: str, **options: Any) -> "Flow": + """Create a queue-like flow.""" + return cls(name=name, kind=FlowKind.QUEUE, options=options) + + @classmethod + def topic(cls, name: str, **options: Any) -> "Flow": + """Create a topic-like flow.""" + return cls(name=name, kind=FlowKind.TOPIC, options=options) + + @classmethod + def stream(cls, name: str, **options: Any) -> "Flow": + """Create a stream-like flow.""" + return cls(name=name, kind=FlowKind.STREAM, options=options) + + def with_options(self, **options: Any) -> "Flow": + """Return the same flow with additional generic options.""" + return replace(self, options={**self.options, **options}) + + def broker_options(self, broker_name: str) -> dict[str, Any]: + """Return transport options for broker-specific implementations.""" + return dict(self.options) diff --git a/taskiq/kicker.py b/taskiq/kicker.py index dc113a7e..96ee9f88 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Coroutine from dataclasses import asdict, is_dataclass from datetime import datetime, timedelta @@ -9,17 +11,20 @@ Generic, ParamSpec, TypeVar, - Union, + cast, overload, ) from pydantic import BaseModel from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.abc.result_backend import AsyncResultBackend from taskiq.compat import model_dump from taskiq.exceptions import SendTaskError +from taskiq.flow import Flow from taskiq.labels import prepare_label from taskiq.message import TaskiqMessage +from taskiq.router import TaskiqRouter from taskiq.scheduler.created_schedule import CreatedSchedule from taskiq.scheduler.scheduled_task import CronSpec, ScheduledTask from taskiq.task import AsyncTaskiqTask @@ -36,13 +41,29 @@ logger = getLogger("taskiq") +class PreparedKiq(Generic[_ReturnType]): + """Prepared task invocation that can be sent later.""" + + def __init__( + self, + kicker: AsyncKicker[..., _ReturnType], + message: TaskiqMessage, + ) -> None: + self.kicker = kicker + self.message = message + + async def kiq(self) -> AsyncTaskiqTask[_ReturnType]: + """Send prepared invocation.""" + return await self.kicker.kiq_message(self.message) + + class AsyncKicker(Generic[_FuncParams, _ReturnType]): """Class that used to modify data before sending it to broker.""" def __init__( self, task_name: str, - broker: "AsyncBroker", + broker: AsyncBroker, labels: dict[str, Any], return_type: type[_ReturnType] | None = None, ) -> None: @@ -52,11 +73,13 @@ def __init__( self.custom_task_id: str | None = None self.custom_schedule_id: str | None = None self.return_type = return_type + self.route_broker: AsyncBroker | str | None = None + self.route_flow: Flow | None = None def with_labels( self, **labels: str | float, - ) -> "AsyncKicker[_FuncParams, _ReturnType]": + ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Update function's labels before sending. @@ -69,7 +92,7 @@ def with_labels( def with_task_id( self, task_id: str | None, - ) -> "AsyncKicker[_FuncParams, _ReturnType]": + ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Set task_id for current execution. @@ -85,7 +108,7 @@ def with_task_id( def with_schedule_id( self, schedule_id: str, - ) -> "AsyncKicker[_FuncParams, _ReturnType]": + ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Set schedule_id for current execution. @@ -97,8 +120,8 @@ def with_schedule_id( def with_broker( self, - broker: "AsyncBroker", - ) -> "AsyncKicker[_FuncParams, _ReturnType]": + broker: AsyncBroker, + ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Replace broker for the function. @@ -109,11 +132,55 @@ def with_broker( :return: Kicker with new broker. """ self.broker = broker + self.route_broker = broker + return self + + def with_flow( + self, + flow: Flow | None, + ) -> AsyncKicker[_FuncParams, _ReturnType]: + """ + Replace flow for the current invocation. + + :param flow: flow to send message to. + :return: Kicker with a route flow override. + """ + self.route_flow = flow return self + def with_route( + self, + broker: AsyncBroker | str, + flow: Flow | None, + ) -> AsyncKicker[_FuncParams, _ReturnType]: + """ + Replace broker and flow for the current invocation. + + :param broker: broker instance or broker name. + :param flow: flow to send message to. + :return: Kicker with a route override. + """ + self.route_broker = broker + self.route_flow = flow + return self + + def prepare( + self, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> PreparedKiq[_ReturnType]: + """ + Prepare a task invocation without sending it. + + :param args: function's arguments. + :param kwargs: function's key word arguments. + :return: prepared task invocation. + """ + return PreparedKiq(self, self._prepare_message(*args, **kwargs)) + @overload async def kiq( - self: "AsyncKicker[_FuncParams, CoroutineType[Any, Any, _T]]", + self: AsyncKicker[_FuncParams, CoroutineType[Any, Any, _T]], *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> AsyncTaskiqTask[_T]: # pragma: no cover @@ -121,7 +188,7 @@ async def kiq( @overload async def kiq( - self: "AsyncKicker[_FuncParams, Coroutine[Any, Any, _T]]", + self: AsyncKicker[_FuncParams, Coroutine[Any, Any, _T]], *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> AsyncTaskiqTask[_T]: # pragma: no cover @@ -129,7 +196,7 @@ async def kiq( @overload async def kiq( - self: "AsyncKicker[_FuncParams, _ReturnType]", + self: AsyncKicker[_FuncParams, _ReturnType], *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> AsyncTaskiqTask[_ReturnType]: # pragma: no cover @@ -156,29 +223,60 @@ async def kiq( logger.debug( f"Kicking {self.task_name} with args={args} and kwargs={kwargs}.", ) - message = self._prepare_message(*args, **kwargs) - for middleware in self.broker.middlewares: - if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: - message = await maybe_awaitable(middleware.pre_send(message)) + return await self.kiq_message(self._prepare_message(*args, **kwargs)) + + async def kiq_message( + self, + message: TaskiqMessage, + ) -> AsyncTaskiqTask[_ReturnType]: + """Send a prepared message.""" try: - await self.broker.kick(self.broker.formatter.dumps(message)) + router = getattr(self.broker, "router", None) + if isinstance(router, TaskiqRouter): + return await router.kiq( + message, + broker=self.route_broker, + flow=self.route_flow, + return_type=self.return_type, + ) + return await self._legacy_kiq(message) except Exception as exc: raise SendTaskError from exc - for middleware in reversed(self.broker.middlewares): + async def _legacy_kiq(self, message: TaskiqMessage) -> AsyncTaskiqTask[_ReturnType]: + """ + Send message through the pre-router broker path. + + This keeps middleware tests and external broker-like mocks compatible + while real AsyncBroker instances use TaskiqRouter. + """ + middlewares = getattr(self.broker, "middlewares", []) + if not isinstance(middlewares, list): + middlewares = [] + + for middleware in middlewares: + if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: + message = await maybe_awaitable(middleware.pre_send(message)) + + await self.broker.kick(self.broker.formatter.dumps(message)) + + for middleware in reversed(middlewares): if middleware.__class__.post_send != TaskiqMiddleware.post_send: await maybe_awaitable(middleware.post_send(message)) return AsyncTaskiqTask( task_id=message.task_id, - result_backend=self.broker.result_backend, - return_type=self.return_type, # type: ignore # (pyright issue) + result_backend=cast( + AsyncResultBackend[_ReturnType], + self.broker.result_backend, + ), + return_type=self.return_type, ) async def schedule_by_cron( self, - source: "ScheduleSource", - cron: Union[str, "CronSpec"], + source: ScheduleSource, + cron: str | CronSpec, *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> CreatedSchedule[_ReturnType]: @@ -217,7 +315,7 @@ async def schedule_by_cron( async def schedule_by_interval( self, - source: "ScheduleSource", + source: ScheduleSource, interval: int | timedelta, *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, @@ -249,7 +347,7 @@ async def schedule_by_interval( async def schedule_by_time( self, - source: "ScheduleSource", + source: ScheduleSource, time: datetime, *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, diff --git a/taskiq/router.py b/taskiq/router.py new file mode 100644 index 00000000..07129022 --- /dev/null +++ b/taskiq/router.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from logging import getLogger +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload + +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.abc.result_backend import AsyncResultBackend +from taskiq.flow import Flow +from taskiq.message import TaskiqMessage +from taskiq.task import AsyncTaskiqTask +from taskiq.task_builder import TaskDefinition +from taskiq.utils import maybe_awaitable + +if TYPE_CHECKING: # pragma: no cover + from taskiq.abc.broker import AsyncBroker + from taskiq.decor import AsyncTaskiqDecoratedTask + +__all__ = ("TaskiqRoute", "TaskiqRouter") + +_FuncParams = ParamSpec("_FuncParams") +_ReturnType = TypeVar("_ReturnType") + +logger = getLogger("taskiq.router") + + +@dataclass(frozen=True, slots=True) +class TaskiqRoute: + """Resolved outbound route for a task invocation.""" + + broker_name: str + flow: Flow | None = None + + +class TaskiqRouter: + """Registry and routing layer shared by one or more brokers.""" + + def __init__(self) -> None: + self.brokers: dict[str, AsyncBroker] = {} + self.default_broker_name: str | None = None + self.task_registry: dict[str, AsyncTaskiqDecoratedTask[Any, Any]] = {} + self.routes: dict[str, TaskiqRoute] = {} + + def set_broker( + self, + broker: AsyncBroker, + name: str | None = None, + default_flow: Flow | None = None, + ) -> str: + """Register broker as a transport in this router.""" + broker_name = name or broker.__class__.__name__ + registered = self.brokers.get(broker_name) + if registered is not None and registered is not broker: + raise ValueError( + f"Broker name {broker_name!r} is already registered. " + "Please provide an explicit unique broker_name.", + ) + self.brokers[broker_name] = broker + if self.default_broker_name is None: + self.default_broker_name = broker_name + return broker_name + + def find_task( + self, + task_name: str, + ) -> AsyncTaskiqDecoratedTask[Any, Any] | None: + """Find a task by name.""" + return self.task_registry.get(task_name) + + def get_all_tasks(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: + """Return all tasks registered in this router.""" + return dict(self.task_registry) + + def register_task( + self, + task: ( + AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType] + | TaskDefinition[_FuncParams, _ReturnType] + ), + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + """Register a bound task or bind a task definition to a broker.""" + if isinstance(task, TaskDefinition): + target_broker = self._resolve_broker(broker) + registered_task = target_broker.register_task( + task.original_func, + task_name=task.task_name, + **task.labels, + ) + if flow is not None: + self.route_task(task.task_name, broker=target_broker, flow=flow) + return registered_task + + self.task_registry[task.task_name] = task + route_broker: AsyncBroker | str | None = broker + if route_broker is None: + route_broker = getattr(task, "broker", None) + if route_broker is not None or flow is not None: + self.route_task(task.task_name, broker=route_broker, flow=flow) + return task + + @overload + def task( + self, + task_name: Callable[_FuncParams, _ReturnType], + *, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + **labels: Any, + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: ... + + @overload + def task( + self, + task_name: str | None = None, + *, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + **labels: Any, + ) -> Callable[ + [Callable[_FuncParams, _ReturnType]], + AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType], + ]: ... + + def task( + self, + task_name: str | Callable[_FuncParams, _ReturnType] | None = None, + *, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + **labels: Any, + ) -> Any: + """Decorate and register a task through this router.""" + + def register( + func: Callable[_FuncParams, _ReturnType], + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + target_broker = self._resolve_broker(broker) + real_task_name = task_name if not callable(task_name) else None + task = target_broker.task(task_name=real_task_name, **labels)(func) + if flow is not None: + self.route_task(task.task_name, broker=target_broker, flow=flow) + return task + + if callable(task_name): + function = task_name + task_name = None + return register(function) + + return register + + def route_task( + self, + task_name: str, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + ) -> TaskiqRoute: + """Set default outbound route for a task.""" + broker_name = self._resolve_broker_name(broker) + route = TaskiqRoute(broker_name=broker_name, flow=flow) + self.routes[task_name] = route + return route + + def resolve_route( + self, + task_name: str, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + ) -> TaskiqRoute: + """Resolve outbound route for a task invocation.""" + if broker is not None: + broker_name = self._resolve_broker_name(broker) + route_flow = flow + if route_flow is None: + route_flow = self._broker_default_flow(broker_name) + return TaskiqRoute( + broker_name=broker_name, + flow=route_flow, + ) + + route = self.routes.get(task_name) + if route is not None: + if flow is None: + return route + return TaskiqRoute(broker_name=route.broker_name, flow=flow) + + broker_name = self._resolve_broker_name(None) + route_flow = flow + if route_flow is None: + route_flow = self._broker_default_flow(broker_name) + return TaskiqRoute( + broker_name=broker_name, + flow=route_flow, + ) + + async def kiq( + self, + message: TaskiqMessage, + *, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + return_type: type[_ReturnType] | None = None, + ) -> AsyncTaskiqTask[_ReturnType]: + """Send message through the resolved broker and flow.""" + route = self.resolve_route(message.task_name, broker=broker, flow=flow) + target_broker = self.brokers[route.broker_name] + + for middleware in target_broker.middlewares: + if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: + message = await maybe_awaitable(middleware.pre_send(message)) + broker_message = target_broker.formatter.dumps(message) + await target_broker.kick_to_flow(broker_message, route.flow) + + for middleware in reversed(target_broker.middlewares): + if middleware.__class__.post_send != TaskiqMiddleware.post_send: + await maybe_awaitable(middleware.post_send(message)) + + return AsyncTaskiqTask( + task_id=message.task_id, + result_backend=cast( + AsyncResultBackend[_ReturnType], + target_broker.result_backend, + ), + return_type=return_type, + ) + + async def requeue( + self, + message: TaskiqMessage, + *, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + ) -> None: + """Send an existing message again through the resolved route.""" + route = self.resolve_route(message.task_name, broker=broker, flow=flow) + target_broker = self.brokers[route.broker_name] + await target_broker.kick_to_flow( + target_broker.formatter.dumps(message), + route.flow, + ) + + def _resolve_broker(self, broker: AsyncBroker | str | None) -> AsyncBroker: + broker_name = self._resolve_broker_name(broker) + return self.brokers[broker_name] + + def _resolve_broker_name(self, broker: AsyncBroker | str | None) -> str: + if isinstance(broker, str): + if broker not in self.brokers: + raise ValueError(f"Unknown broker {broker!r}.") + return broker + + if broker is not None: + broker_name = getattr(broker, "broker_name", None) + if broker_name is not None and broker_name in self.brokers: + return broker_name + for registered_name, registered_broker in self.brokers.items(): + if registered_broker is broker: + return registered_name + raise ValueError("Broker is not registered in this router.") + + if self.default_broker_name is None: + raise ValueError("Router doesn't have registered brokers.") + return self.default_broker_name + + def _broker_default_flow(self, broker_name: str) -> Flow | None: + return getattr(self.brokers[broker_name], "default_flow", None) diff --git a/taskiq/task_builder.py b/taskiq/task_builder.py new file mode 100644 index 00000000..d2d3b0d1 --- /dev/null +++ b/taskiq/task_builder.py @@ -0,0 +1,107 @@ +import inspect +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Generic, ParamSpec, TypeVar, overload + +from taskiq.message import TaskiqMessage + +__all__ = ("TaskDefinition", "task_builder") + +_FuncParams = ParamSpec("_FuncParams") +_ReturnType = TypeVar("_ReturnType") + + +@dataclass(frozen=True, slots=True) +class TaskDefinition(Generic[_FuncParams, _ReturnType]): + """Unbound task declaration that can be registered later.""" + + task_name: str + original_func: Callable[_FuncParams, _ReturnType] + labels: dict[str, Any] = field(default_factory=dict) + return_type: type[_ReturnType] | None = None + + def __call__( + self, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> _ReturnType: + """Call original function directly.""" + return self.original_func(*args, **kwargs) + + async def call( + self, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> _ReturnType: + """Execute original function in the current process.""" + result = self.original_func(*args, **kwargs) + if inspect.isawaitable(result): + return await result + return result + + def message( + self, + task_id: str, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> TaskiqMessage: + """Build a TaskiqMessage without binding this definition to a router.""" + return TaskiqMessage( + task_id=task_id, + task_name=self.task_name, + labels=dict(self.labels), + args=list(args), + kwargs=dict(kwargs), + ) + + +@overload +def task_builder( + task_name: Callable[_FuncParams, _ReturnType], + **labels: Any, +) -> TaskDefinition[_FuncParams, _ReturnType]: ... + + +@overload +def task_builder( + task_name: str | None = None, + **labels: Any, +) -> Callable[ + [Callable[_FuncParams, _ReturnType]], + TaskDefinition[_FuncParams, _ReturnType], +]: ... + + +def task_builder( + task_name: str | Callable[_FuncParams, _ReturnType] | None = None, + **labels: Any, +) -> Any: + """Build an unbound task definition. + + This decorator is intended for library/package tasks that should be + registered by the final application. + """ + + def build( + func: Callable[_FuncParams, _ReturnType], + ) -> TaskDefinition[_FuncParams, _ReturnType]: + real_task_name = task_name + if real_task_name is None or callable(real_task_name): + real_task_name = f"{func.__module__}:{func.__name__}" + return_type = None + signature = inspect.signature(func) + if signature.return_annotation is not inspect.Signature.empty: + return_type = signature.return_annotation + return TaskDefinition( + task_name=real_task_name, + original_func=func, + labels=dict(labels), + return_type=return_type, + ) + + if callable(task_name): + function = task_name + task_name = None + return build(function) + + return build diff --git a/tests/test_router.py b/tests/test_router.py new file mode 100644 index 00000000..00b1883d --- /dev/null +++ b/tests/test_router.py @@ -0,0 +1,164 @@ +from collections.abc import AsyncGenerator + +import pytest + +from taskiq import Flow, TaskiqRouter, task_builder +from taskiq.abc.broker import AsyncBroker +from taskiq.message import BrokerMessage + + +class RecordingBroker(AsyncBroker): + """Broker that records sent messages and flows.""" + + def __init__( + self, + *, + router: TaskiqRouter | None = None, + broker_name: str | None = None, + default_flow: Flow | None = None, + ) -> None: + self.sent: list[tuple[BrokerMessage, Flow | None]] = [] + super().__init__( + router=router, + broker_name=broker_name, + default_flow=default_flow, + ) + + async def kick(self, message: BrokerMessage) -> None: + """Record old-style send.""" + self.sent.append((message, None)) + + async def kick_to_flow( + self, + message: BrokerMessage, + flow: Flow | None = None, + ) -> None: + """Record flow-aware send.""" + self.sent.append((message, flow)) + + async def listen(self) -> AsyncGenerator[bytes, None]: + """Recording broker doesn't listen in these tests.""" + if False: + yield b"" + + +def test_broker_creates_default_router() -> None: + broker = RecordingBroker() + + assert broker.router.brokers[broker.broker_name] is broker + assert broker.router.default_broker_name == broker.broker_name + + +async def test_old_broker_task_api_registers_task_in_router() -> None: + broker = RecordingBroker() + + @broker.task(task_name="demo.task") + async def demo_task() -> None: + return None + + assert broker.find_task("demo.task") is demo_task + assert broker.router.find_task("demo.task") is demo_task + + await demo_task.kiq() + + assert broker.sent[0][0].task_name == "demo.task" + assert broker.sent[0][1] is None + + +async def test_router_can_route_task_to_another_broker_flow() -> None: + router = TaskiqRouter() + source = RecordingBroker(router=router, broker_name="source") + target = RecordingBroker(router=router, broker_name="target") + flow = Flow("events") + + @source.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task("demo.task", broker="target", flow=flow) + + await demo_task.kiq() + + assert source.sent == [] + assert target.sent[0][0].task_name == "demo.task" + assert target.sent[0][1] == flow + + +async def test_kicker_route_override_wins_over_registered_route() -> None: + router = TaskiqRouter() + first = RecordingBroker(router=router, broker_name="first") + second = RecordingBroker(router=router, broker_name="second") + first_flow = Flow("first") + second_flow = Flow("second") + + @first.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task("demo.task", broker="first", flow=first_flow) + + await demo_task.kicker().with_route("second", second_flow).kiq() + + assert first.sent == [] + assert second.sent[0][1] == second_flow + + +async def test_kicker_can_prepare_invocation_for_later() -> None: + broker = RecordingBroker() + + @broker.task(task_name="demo.task") + async def demo_task(value: int) -> None: + return None + + prepared = demo_task.kicker().with_labels(trace_id="abc").prepare(1) + + assert prepared.message.task_name == "demo.task" + assert prepared.message.args == [1] + assert prepared.message.labels["trace_id"] == "abc" + + await prepared.kiq() + + assert broker.sent[0][0].task_id == prepared.message.task_id + + +async def test_task_builder_can_be_registered_later() -> None: + broker = RecordingBroker() + + @task_builder("shared.add", queue="shared") + def add(left: int, right: int) -> int: + return left + right + + assert await add.call(1, 2) == 3 + + registered = broker.register_task(add) + + assert registered.task_name == "shared.add" + assert registered.labels == {"queue": "shared"} + assert broker.router.find_task("shared.add") is registered + + await registered.kiq(1, 2) + + assert broker.sent[0][0].task_name == "shared.add" + + +async def test_router_task_decorator_can_choose_broker_and_flow() -> None: + router = TaskiqRouter() + target = RecordingBroker(router=router, broker_name="target") + flow = Flow("target-flow") + + @router.task("demo.task", broker="target", flow=flow) + async def demo_task() -> None: + return None + + await demo_task.kiq() + + assert target.sent[0][0].task_name == "demo.task" + assert target.sent[0][1] == flow + + +def test_router_rejects_duplicate_broker_names() -> None: + router = TaskiqRouter() + RecordingBroker(router=router, broker_name="broker") + + with pytest.raises(ValueError, match="already registered"): + RecordingBroker(router=router, broker_name="broker")