Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions taskiq/cli/worker/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class WorkerArgs:
no_propagate_errors: bool = False
max_fails: int = -1
ack_type: AcknowledgeType = AcknowledgeType.WHEN_SAVED
max_tasks_per_child: Optional[int] = None
wait_tasks_timeout: Optional[float] = None

@classmethod
def from_cli(
Expand Down Expand Up @@ -197,6 +199,19 @@ def from_cli(
choices=[ack_type.name.lower() for ack_type in AcknowledgeType],
help="When to acknowledge message.",
)
parser.add_argument(
"--max-tasks-per-child",
type=int,
default=None,
help="Maximum number of tasks to execute per child process.",
)
parser.add_argument(
"--wait-tasks-timeout",
type=float,
default=None,
help="Maximum time to wait for all current tasks "
"to finish before exiting.",
)

namespace = parser.parse_args(args)
# If there are any patterns specified, remove default.
Expand Down
2 changes: 2 additions & 0 deletions taskiq/cli/worker/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
max_prefetch=args.max_prefetch,
propagate_exceptions=not args.no_propagate_errors,
ack_type=args.ack_type,
max_tasks_to_execute=args.max_tasks_per_child,
wait_tasks_timeout=args.wait_tasks_timeout,
**receiver_kwargs, # type: ignore
)
loop.run_until_complete(receiver.listen())
Expand Down
14 changes: 14 additions & 0 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(
run_starup: bool = True,
ack_type: Optional[AcknowledgeType] = None,
on_exit: Optional[Callable[["Receiver"], None]] = None,
max_tasks_to_execute: Optional[int] = None,
wait_tasks_timeout: Optional[float] = None,
) -> None:
self.broker = broker
self.executor = executor
Expand All @@ -68,6 +70,8 @@ def __init__(
self.on_exit = on_exit
self.ack_time = ack_type or AcknowledgeType.WHEN_SAVED
self.known_tasks: Set[str] = set()
self.max_tasks_to_execute = max_tasks_to_execute
self.wait_tasks_timeout = wait_tasks_timeout
for task in self.broker.get_all_tasks().values():
self._prepare_task(task.task_name, task.original_func)
self.sem: "Optional[asyncio.Semaphore]" = None
Expand Down Expand Up @@ -342,12 +346,20 @@ async def prefetcher(

:param queue: queue for prefetched data.
"""
fetched_tasks: int = 0
iterator = self.broker.listen()

while True:
try:
await self.sem_prefetch.acquire()
if (
self.max_tasks_to_execute
and fetched_tasks >= self.max_tasks_to_execute
):
logger.info("Max number of tasks executed.")
break
message = await iterator.__anext__()
fetched_tasks += 1
await queue.put(message)
except asyncio.CancelledError:
break
Expand Down Expand Up @@ -389,6 +401,8 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
self.sem_prefetch.release()
message = await queue.get()
if message is QUEUE_DONE:
logger.info("Waiting for running tasks to complete.")
await asyncio.wait(tasks, timeout=self.wait_tasks_timeout)
break

task = asyncio.create_task(
Expand Down