diff --git a/docs/guide/cli.md b/docs/guide/cli.md index 9fa70b52..2e7b927c 100644 --- a/docs/guide/cli.md +++ b/docs/guide/cli.md @@ -26,7 +26,7 @@ That's why taskiq can auto-discover tasks in current directory recursively. We have two options for this: - `--tasks-pattern` or `-tp`. - It's a name of files to import. By default is searches for all `tasks.py` files. + It's a glob pattern of files to import. By default it is `**/tasks.py` which searches for all `tasks.py` files. May be specified multiple times. - `--fs-discover` or `-fsd`. This option enables search of task files in current directory recursively, using the given pattern. ### Acknowledgements @@ -118,7 +118,7 @@ taskiq scheduler my_project.broker:scheduler my_project.module1 my_project.modul Path to scheduler is the only required argument. - `--tasks-pattern` or `-tp`. - It's a name of files to import. By default is searches for all `tasks.py` files. + It's a glob pattern of files to import. By default it is `**/tasks.py` which searches for all `tasks.py` files. May be specified multiple times. - `--fs-discover` or `-fsd`. This option enables search of task files in current directory recursively, using the given pattern. - `--no-configure-logging` - use this parameter if your application configures custom logging. - `--log-level` is used to set a log level (default `INFO`). diff --git a/taskiq/cli/scheduler/args.py b/taskiq/cli/scheduler/args.py index a596b61b..59230db8 100644 --- a/taskiq/cli/scheduler/args.py +++ b/taskiq/cli/scheduler/args.py @@ -15,7 +15,7 @@ class SchedulerArgs: log_level: str = LogLevel.INFO.name configure_logging: bool = True fs_discover: bool = False - tasks_pattern: str = "tasks.py" + tasks_pattern: Sequence[str] = ("**/tasks.py",) skip_first_run: bool = False @classmethod @@ -52,8 +52,9 @@ def from_cli(cls, args: Optional[Sequence[str]] = None) -> "SchedulerArgs": parser.add_argument( "--tasks-pattern", "-tp", - default="tasks.py", - help="Name of files in which taskiq will try to find modules.", + default=["**/tasks.py"], + action="append", + help="Glob patterns of files in which taskiq will try to find the tasks.", ) parser.add_argument( "--log-level", @@ -76,4 +77,10 @@ def from_cli(cls, args: Optional[Sequence[str]] = None) -> "SchedulerArgs": "This option skips running tasks immediately after scheduler start." ), ) - return cls(**parser.parse_args(args).__dict__) + + namespace = parser.parse_args(args) + # If there are any patterns specified, remove default. + # This is an argparse limitation. + if len(namespace.tasks_pattern) > 1: + namespace.tasks_pattern.pop(0) + return cls(**namespace.__dict__) diff --git a/taskiq/cli/utils.py b/taskiq/cli/utils.py index fd1e16c8..aa3e5918 100644 --- a/taskiq/cli/utils.py +++ b/taskiq/cli/utils.py @@ -4,7 +4,7 @@ from importlib import import_module from logging import getLogger from pathlib import Path -from typing import Any, Generator, List +from typing import Any, Generator, List, Sequence, Union from taskiq.utils import remove_suffix @@ -69,7 +69,11 @@ def import_from_modules(modules: List[str]) -> None: logger.exception(err) -def import_tasks(modules: List[str], pattern: str, fs_discover: bool) -> None: +def import_tasks( + modules: List[str], + pattern: Union[str, Sequence[str]], + fs_discover: bool, +) -> None: """ Import tasks modules. @@ -82,9 +86,14 @@ def import_tasks(modules: List[str], pattern: str, fs_discover: bool) -> None: from filesystem. """ if fs_discover: - for path in Path().rglob(pattern): - modules.append( - remove_suffix(str(path), ".py").replace(os.path.sep, "."), - ) - + if isinstance(pattern, str): + pattern = (pattern,) + discovered_modules = set() + for glob_pattern in pattern: + for path in Path().glob(glob_pattern): + discovered_modules.add( + remove_suffix(str(path), ".py").replace(os.path.sep, "."), + ) + + modules.extend(list(discovered_modules)) import_from_modules(modules) diff --git a/taskiq/cli/worker/args.py b/taskiq/cli/worker/args.py index 8de0b3cc..03a35c7a 100644 --- a/taskiq/cli/worker/args.py +++ b/taskiq/cli/worker/args.py @@ -26,7 +26,7 @@ class WorkerArgs: broker: str modules: List[str] - tasks_pattern: str = "tasks.py" + tasks_pattern: Sequence[str] = ("**/tasks.py",) fs_discover: bool = False configure_logging: bool = True log_level: LogLevel = LogLevel.INFO @@ -87,8 +87,9 @@ def from_cli( parser.add_argument( "--tasks-pattern", "-tp", - default="tasks.py", - help="Name of files in which taskiq will try to find modules.", + default=["**/tasks.py"], + action="append", + help="Glob patterns of files in which taskiq will try to find the tasks.", ) parser.add_argument( "modules", @@ -198,4 +199,8 @@ def from_cli( ) namespace = parser.parse_args(args) + # If there are any patterns specified, remove default. + # This is an argparse limitation. + if len(namespace.tasks_pattern) > 1: + namespace.tasks_pattern.pop(0) return WorkerArgs(**namespace.__dict__) diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py new file mode 100644 index 00000000..ebc85a8d --- /dev/null +++ b/tests/cli/test_utils.py @@ -0,0 +1,43 @@ +from unittest.mock import patch + +from taskiq.cli.utils import import_tasks + + +def test_import_tasks_list_pattern() -> None: + modules = ["taskiq.tasks"] + with patch("taskiq.cli.utils.import_from_modules", autospec=True) as mock: + import_tasks(modules, ["tests/**/test_utils.py"], True) + assert set(modules) == { + "taskiq.tasks", + "tests.test_utils", + "tests.cli.test_utils", + } + mock.assert_called_with(modules) + + +def test_import_tasks_str_pattern() -> None: + modules = ["taskiq.tasks"] + with patch("taskiq.cli.utils.import_from_modules", autospec=True) as mock: + import_tasks(modules, "tests/**/test_utils.py", True) + assert set(modules) == { + "taskiq.tasks", + "tests.test_utils", + "tests.cli.test_utils", + } + mock.assert_called_with(modules) + + +def test_import_tasks_empty_pattern() -> None: + modules = ["taskiq.tasks"] + with patch("taskiq.cli.utils.import_from_modules", autospec=True) as mock: + import_tasks(modules, [], True) + assert modules == ["taskiq.tasks"] + mock.assert_called_with(modules) + + +def test_import_tasks_no_discover() -> None: + modules = ["taskiq.tasks"] + with patch("taskiq.cli.utils.import_from_modules", autospec=True) as mock: + import_tasks(modules, "tests/**/test_utils.py", False) + assert modules == ["taskiq.tasks"] + mock.assert_called_with(modules)