diff --git a/taskiq/cli/worker/args.py b/taskiq/cli/worker/args.py index 4c225b07..7350bfa0 100644 --- a/taskiq/cli/worker/args.py +++ b/taskiq/cli/worker/args.py @@ -43,6 +43,8 @@ class WorkerArgs: no_propagate_errors: bool = False max_fails: int = -1 ack_type: AcknowledgeType = AcknowledgeType.WHEN_SAVED + max_tasks_per_child: Optional[int] = None + wait_tasks_timeout: Optional[float] = None @classmethod def from_cli( @@ -197,6 +199,19 @@ def from_cli( choices=[ack_type.name.lower() for ack_type in AcknowledgeType], help="When to acknowledge message.", ) + parser.add_argument( + "--max-tasks-per-child", + type=int, + default=None, + help="Maximum number of tasks to execute per child process.", + ) + parser.add_argument( + "--wait-tasks-timeout", + type=float, + default=None, + help="Maximum time to wait for all current tasks " + "to finish before exiting.", + ) namespace = parser.parse_args(args) # If there are any patterns specified, remove default. diff --git a/taskiq/cli/worker/run.py b/taskiq/cli/worker/run.py index fd253d77..d67b1271 100644 --- a/taskiq/cli/worker/run.py +++ b/taskiq/cli/worker/run.py @@ -140,6 +140,8 @@ def interrupt_handler(signum: int, _frame: Any) -> None: max_prefetch=args.max_prefetch, propagate_exceptions=not args.no_propagate_errors, ack_type=args.ack_type, + max_tasks_to_execute=args.max_tasks_per_child, + wait_tasks_timeout=args.wait_tasks_timeout, **receiver_kwargs, # type: ignore ) loop.run_until_complete(receiver.listen()) diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 73a56ed3..6394d09c 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -56,6 +56,8 @@ def __init__( run_starup: bool = True, ack_type: Optional[AcknowledgeType] = None, on_exit: Optional[Callable[["Receiver"], None]] = None, + max_tasks_to_execute: Optional[int] = None, + wait_tasks_timeout: Optional[float] = None, ) -> None: self.broker = broker self.executor = executor @@ -68,6 +70,8 @@ def __init__( self.on_exit = on_exit self.ack_time = ack_type or AcknowledgeType.WHEN_SAVED self.known_tasks: Set[str] = set() + self.max_tasks_to_execute = max_tasks_to_execute + self.wait_tasks_timeout = wait_tasks_timeout for task in self.broker.get_all_tasks().values(): self._prepare_task(task.task_name, task.original_func) self.sem: "Optional[asyncio.Semaphore]" = None @@ -342,12 +346,20 @@ async def prefetcher( :param queue: queue for prefetched data. """ + fetched_tasks: int = 0 iterator = self.broker.listen() while True: try: await self.sem_prefetch.acquire() + if ( + self.max_tasks_to_execute + and fetched_tasks >= self.max_tasks_to_execute + ): + logger.info("Max number of tasks executed.") + break message = await iterator.__anext__() + fetched_tasks += 1 await queue.put(message) except asyncio.CancelledError: break @@ -389,6 +401,8 @@ def task_cb(task: "asyncio.Task[Any]") -> None: self.sem_prefetch.release() message = await queue.get() if message is QUEUE_DONE: + logger.info("Waiting for running tasks to complete.") + await asyncio.wait(tasks, timeout=self.wait_tasks_timeout) break task = asyncio.create_task(