mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 04:37:06 +00:00
Add check for typed ConfigEntry in quality scale validation (#132028)
This commit is contained in:
parent
46d4081ec6
commit
95107cf670
@ -1348,16 +1348,19 @@ def validate_iqs_file(config: Config, integration: Integration) -> None:
|
||||
"quality_scale", f"Invalid {name}: {humanize_error(data, err)}"
|
||||
)
|
||||
|
||||
rules_done = set[str]()
|
||||
rules_met = set[str]()
|
||||
for rule_name, rule_value in data.get("rules", {}).items():
|
||||
status = rule_value["status"] if isinstance(rule_value, dict) else rule_value
|
||||
if status not in {"done", "exempt"}:
|
||||
continue
|
||||
rules_met.add(rule_name)
|
||||
if (
|
||||
status == "done"
|
||||
and (validator := VALIDATORS.get(rule_name))
|
||||
and (errors := validator.validate(integration))
|
||||
if status == "done":
|
||||
rules_done.add(rule_name)
|
||||
|
||||
for rule_name in rules_done:
|
||||
if (validator := VALIDATORS.get(rule_name)) and (
|
||||
errors := validator.validate(integration, rules_done=rules_done)
|
||||
):
|
||||
for error in errors:
|
||||
integration.add_error("quality_scale", f"[{rule_name}] {error}")
|
||||
|
@ -8,7 +8,9 @@ from script.hassfest.model import Integration
|
||||
class RuleValidationProtocol(Protocol):
|
||||
"""Protocol for rule validation."""
|
||||
|
||||
def validate(self, integration: Integration) -> list[str] | None:
|
||||
def validate(
|
||||
self, integration: Integration, *, rules_done: set[str]
|
||||
) -> list[str] | None:
|
||||
"""Validate a quality scale rule.
|
||||
|
||||
Returns error (if any).
|
||||
|
@ -17,7 +17,7 @@ def _has_unload_entry_function(module: ast.Module) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def validate(integration: Integration) -> list[str] | None:
|
||||
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
|
||||
"""Validate that the integration has a config flow."""
|
||||
|
||||
init_file = integration.path / "__init__.py"
|
||||
|
@ -6,7 +6,7 @@ https://developers.home-assistant.io/docs/core/integration-quality-scale/rules/c
|
||||
from script.hassfest.model import Integration
|
||||
|
||||
|
||||
def validate(integration: Integration) -> list[str] | None:
|
||||
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
|
||||
"""Validate that the integration implements config flow."""
|
||||
|
||||
if not integration.config_flow:
|
||||
|
@ -22,7 +22,7 @@ def _has_diagnostics_function(module: ast.Module) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def validate(integration: Integration) -> list[str] | None:
|
||||
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
|
||||
"""Validate that the integration implements diagnostics."""
|
||||
|
||||
diagnostics_file = integration.path / "diagnostics.py"
|
||||
|
@ -38,7 +38,7 @@ def _has_discovery_function(module: ast.Module) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def validate(integration: Integration) -> list[str] | None:
|
||||
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
|
||||
"""Validate that the integration implements diagnostics."""
|
||||
|
||||
config_flow_file = integration.path / "config_flow.py"
|
||||
|
@ -18,7 +18,7 @@ def _has_parallel_updates_defined(module: ast.Module) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def validate(integration: Integration) -> list[str] | None:
|
||||
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
|
||||
"""Validate that the integration sets PARALLEL_UPDATES constant."""
|
||||
|
||||
errors = []
|
||||
|
@ -17,7 +17,7 @@ def _has_step_reauth_function(module: ast.Module) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def validate(integration: Integration) -> list[str] | None:
|
||||
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
|
||||
"""Validate that the integration has a reauthentication flow."""
|
||||
|
||||
config_flow_file = integration.path / "config_flow.py"
|
||||
|
@ -17,7 +17,7 @@ def _has_step_reconfigure_function(module: ast.Module) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def validate(integration: Integration) -> list[str] | None:
|
||||
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
|
||||
"""Validate that the integration has a reconfiguration flow."""
|
||||
|
||||
config_flow_file = integration.path / "config_flow.py"
|
||||
|
@ -4,10 +4,31 @@ https://developers.home-assistant.io/docs/core/integration-quality-scale/rules/r
|
||||
"""
|
||||
|
||||
import ast
|
||||
import re
|
||||
|
||||
from homeassistant.const import Platform
|
||||
from script.hassfest import ast_parse_module
|
||||
from script.hassfest.model import Integration
|
||||
|
||||
_ANNOTATION_MATCH = re.compile(r"^[A-Za-z]+ConfigEntry$")
|
||||
_FUNCTIONS: dict[str, dict[str, int]] = {
|
||||
"__init__": { # based on ComponentProtocol
|
||||
"async_migrate_entry": 2,
|
||||
"async_remove_config_entry_device": 2,
|
||||
"async_remove_entry": 2,
|
||||
"async_setup_entry": 2,
|
||||
"async_unload_entry": 2,
|
||||
},
|
||||
"diagnostics": { # based on DiagnosticsProtocol
|
||||
"async_get_config_entry_diagnostics": 2,
|
||||
"async_get_device_diagnostics": 2,
|
||||
},
|
||||
}
|
||||
for platform in Platform: # based on EntityPlatformModule
|
||||
_FUNCTIONS[platform.value] = {
|
||||
"async_setup_entry": 2,
|
||||
}
|
||||
|
||||
|
||||
def _sets_runtime_data(
|
||||
async_setup_entry_function: ast.AsyncFunctionDef, config_entry_argument: ast.arg
|
||||
@ -25,30 +46,83 @@ def _sets_runtime_data(
|
||||
return False
|
||||
|
||||
|
||||
def _get_setup_entry_function(module: ast.Module) -> ast.AsyncFunctionDef | None:
|
||||
"""Get async_setup_entry function."""
|
||||
def _get_async_function(module: ast.Module, name: str) -> ast.AsyncFunctionDef | None:
|
||||
"""Get async function."""
|
||||
for item in module.body:
|
||||
if isinstance(item, ast.AsyncFunctionDef) and item.name == "async_setup_entry":
|
||||
if isinstance(item, ast.AsyncFunctionDef) and item.name == name:
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
def validate(integration: Integration) -> list[str] | None:
|
||||
def _check_function_annotation(
|
||||
function: ast.AsyncFunctionDef, position: int
|
||||
) -> str | None:
|
||||
"""Ensure function uses CustomConfigEntry type annotation."""
|
||||
if len(function.args.args) < position:
|
||||
return f"{function.name} has incorrect signature"
|
||||
argument = function.args.args[position - 1]
|
||||
if not (
|
||||
(annotation := argument.annotation)
|
||||
and isinstance(annotation, ast.Name)
|
||||
and _ANNOTATION_MATCH.match(annotation.id)
|
||||
):
|
||||
return f"([+ strict-typing]) {function.name} does not use typed ConfigEntry"
|
||||
return None
|
||||
|
||||
|
||||
def _check_typed_config_entry(integration: Integration) -> list[str]:
|
||||
"""Ensure integration uses CustomConfigEntry type annotation."""
|
||||
errors: list[str] = []
|
||||
# Check body level function annotations
|
||||
for file, functions in _FUNCTIONS.items():
|
||||
module_file = integration.path / f"{file}.py"
|
||||
if not module_file.exists():
|
||||
continue
|
||||
module = ast_parse_module(module_file)
|
||||
for function, position in functions.items():
|
||||
if not (async_function := _get_async_function(module, function)):
|
||||
continue
|
||||
if error := _check_function_annotation(async_function, position):
|
||||
errors.append(f"{error} in {module_file}")
|
||||
|
||||
# Check config_flow annotations
|
||||
config_flow_file = integration.path / "config_flow.py"
|
||||
config_flow = ast_parse_module(config_flow_file)
|
||||
for node in config_flow.body:
|
||||
if not isinstance(node, ast.ClassDef):
|
||||
continue
|
||||
if any(
|
||||
isinstance(async_function, ast.FunctionDef)
|
||||
and async_function.name == "async_get_options_flow"
|
||||
and (error := _check_function_annotation(async_function, 1))
|
||||
for async_function in node.body
|
||||
):
|
||||
errors.append(f"{error} in {config_flow_file}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
|
||||
"""Validate correct use of ConfigEntry.runtime_data."""
|
||||
init_file = integration.path / "__init__.py"
|
||||
init = ast_parse_module(init_file)
|
||||
|
||||
# Should not happen, but better to be safe
|
||||
if not (async_setup_entry := _get_setup_entry_function(init)):
|
||||
if not (async_setup_entry := _get_async_function(init, "async_setup_entry")):
|
||||
return [f"Could not find `async_setup_entry` in {init_file}"]
|
||||
if len(async_setup_entry.args.args) != 2:
|
||||
return [f"async_setup_entry has incorrect signature in {init_file}"]
|
||||
config_entry_argument = async_setup_entry.args.args[1]
|
||||
|
||||
errors: list[str] = []
|
||||
if not _sets_runtime_data(async_setup_entry, config_entry_argument):
|
||||
return [
|
||||
errors.append(
|
||||
"Integration does not set entry.runtime_data in async_setup_entry"
|
||||
f"({init_file})"
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
# Extra checks, if strict-typing is marked as done
|
||||
if "strict-typing" in rules_done:
|
||||
errors.extend(_check_typed_config_entry(integration))
|
||||
|
||||
return errors
|
||||
|
@ -24,7 +24,7 @@ def _strict_typing_components() -> set[str]:
|
||||
)
|
||||
|
||||
|
||||
def validate(integration: Integration) -> list[str] | None:
|
||||
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
|
||||
"""Validate that the integration has strict typing enabled."""
|
||||
|
||||
if integration.domain not in _strict_typing_components():
|
||||
|
@ -30,7 +30,7 @@ def _has_abort_unique_id_configured(module: ast.Module) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def validate(integration: Integration) -> list[str] | None:
|
||||
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
|
||||
"""Validate that the integration prevents duplicate devices."""
|
||||
|
||||
if integration.manifest.get("single_config_entry"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user