From 95107cf6708d11891b92572c4d4e01a5833e079f Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Tue, 10 Dec 2024 13:07:08 +0100 Subject: [PATCH] Add check for typed ConfigEntry in quality scale validation (#132028) --- script/hassfest/quality_scale.py | 11 ++- .../quality_scale_validation/__init__.py | 4 +- .../config_entry_unloading.py | 2 +- .../quality_scale_validation/config_flow.py | 2 +- .../quality_scale_validation/diagnostics.py | 2 +- .../quality_scale_validation/discovery.py | 2 +- .../parallel_updates.py | 2 +- .../reauthentication_flow.py | 2 +- .../reconfiguration_flow.py | 2 +- .../quality_scale_validation/runtime_data.py | 90 +++++++++++++++++-- .../quality_scale_validation/strict_typing.py | 2 +- .../unique_config_entry.py | 2 +- 12 files changed, 101 insertions(+), 22 deletions(-) diff --git a/script/hassfest/quality_scale.py b/script/hassfest/quality_scale.py index ff67bbbe416..9f6d1e0b783 100644 --- a/script/hassfest/quality_scale.py +++ b/script/hassfest/quality_scale.py @@ -1348,16 +1348,19 @@ def validate_iqs_file(config: Config, integration: Integration) -> None: "quality_scale", f"Invalid {name}: {humanize_error(data, err)}" ) + rules_done = set[str]() rules_met = set[str]() for rule_name, rule_value in data.get("rules", {}).items(): status = rule_value["status"] if isinstance(rule_value, dict) else rule_value if status not in {"done", "exempt"}: continue rules_met.add(rule_name) - if ( - status == "done" - and (validator := VALIDATORS.get(rule_name)) - and (errors := validator.validate(integration)) + if status == "done": + rules_done.add(rule_name) + + for rule_name in rules_done: + if (validator := VALIDATORS.get(rule_name)) and ( + errors := validator.validate(integration, rules_done=rules_done) ): for error in errors: integration.add_error("quality_scale", f"[{rule_name}] {error}") diff --git a/script/hassfest/quality_scale_validation/__init__.py b/script/hassfest/quality_scale_validation/__init__.py index 836c1082763..892bb70fabd 100644 --- a/script/hassfest/quality_scale_validation/__init__.py +++ b/script/hassfest/quality_scale_validation/__init__.py @@ -8,7 +8,9 @@ from script.hassfest.model import Integration class RuleValidationProtocol(Protocol): """Protocol for rule validation.""" - def validate(self, integration: Integration) -> list[str] | None: + def validate( + self, integration: Integration, *, rules_done: set[str] + ) -> list[str] | None: """Validate a quality scale rule. Returns error (if any). diff --git a/script/hassfest/quality_scale_validation/config_entry_unloading.py b/script/hassfest/quality_scale_validation/config_entry_unloading.py index b25a72e427f..fb636a7f2ed 100644 --- a/script/hassfest/quality_scale_validation/config_entry_unloading.py +++ b/script/hassfest/quality_scale_validation/config_entry_unloading.py @@ -17,7 +17,7 @@ def _has_unload_entry_function(module: ast.Module) -> bool: ) -def validate(integration: Integration) -> list[str] | None: +def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None: """Validate that the integration has a config flow.""" init_file = integration.path / "__init__.py" diff --git a/script/hassfest/quality_scale_validation/config_flow.py b/script/hassfest/quality_scale_validation/config_flow.py index e1361d6550f..6e88aa462f4 100644 --- a/script/hassfest/quality_scale_validation/config_flow.py +++ b/script/hassfest/quality_scale_validation/config_flow.py @@ -6,7 +6,7 @@ https://developers.home-assistant.io/docs/core/integration-quality-scale/rules/c from script.hassfest.model import Integration -def validate(integration: Integration) -> list[str] | None: +def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None: """Validate that the integration implements config flow.""" if not integration.config_flow: diff --git a/script/hassfest/quality_scale_validation/diagnostics.py b/script/hassfest/quality_scale_validation/diagnostics.py index d3ef38474f8..44012208bcb 100644 --- a/script/hassfest/quality_scale_validation/diagnostics.py +++ b/script/hassfest/quality_scale_validation/diagnostics.py @@ -22,7 +22,7 @@ def _has_diagnostics_function(module: ast.Module) -> bool: ) -def validate(integration: Integration) -> list[str] | None: +def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None: """Validate that the integration implements diagnostics.""" diagnostics_file = integration.path / "diagnostics.py" diff --git a/script/hassfest/quality_scale_validation/discovery.py b/script/hassfest/quality_scale_validation/discovery.py index 66a08456314..db50cdba55a 100644 --- a/script/hassfest/quality_scale_validation/discovery.py +++ b/script/hassfest/quality_scale_validation/discovery.py @@ -38,7 +38,7 @@ def _has_discovery_function(module: ast.Module) -> bool: ) -def validate(integration: Integration) -> list[str] | None: +def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None: """Validate that the integration implements diagnostics.""" config_flow_file = integration.path / "config_flow.py" diff --git a/script/hassfest/quality_scale_validation/parallel_updates.py b/script/hassfest/quality_scale_validation/parallel_updates.py index 74ec55991f9..3483a44f504 100644 --- a/script/hassfest/quality_scale_validation/parallel_updates.py +++ b/script/hassfest/quality_scale_validation/parallel_updates.py @@ -18,7 +18,7 @@ def _has_parallel_updates_defined(module: ast.Module) -> bool: ) -def validate(integration: Integration) -> list[str] | None: +def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None: """Validate that the integration sets PARALLEL_UPDATES constant.""" errors = [] diff --git a/script/hassfest/quality_scale_validation/reauthentication_flow.py b/script/hassfest/quality_scale_validation/reauthentication_flow.py index 4ae8fed5696..81d34ec4f7f 100644 --- a/script/hassfest/quality_scale_validation/reauthentication_flow.py +++ b/script/hassfest/quality_scale_validation/reauthentication_flow.py @@ -17,7 +17,7 @@ def _has_step_reauth_function(module: ast.Module) -> bool: ) -def validate(integration: Integration) -> list[str] | None: +def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None: """Validate that the integration has a reauthentication flow.""" config_flow_file = integration.path / "config_flow.py" diff --git a/script/hassfest/quality_scale_validation/reconfiguration_flow.py b/script/hassfest/quality_scale_validation/reconfiguration_flow.py index 19192cb28d0..b27475e8c70 100644 --- a/script/hassfest/quality_scale_validation/reconfiguration_flow.py +++ b/script/hassfest/quality_scale_validation/reconfiguration_flow.py @@ -17,7 +17,7 @@ def _has_step_reconfigure_function(module: ast.Module) -> bool: ) -def validate(integration: Integration) -> list[str] | None: +def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None: """Validate that the integration has a reconfiguration flow.""" config_flow_file = integration.path / "config_flow.py" diff --git a/script/hassfest/quality_scale_validation/runtime_data.py b/script/hassfest/quality_scale_validation/runtime_data.py index c426496636b..8ad721a218c 100644 --- a/script/hassfest/quality_scale_validation/runtime_data.py +++ b/script/hassfest/quality_scale_validation/runtime_data.py @@ -4,10 +4,31 @@ https://developers.home-assistant.io/docs/core/integration-quality-scale/rules/r """ import ast +import re +from homeassistant.const import Platform from script.hassfest import ast_parse_module from script.hassfest.model import Integration +_ANNOTATION_MATCH = re.compile(r"^[A-Za-z]+ConfigEntry$") +_FUNCTIONS: dict[str, dict[str, int]] = { + "__init__": { # based on ComponentProtocol + "async_migrate_entry": 2, + "async_remove_config_entry_device": 2, + "async_remove_entry": 2, + "async_setup_entry": 2, + "async_unload_entry": 2, + }, + "diagnostics": { # based on DiagnosticsProtocol + "async_get_config_entry_diagnostics": 2, + "async_get_device_diagnostics": 2, + }, +} +for platform in Platform: # based on EntityPlatformModule + _FUNCTIONS[platform.value] = { + "async_setup_entry": 2, + } + def _sets_runtime_data( async_setup_entry_function: ast.AsyncFunctionDef, config_entry_argument: ast.arg @@ -25,30 +46,83 @@ def _sets_runtime_data( return False -def _get_setup_entry_function(module: ast.Module) -> ast.AsyncFunctionDef | None: - """Get async_setup_entry function.""" +def _get_async_function(module: ast.Module, name: str) -> ast.AsyncFunctionDef | None: + """Get async function.""" for item in module.body: - if isinstance(item, ast.AsyncFunctionDef) and item.name == "async_setup_entry": + if isinstance(item, ast.AsyncFunctionDef) and item.name == name: return item return None -def validate(integration: Integration) -> list[str] | None: +def _check_function_annotation( + function: ast.AsyncFunctionDef, position: int +) -> str | None: + """Ensure function uses CustomConfigEntry type annotation.""" + if len(function.args.args) < position: + return f"{function.name} has incorrect signature" + argument = function.args.args[position - 1] + if not ( + (annotation := argument.annotation) + and isinstance(annotation, ast.Name) + and _ANNOTATION_MATCH.match(annotation.id) + ): + return f"([+ strict-typing]) {function.name} does not use typed ConfigEntry" + return None + + +def _check_typed_config_entry(integration: Integration) -> list[str]: + """Ensure integration uses CustomConfigEntry type annotation.""" + errors: list[str] = [] + # Check body level function annotations + for file, functions in _FUNCTIONS.items(): + module_file = integration.path / f"{file}.py" + if not module_file.exists(): + continue + module = ast_parse_module(module_file) + for function, position in functions.items(): + if not (async_function := _get_async_function(module, function)): + continue + if error := _check_function_annotation(async_function, position): + errors.append(f"{error} in {module_file}") + + # Check config_flow annotations + config_flow_file = integration.path / "config_flow.py" + config_flow = ast_parse_module(config_flow_file) + for node in config_flow.body: + if not isinstance(node, ast.ClassDef): + continue + if any( + isinstance(async_function, ast.FunctionDef) + and async_function.name == "async_get_options_flow" + and (error := _check_function_annotation(async_function, 1)) + for async_function in node.body + ): + errors.append(f"{error} in {config_flow_file}") + + return errors + + +def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None: """Validate correct use of ConfigEntry.runtime_data.""" init_file = integration.path / "__init__.py" init = ast_parse_module(init_file) # Should not happen, but better to be safe - if not (async_setup_entry := _get_setup_entry_function(init)): + if not (async_setup_entry := _get_async_function(init, "async_setup_entry")): return [f"Could not find `async_setup_entry` in {init_file}"] if len(async_setup_entry.args.args) != 2: return [f"async_setup_entry has incorrect signature in {init_file}"] config_entry_argument = async_setup_entry.args.args[1] + errors: list[str] = [] if not _sets_runtime_data(async_setup_entry, config_entry_argument): - return [ + errors.append( "Integration does not set entry.runtime_data in async_setup_entry" f"({init_file})" - ] + ) - return None + # Extra checks, if strict-typing is marked as done + if "strict-typing" in rules_done: + errors.extend(_check_typed_config_entry(integration)) + + return errors diff --git a/script/hassfest/quality_scale_validation/strict_typing.py b/script/hassfest/quality_scale_validation/strict_typing.py index 285746a9eb6..a7755b6bb40 100644 --- a/script/hassfest/quality_scale_validation/strict_typing.py +++ b/script/hassfest/quality_scale_validation/strict_typing.py @@ -24,7 +24,7 @@ def _strict_typing_components() -> set[str]: ) -def validate(integration: Integration) -> list[str] | None: +def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None: """Validate that the integration has strict typing enabled.""" if integration.domain not in _strict_typing_components(): diff --git a/script/hassfest/quality_scale_validation/unique_config_entry.py b/script/hassfest/quality_scale_validation/unique_config_entry.py index bf9991d5635..8c38923e584 100644 --- a/script/hassfest/quality_scale_validation/unique_config_entry.py +++ b/script/hassfest/quality_scale_validation/unique_config_entry.py @@ -30,7 +30,7 @@ def _has_abort_unique_id_configured(module: ast.Module) -> bool: ) -def validate(integration: Integration) -> list[str] | None: +def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None: """Validate that the integration prevents duplicate devices.""" if integration.manifest.get("single_config_entry"):