Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 40 additions & 53 deletions taskiq/cli/scheduler/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,45 +32,45 @@ def to_tz_aware(time: datetime) -> datetime:
return time


async def schedules_updater(
async def get_schedules(source: ScheduleSource) -> List[ScheduledTask]:
"""
Get schedules from source.

If source raises an exception, it will be
logged and an empty list will be returned.

:param source: source to get schedules from.
"""
try:
return await source.get_schedules()
except Exception as exc:
logger.warning(
"Cannot update schedules with source: %s",
source,
)
logger.debug(exc, exc_info=True)
return []


async def get_all_schedules(
scheduler: TaskiqScheduler,
current_schedules: Dict[ScheduleSource, List[ScheduledTask]],
event: asyncio.Event,
) -> None:
) -> Dict[ScheduleSource, List[ScheduledTask]]:
"""
Periodic update to schedules.
Task to update all schedules.

This task periodically checks for new schedules,
assembles the final list and replaces current
schedule with a new one.
This function updates all schedules
from all sources and returns a dict
with source as a key and list of
scheduled tasks as a value.

:param scheduler: current scheduler.
:param current_schedules: list of schedules.
:param event: event when schedules are updated.
:return: dict with source as a key and list of scheduled tasks as a value.
"""
while True:
logger.debug("Started schedule update.")
new_schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {}
for source in scheduler.sources:
try:
schedules = await source.get_schedules()
except Exception as exc:
logger.warning(
"Cannot update schedules with source: %s",
source,
)
logger.debug(exc, exc_info=True)
continue

new_schedules[source] = scheduler.merge_func(
new_schedules.get(source) or [],
schedules,
)

current_schedules.clear()
current_schedules.update(new_schedules)
event.set()
await asyncio.sleep(scheduler.refresh_delay)
logger.debug("Started schedule update.")
schedules = await asyncio.gather(
*[get_schedules(source) for source in scheduler.sources],
)
return dict(zip(scheduler.sources, schedules))


def get_task_delay(task: ScheduledTask) -> Optional[int]:
Expand Down Expand Up @@ -141,23 +141,14 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
:param scheduler: current scheduler.
"""
loop = asyncio.get_event_loop()
schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {}

current_task = asyncio.current_task()
first_update_event = asyncio.Event()
updater_task = loop.create_task(
schedules_updater(
scheduler,
schedules,
first_update_event,
),
)
if current_task is not None:
current_task.add_done_callback(lambda _: updater_task.cancel())
await first_update_event.wait()
running_schedules = set()
while True:
for source, task_list in schedules.items():
# We use this method to correctly sleep for one minute.
next_minute = datetime.now().replace(second=0, microsecond=0) + timedelta(
minutes=1,
)
scheduled_tasks = await get_all_schedules(scheduler)
for source, task_list in scheduled_tasks.items():
for task in task_list:
try:
task_delay = get_task_delay(task)
Expand All @@ -175,11 +166,7 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
running_schedules.add(send_task)
send_task.add_done_callback(running_schedules.discard)

delay = (
datetime.now().replace(second=1, microsecond=0)
+ timedelta(minutes=1)
- datetime.now()
)
delay = next_minute - datetime.now()
await asyncio.sleep(delay.total_seconds())


Expand Down
10 changes: 1 addition & 9 deletions taskiq/scheduler/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Callable, List
from typing import TYPE_CHECKING, List

from taskiq.kicker import AsyncKicker
from taskiq.scheduler.merge_functions import only_new
from taskiq.scheduler.scheduled_task import ScheduledTask
from taskiq.utils import maybe_awaitable

Expand All @@ -17,16 +16,9 @@ def __init__(
self,
broker: "AsyncBroker",
sources: List["ScheduleSource"],
merge_func: Callable[
[List["ScheduledTask"], List["ScheduledTask"]],
List["ScheduledTask"],
] = only_new,
refresh_delay: float = 30.0,
) -> None: # pragma: no cover
self.broker = broker
self.sources = sources
self.refresh_delay = refresh_delay
self.merge_func = merge_func

async def startup(self) -> None: # pragma: no cover
"""
Expand Down
87 changes: 87 additions & 0 deletions tests/cli/scheduler/test_updater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from datetime import datetime
from typing import List, Union

import pytest

from taskiq import InMemoryBroker, ScheduleSource
from taskiq.cli.scheduler.run import get_all_schedules
from taskiq.scheduler.scheduled_task import ScheduledTask
from taskiq.scheduler.scheduler import TaskiqScheduler


class DummySource(ScheduleSource):
def __init__(self, schedules: Union[Exception, List[ScheduledTask]]) -> None:
self.schedules = schedules

async def get_schedules(self) -> List[ScheduledTask]:
"""Return test schedules, or raise an exception."""
if isinstance(self.schedules, Exception):
raise self.schedules
return self.schedules


@pytest.mark.anyio
async def test_get_schedules_success() -> None:
"""Tests that schedules are returned correctly."""
schedules1 = [
ScheduledTask(
task_name="a",
labels={},
args=[],
kwargs={},
time=datetime.now(),
),
ScheduledTask(
task_name="b",
labels={},
args=[],
kwargs={},
time=datetime.now(),
),
]
schedules2 = [
ScheduledTask(
task_name="c",
labels={},
args=[],
kwargs={},
time=datetime.now(),
),
]
sources: List[ScheduleSource] = [
DummySource(schedules1),
DummySource(schedules2),
]

schedules = await get_all_schedules(
TaskiqScheduler(InMemoryBroker(), sources),
)
assert schedules == {
sources[0]: schedules1,
sources[1]: schedules2,
}


@pytest.mark.anyio
async def test_get_schedules_error() -> None:
"""Tests that if source returned an error, empty list will be returned."""
source1 = DummySource(
[
ScheduledTask(
task_name="a",
labels={},
args=[],
kwargs={},
time=datetime.now(),
),
],
)
source2 = DummySource(Exception("test"))

schedules = await get_all_schedules(
TaskiqScheduler(InMemoryBroker(), [source1, source2]),
)
assert schedules == {
source1: source1.schedules,
source2: [],
}