diff --git a/mypy.ini b/mypy.ini index 886b0fce2ce..f41ebd8b1ce 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1,6 @@ # Automatically generated by hassfest. # -# To update, run python3 -m script.hassfest +# To update, run python3 -m script.hassfest -p mypy_config [mypy] python_version = 3.9 diff --git a/script/hassfest/__main__.py b/script/hassfest/__main__.py index d4935196cc7..8a8e1155ab9 100644 --- a/script/hassfest/__main__.py +++ b/script/hassfest/__main__.py @@ -43,6 +43,11 @@ HASS_PLUGINS = [ mypy_config, ] +ALL_PLUGIN_NAMES = [ + plugin.__name__.rsplit(".", maxsplit=1)[-1] + for plugin in (*INTEGRATION_PLUGINS, *HASS_PLUGINS) +] + def valid_integration_path(integration_path): """Test if it's a valid integration.""" @@ -53,6 +58,17 @@ def valid_integration_path(integration_path): return path +def validate_plugins(plugin_names: str) -> list[str]: + """Split and validate plugin names.""" + all_plugin_names = set(ALL_PLUGIN_NAMES) + plugins = plugin_names.split(",") + for plugin in plugins: + if plugin not in all_plugin_names: + raise argparse.ArgumentTypeError(f"{plugin} is not a valid plugin name") + + return plugins + + def get_config() -> Config: """Return config.""" parser = argparse.ArgumentParser(description="Hassfest") @@ -70,6 +86,13 @@ def get_config() -> Config: action="store_true", help="Validate requirements", ) + parser.add_argument( + "-p", + "--plugins", + type=validate_plugins, + default=ALL_PLUGIN_NAMES, + help="Comma-separate list of plugins to run. Valid plugin names: %(default)s", + ) parsed = parser.parse_args() if parsed.action is None: @@ -91,6 +114,7 @@ def get_config() -> Config: specific_integrations=parsed.integration_path, action=parsed.action, requirements=parsed.requirements, + plugins=set(parsed.plugins), ) @@ -117,9 +141,12 @@ def main(): plugins += HASS_PLUGINS for plugin in plugins: + plugin_name = plugin.__name__.rsplit(".", maxsplit=1)[-1] + if plugin_name not in config.plugins: + continue try: start = monotonic() - print(f"Validating {plugin.__name__.split('.')[-1]}...", end="", flush=True) + print(f"Validating {plugin_name}...", end="", flush=True) if ( plugin is requirements and config.requirements @@ -161,6 +188,9 @@ def main(): if config.action == "generate": for plugin in plugins: + plugin_name = plugin.__name__.rsplit(".", maxsplit=1)[-1] + if plugin_name not in config.plugins: + continue if hasattr(plugin, "generate"): plugin.generate(integrations, config) return 0 diff --git a/script/hassfest/model.py b/script/hassfest/model.py index 69810686cc1..7006c1e6032 100644 --- a/script/hassfest/model.py +++ b/script/hassfest/model.py @@ -32,6 +32,7 @@ class Config: requirements: bool = attr.ib() errors: list[Error] = attr.ib(factory=list) cache: dict[str, Any] = attr.ib(factory=dict) + plugins: set[str] = attr.ib(factory=set) def add_error(self, *args: Any, **kwargs: Any) -> None: """Add an error.""" diff --git a/script/hassfest/mypy_config.py b/script/hassfest/mypy_config.py index d2bd437c2d9..06c1353ce73 100644 --- a/script/hassfest/mypy_config.py +++ b/script/hassfest/mypy_config.py @@ -90,7 +90,7 @@ NO_IMPLICIT_REEXPORT_MODULES: set[str] = { HEADER: Final = """ # Automatically generated by hassfest. # -# To update, run python3 -m script.hassfest +# To update, run python3 -m script.hassfest -p mypy_config """.lstrip()