Add check for typed ConfigEntry in quality scale validation (#132028)

This commit is contained in:
epenet 2024-12-10 13:07:08 +01:00 committed by GitHub
parent 46d4081ec6
commit 95107cf670
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 101 additions and 22 deletions

View File

@ -1348,16 +1348,19 @@ def validate_iqs_file(config: Config, integration: Integration) -> None:
"quality_scale", f"Invalid {name}: {humanize_error(data, err)}"
)
rules_done = set[str]()
rules_met = set[str]()
for rule_name, rule_value in data.get("rules", {}).items():
status = rule_value["status"] if isinstance(rule_value, dict) else rule_value
if status not in {"done", "exempt"}:
continue
rules_met.add(rule_name)
if (
status == "done"
and (validator := VALIDATORS.get(rule_name))
and (errors := validator.validate(integration))
if status == "done":
rules_done.add(rule_name)
for rule_name in rules_done:
if (validator := VALIDATORS.get(rule_name)) and (
errors := validator.validate(integration, rules_done=rules_done)
):
for error in errors:
integration.add_error("quality_scale", f"[{rule_name}] {error}")

View File

@ -8,7 +8,9 @@ from script.hassfest.model import Integration
class RuleValidationProtocol(Protocol):
"""Protocol for rule validation."""
def validate(self, integration: Integration) -> list[str] | None:
def validate(
self, integration: Integration, *, rules_done: set[str]
) -> list[str] | None:
"""Validate a quality scale rule.
Returns error (if any).

View File

@ -17,7 +17,7 @@ def _has_unload_entry_function(module: ast.Module) -> bool:
)
def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration has a config flow."""
init_file = integration.path / "__init__.py"

View File

@ -6,7 +6,7 @@ https://developers.home-assistant.io/docs/core/integration-quality-scale/rules/c
from script.hassfest.model import Integration
def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration implements config flow."""
if not integration.config_flow:

View File

@ -22,7 +22,7 @@ def _has_diagnostics_function(module: ast.Module) -> bool:
)
def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration implements diagnostics."""
diagnostics_file = integration.path / "diagnostics.py"

View File

@ -38,7 +38,7 @@ def _has_discovery_function(module: ast.Module) -> bool:
)
def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration implements diagnostics."""
config_flow_file = integration.path / "config_flow.py"

View File

@ -18,7 +18,7 @@ def _has_parallel_updates_defined(module: ast.Module) -> bool:
)
def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration sets PARALLEL_UPDATES constant."""
errors = []

View File

@ -17,7 +17,7 @@ def _has_step_reauth_function(module: ast.Module) -> bool:
)
def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration has a reauthentication flow."""
config_flow_file = integration.path / "config_flow.py"

View File

@ -17,7 +17,7 @@ def _has_step_reconfigure_function(module: ast.Module) -> bool:
)
def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration has a reconfiguration flow."""
config_flow_file = integration.path / "config_flow.py"

View File

@ -4,10 +4,31 @@ https://developers.home-assistant.io/docs/core/integration-quality-scale/rules/r
"""
import ast
import re
from homeassistant.const import Platform
from script.hassfest import ast_parse_module
from script.hassfest.model import Integration
_ANNOTATION_MATCH = re.compile(r"^[A-Za-z]+ConfigEntry$")
_FUNCTIONS: dict[str, dict[str, int]] = {
"__init__": { # based on ComponentProtocol
"async_migrate_entry": 2,
"async_remove_config_entry_device": 2,
"async_remove_entry": 2,
"async_setup_entry": 2,
"async_unload_entry": 2,
},
"diagnostics": { # based on DiagnosticsProtocol
"async_get_config_entry_diagnostics": 2,
"async_get_device_diagnostics": 2,
},
}
for platform in Platform: # based on EntityPlatformModule
_FUNCTIONS[platform.value] = {
"async_setup_entry": 2,
}
def _sets_runtime_data(
async_setup_entry_function: ast.AsyncFunctionDef, config_entry_argument: ast.arg
@ -25,30 +46,83 @@ def _sets_runtime_data(
return False
def _get_setup_entry_function(module: ast.Module) -> ast.AsyncFunctionDef | None:
"""Get async_setup_entry function."""
def _get_async_function(module: ast.Module, name: str) -> ast.AsyncFunctionDef | None:
"""Get async function."""
for item in module.body:
if isinstance(item, ast.AsyncFunctionDef) and item.name == "async_setup_entry":
if isinstance(item, ast.AsyncFunctionDef) and item.name == name:
return item
return None
def validate(integration: Integration) -> list[str] | None:
def _check_function_annotation(
function: ast.AsyncFunctionDef, position: int
) -> str | None:
"""Ensure function uses CustomConfigEntry type annotation."""
if len(function.args.args) < position:
return f"{function.name} has incorrect signature"
argument = function.args.args[position - 1]
if not (
(annotation := argument.annotation)
and isinstance(annotation, ast.Name)
and _ANNOTATION_MATCH.match(annotation.id)
):
return f"([+ strict-typing]) {function.name} does not use typed ConfigEntry"
return None
def _check_typed_config_entry(integration: Integration) -> list[str]:
"""Ensure integration uses CustomConfigEntry type annotation."""
errors: list[str] = []
# Check body level function annotations
for file, functions in _FUNCTIONS.items():
module_file = integration.path / f"{file}.py"
if not module_file.exists():
continue
module = ast_parse_module(module_file)
for function, position in functions.items():
if not (async_function := _get_async_function(module, function)):
continue
if error := _check_function_annotation(async_function, position):
errors.append(f"{error} in {module_file}")
# Check config_flow annotations
config_flow_file = integration.path / "config_flow.py"
config_flow = ast_parse_module(config_flow_file)
for node in config_flow.body:
if not isinstance(node, ast.ClassDef):
continue
if any(
isinstance(async_function, ast.FunctionDef)
and async_function.name == "async_get_options_flow"
and (error := _check_function_annotation(async_function, 1))
for async_function in node.body
):
errors.append(f"{error} in {config_flow_file}")
return errors
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate correct use of ConfigEntry.runtime_data."""
init_file = integration.path / "__init__.py"
init = ast_parse_module(init_file)
# Should not happen, but better to be safe
if not (async_setup_entry := _get_setup_entry_function(init)):
if not (async_setup_entry := _get_async_function(init, "async_setup_entry")):
return [f"Could not find `async_setup_entry` in {init_file}"]
if len(async_setup_entry.args.args) != 2:
return [f"async_setup_entry has incorrect signature in {init_file}"]
config_entry_argument = async_setup_entry.args.args[1]
errors: list[str] = []
if not _sets_runtime_data(async_setup_entry, config_entry_argument):
return [
errors.append(
"Integration does not set entry.runtime_data in async_setup_entry"
f"({init_file})"
]
)
return None
# Extra checks, if strict-typing is marked as done
if "strict-typing" in rules_done:
errors.extend(_check_typed_config_entry(integration))
return errors

View File

@ -24,7 +24,7 @@ def _strict_typing_components() -> set[str]:
)
def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration has strict typing enabled."""
if integration.domain not in _strict_typing_components():

View File

@ -30,7 +30,7 @@ def _has_abort_unique_id_configured(module: ast.Module) -> bool:
)
def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration prevents duplicate devices."""
if integration.manifest.get("single_config_entry"):