Make hassfest strictly typed (#82091)

This commit is contained in:
Aarni Koskela 2022-11-23 20:05:31 +02:00 committed by GitHub
parent 0b5357de44
commit 97b40b5f49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 132 additions and 97 deletions

View File

@ -1,4 +1,6 @@
"""Validate manifests.""" """Validate manifests."""
from __future__ import annotations
import argparse import argparse
import pathlib import pathlib
import sys import sys
@ -55,7 +57,7 @@ ALL_PLUGIN_NAMES = [
] ]
def valid_integration_path(integration_path): def valid_integration_path(integration_path: pathlib.Path | str) -> pathlib.Path:
"""Test if it's a valid integration.""" """Test if it's a valid integration."""
path = pathlib.Path(integration_path) path = pathlib.Path(integration_path)
if not path.is_dir(): if not path.is_dir():
@ -124,7 +126,7 @@ def get_config() -> Config:
) )
def main(): def main() -> int:
"""Validate manifests.""" """Validate manifests."""
try: try:
config = get_config() config = get_config()
@ -218,7 +220,12 @@ def main():
return 1 return 1
def print_integrations_status(config, integrations, *, show_fixable_errors=True): def print_integrations_status(
config: Config,
integrations: list[Integration],
*,
show_fixable_errors: bool = True,
) -> None:
"""Print integration status.""" """Print integration status."""
for integration in sorted(integrations, key=lambda itg: itg.domain): for integration in sorted(integrations, key=lambda itg: itg.domain):
extra = f" - {integration.path}" if config.specific_integrations else "" extra = f" - {integration.path}" if config.specific_integrations else ""

View File

@ -41,7 +41,7 @@ def validate(integrations: dict[str, Integration], config: Config) -> None:
) )
def generate(integrations: dict[str, Integration], config: Config): def generate(integrations: dict[str, Integration], config: Config) -> None:
"""Generate application_credentials data.""" """Generate application_credentials data."""
application_credentials_path = ( application_credentials_path = (
config.root / "homeassistant/generated/application_credentials.py" config.root / "homeassistant/generated/application_credentials.py"

View File

@ -5,7 +5,7 @@ from .model import Config, Integration
from .serializer import format_python_namespace from .serializer import format_python_namespace
def generate_and_validate(integrations: list[dict[str, str]]): def generate_and_validate(integrations: dict[str, Integration]) -> str:
"""Validate and generate bluetooth data.""" """Validate and generate bluetooth data."""
match_list = [] match_list = []
@ -29,7 +29,7 @@ def generate_and_validate(integrations: list[dict[str, str]]):
) )
def validate(integrations: dict[str, Integration], config: Config): def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Validate bluetooth file.""" """Validate bluetooth file."""
bluetooth_path = config.root / "homeassistant/generated/bluetooth.py" bluetooth_path = config.root / "homeassistant/generated/bluetooth.py"
config.cache["bluetooth"] = content = generate_and_validate(integrations) config.cache["bluetooth"] = content = generate_and_validate(integrations)
@ -48,7 +48,7 @@ def validate(integrations: dict[str, Integration], config: Config):
return return
def generate(integrations: dict[str, Integration], config: Config): def generate(integrations: dict[str, Integration], config: Config) -> None:
"""Generate bluetooth file.""" """Generate bluetooth file."""
bluetooth_path = config.root / "homeassistant/generated/bluetooth.py" bluetooth_path = config.root / "homeassistant/generated/bluetooth.py"
with open(str(bluetooth_path), "w") as fp: with open(str(bluetooth_path), "w") as fp:

View File

@ -55,7 +55,7 @@ def _validate_brand(
): ):
config.add_error( config.add_error(
"brand", "brand",
f"{brand.path.name}: Brand '{brand.brand['domain']}' " f"{brand.path.name}: Brand '{brand.domain}' "
f"is an integration but is missing in the brand's 'integrations' list'", f"is an integration but is missing in the brand's 'integrations' list'",
) )

View File

@ -42,7 +42,7 @@ REMOVE_CODEOWNERS = """
""" """
def generate_and_validate(integrations: dict[str, Integration], config: Config): def generate_and_validate(integrations: dict[str, Integration], config: Config) -> str:
"""Generate CODEOWNERS.""" """Generate CODEOWNERS."""
parts = [BASE] parts = [BASE]
@ -77,7 +77,7 @@ def generate_and_validate(integrations: dict[str, Integration], config: Config):
return "\n".join(parts) return "\n".join(parts)
def validate(integrations: dict[str, Integration], config: Config): def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Validate CODEOWNERS.""" """Validate CODEOWNERS."""
codeowners_path = config.root / "CODEOWNERS" codeowners_path = config.root / "CODEOWNERS"
config.cache["codeowners"] = content = generate_and_validate(integrations, config) config.cache["codeowners"] = content = generate_and_validate(integrations, config)
@ -95,7 +95,7 @@ def validate(integrations: dict[str, Integration], config: Config):
return return
def generate(integrations: dict[str, Integration], config: Config): def generate(integrations: dict[str, Integration], config: Config) -> None:
"""Generate CODEOWNERS.""" """Generate CODEOWNERS."""
codeowners_path = config.root / "CODEOWNERS" codeowners_path = config.root / "CODEOWNERS"
with open(str(codeowners_path), "w") as fp: with open(str(codeowners_path), "w") as fp:

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import json import json
import pathlib import pathlib
from typing import Any
from .brand import validate as validate_brands from .brand import validate as validate_brands
from .model import Brand, Config, Integration from .model import Brand, Config, Integration
@ -11,12 +12,12 @@ from .serializer import format_python_namespace
UNIQUE_ID_IGNORE = {"huawei_lte", "mqtt", "adguard"} UNIQUE_ID_IGNORE = {"huawei_lte", "mqtt", "adguard"}
def _validate_integration(config: Config, integration: Integration): def _validate_integration(config: Config, integration: Integration) -> None:
"""Validate config flow of an integration.""" """Validate config flow of an integration."""
config_flow_file = integration.path / "config_flow.py" config_flow_file = integration.path / "config_flow.py"
if not config_flow_file.is_file(): if not config_flow_file.is_file():
if integration.manifest.get("config_flow"): if (integration.manifest or {}).get("config_flow"):
integration.add_error( integration.add_error(
"config_flow", "config_flow",
"Config flows need to be defined in the file config_flow.py", "Config flows need to be defined in the file config_flow.py",
@ -60,9 +61,9 @@ def _validate_integration(config: Config, integration: Integration):
) )
def _generate_and_validate(integrations: dict[str, Integration], config: Config): def _generate_and_validate(integrations: dict[str, Integration], config: Config) -> str:
"""Validate and generate config flow data.""" """Validate and generate config flow data."""
domains = { domains: dict[str, list[str]] = {
"integration": [], "integration": [],
"helper": [], "helper": [],
} }
@ -84,9 +85,9 @@ def _generate_and_validate(integrations: dict[str, Integration], config: Config)
def _populate_brand_integrations( def _populate_brand_integrations(
integration_data: dict, integration_data: dict[str, Any],
integrations: dict[str, Integration], integrations: dict[str, Integration],
brand_metadata: dict, brand_metadata: dict[str, Any],
sub_integrations: list[str], sub_integrations: list[str],
) -> None: ) -> None:
"""Add referenced integrations to a brand's metadata.""" """Add referenced integrations to a brand's metadata."""
@ -99,7 +100,7 @@ def _populate_brand_integrations(
"system", "system",
): ):
continue continue
metadata = { metadata: dict[str, Any] = {
"integration_type": integration.integration_type, "integration_type": integration.integration_type,
} }
# Always set the config_flow key to avoid breaking the frontend # Always set the config_flow key to avoid breaking the frontend
@ -119,11 +120,13 @@ def _populate_brand_integrations(
def _generate_integrations( def _generate_integrations(
brands: dict[str, Brand], integrations: dict[str, Integration], config: Config brands: dict[str, Brand],
): integrations: dict[str, Integration],
config: Config,
) -> str:
"""Generate integrations data.""" """Generate integrations data."""
result = { result: dict[str, Any] = {
"integration": {}, "integration": {},
"helper": {}, "helper": {},
"translated_name": set(), "translated_name": set(),
@ -154,7 +157,7 @@ def _generate_integrations(
# Generate the config flow index # Generate the config flow index
for domain in sorted(primary_domains): for domain in sorted(primary_domains):
metadata = {} metadata: dict[str, Any] = {}
if brand := brands.get(domain): if brand := brands.get(domain):
metadata["name"] = brand.name metadata["name"] = brand.name
@ -199,7 +202,7 @@ def _generate_integrations(
) )
def validate(integrations: dict[str, Integration], config: Config): def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Validate config flow file.""" """Validate config flow file."""
config_flow_path = config.root / "homeassistant/generated/config_flows.py" config_flow_path = config.root / "homeassistant/generated/config_flows.py"
integrations_path = config.root / "homeassistant/generated/integrations.json" integrations_path = config.root / "homeassistant/generated/integrations.json"
@ -233,7 +236,7 @@ def validate(integrations: dict[str, Integration], config: Config):
) )
def generate(integrations: dict[str, Integration], config: Config): def generate(integrations: dict[str, Integration], config: Config) -> None:
"""Generate config flow file.""" """Generate config flow file."""
config_flow_path = config.root / "homeassistant/generated/config_flows.py" config_flow_path = config.root / "homeassistant/generated/config_flows.py"
integrations_path = config.root / "homeassistant/generated/integrations.json" integrations_path = config.root / "homeassistant/generated/integrations.json"

View File

@ -30,7 +30,7 @@ ALLOWED_IGNORE_VIOLATIONS = {
} }
def validate(integrations: dict[str, Integration], config: Config): def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Validate coverage.""" """Validate coverage."""
coverage_path = config.root / ".coveragerc" coverage_path = config.root / ".coveragerc"

View File

@ -7,19 +7,19 @@ from pathlib import Path
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.requirements import DISCOVERY_INTEGRATIONS from homeassistant.requirements import DISCOVERY_INTEGRATIONS
from .model import Integration from .model import Config, Integration
class ImportCollector(ast.NodeVisitor): class ImportCollector(ast.NodeVisitor):
"""Collect all integrations referenced.""" """Collect all integrations referenced."""
def __init__(self, integration: Integration): def __init__(self, integration: Integration) -> None:
"""Initialize the import collector.""" """Initialize the import collector."""
self.integration = integration self.integration = integration
self.referenced: dict[Path, set[str]] = {} self.referenced: dict[Path, set[str]] = {}
# Current file or dir we're inspecting # Current file or dir we're inspecting
self._cur_fil_dir = None self._cur_fil_dir: Path | None = None
def collect(self) -> None: def collect(self) -> None:
"""Collect imports from a source file.""" """Collect imports from a source file."""
@ -32,11 +32,12 @@ class ImportCollector(ast.NodeVisitor):
self.visit(ast.parse(fil.read_text())) self.visit(ast.parse(fil.read_text()))
self._cur_fil_dir = None self._cur_fil_dir = None
def _add_reference(self, reference_domain: str): def _add_reference(self, reference_domain: str) -> None:
"""Add a reference.""" """Add a reference."""
assert self._cur_fil_dir
self.referenced[self._cur_fil_dir].add(reference_domain) self.referenced[self._cur_fil_dir].add(reference_domain)
def visit_ImportFrom(self, node): def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Visit ImportFrom node.""" """Visit ImportFrom node."""
if node.module is None: if node.module is None:
return return
@ -59,14 +60,14 @@ class ImportCollector(ast.NodeVisitor):
for name_node in node.names: for name_node in node.names:
self._add_reference(name_node.name) self._add_reference(name_node.name)
def visit_Import(self, node): def visit_Import(self, node: ast.Import) -> None:
"""Visit Import node.""" """Visit Import node."""
# import homeassistant.components.hue as hue # import homeassistant.components.hue as hue
for name_node in node.names: for name_node in node.names:
if name_node.name.startswith("homeassistant.components."): if name_node.name.startswith("homeassistant.components."):
self._add_reference(name_node.name.split(".")[2]) self._add_reference(name_node.name.split(".")[2])
def visit_Attribute(self, node): def visit_Attribute(self, node: ast.Attribute) -> None:
"""Visit Attribute node.""" """Visit Attribute node."""
# hass.components.hue.async_create() # hass.components.hue.async_create()
# Name(id=hass) # Name(id=hass)
@ -156,15 +157,16 @@ IGNORE_VIOLATIONS = {
def calc_allowed_references(integration: Integration) -> set[str]: def calc_allowed_references(integration: Integration) -> set[str]:
"""Return a set of allowed references.""" """Return a set of allowed references."""
manifest = integration.manifest
allowed_references = ( allowed_references = (
ALLOWED_USED_COMPONENTS ALLOWED_USED_COMPONENTS
| set(integration.manifest.get("dependencies", [])) | set(manifest.get("dependencies", []))
| set(integration.manifest.get("after_dependencies", [])) | set(manifest.get("after_dependencies", []))
) )
# Discovery requirements are ok if referenced in manifest # Discovery requirements are ok if referenced in manifest
for check_domain, to_check in DISCOVERY_INTEGRATIONS.items(): for check_domain, to_check in DISCOVERY_INTEGRATIONS.items():
if any(check in integration.manifest for check in to_check): if any(check in manifest for check in to_check):
allowed_references.add(check_domain) allowed_references.add(check_domain)
return allowed_references return allowed_references
@ -174,7 +176,7 @@ def find_non_referenced_integrations(
integrations: dict[str, Integration], integrations: dict[str, Integration],
integration: Integration, integration: Integration,
references: dict[Path, set[str]], references: dict[Path, set[str]],
): ) -> set[str]:
"""Find intergrations that are not allowed to be referenced.""" """Find intergrations that are not allowed to be referenced."""
allowed_references = calc_allowed_references(integration) allowed_references = calc_allowed_references(integration)
referenced = set() referenced = set()
@ -219,8 +221,9 @@ def find_non_referenced_integrations(
def validate_dependencies( def validate_dependencies(
integrations: dict[str, Integration], integration: Integration integrations: dict[str, Integration],
): integration: Integration,
) -> None:
"""Validate all dependencies.""" """Validate all dependencies."""
# Some integrations are allowed to have violations. # Some integrations are allowed to have violations.
if integration.domain in IGNORE_VIOLATIONS: if integration.domain in IGNORE_VIOLATIONS:
@ -242,7 +245,7 @@ def validate_dependencies(
) )
def validate(integrations: dict[str, Integration], config): def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Handle dependencies for integrations.""" """Handle dependencies for integrations."""
# check for non-existing dependencies # check for non-existing dependencies
for integration in integrations.values(): for integration in integrations.values():

View File

@ -5,7 +5,7 @@ from .model import Config, Integration
from .serializer import format_python_namespace from .serializer import format_python_namespace
def generate_and_validate(integrations: list[dict[str, str]]): def generate_and_validate(integrations: dict[str, Integration]) -> str:
"""Validate and generate dhcp data.""" """Validate and generate dhcp data."""
match_list = [] match_list = []
@ -29,7 +29,7 @@ def generate_and_validate(integrations: list[dict[str, str]]):
) )
def validate(integrations: dict[str, Integration], config: Config): def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Validate dhcp file.""" """Validate dhcp file."""
dhcp_path = config.root / "homeassistant/generated/dhcp.py" dhcp_path = config.root / "homeassistant/generated/dhcp.py"
config.cache["dhcp"] = content = generate_and_validate(integrations) config.cache["dhcp"] = content = generate_and_validate(integrations)
@ -48,7 +48,7 @@ def validate(integrations: dict[str, Integration], config: Config):
return return
def generate(integrations: dict[str, Integration], config: Config): def generate(integrations: dict[str, Integration], config: Config) -> None:
"""Generate dhcp file.""" """Generate dhcp file."""
dhcp_path = config.root / "homeassistant/generated/dhcp.py" dhcp_path = config.root / "homeassistant/generated/dhcp.py"
with open(str(dhcp_path), "w") as fp: with open(str(dhcp_path), "w") as fp:

View File

@ -3,10 +3,10 @@ from __future__ import annotations
import json import json
from .model import Integration from .model import Config, Integration
def validate_json_files(integration: Integration): def validate_json_files(integration: Integration) -> None:
"""Validate JSON files for integration.""" """Validate JSON files for integration."""
for json_file in integration.path.glob("**/*.json"): for json_file in integration.path.glob("**/*.json"):
if not json_file.is_file(): if not json_file.is_file():
@ -18,10 +18,8 @@ def validate_json_files(integration: Integration):
relative_path = json_file.relative_to(integration.path) relative_path = json_file.relative_to(integration.path)
integration.add_error("json", f"Invalid JSON file {relative_path}") integration.add_error("json", f"Invalid JSON file {relative_path}")
return
def validate(integrations: dict[str, Integration], config: Config) -> None:
def validate(integrations: dict[str, Integration], config):
"""Handle JSON files inside integrations.""" """Handle JSON files inside integrations."""
if not config.specific_integrations: if not config.specific_integrations:
return return

View File

@ -119,7 +119,7 @@ def documentation_url(value: str) -> str:
return value return value
def verify_lowercase(value: str): def verify_lowercase(value: str) -> str:
"""Verify a value is lowercase.""" """Verify a value is lowercase."""
if value.lower() != value: if value.lower() != value:
raise vol.Invalid("Value needs to be lowercase") raise vol.Invalid("Value needs to be lowercase")
@ -127,7 +127,7 @@ def verify_lowercase(value: str):
return value return value
def verify_uppercase(value: str): def verify_uppercase(value: str) -> str:
"""Verify a value is uppercase.""" """Verify a value is uppercase."""
if value.upper() != value: if value.upper() != value:
raise vol.Invalid("Value needs to be uppercase") raise vol.Invalid("Value needs to be uppercase")
@ -135,7 +135,7 @@ def verify_uppercase(value: str):
return value return value
def verify_version(value: str): def verify_version(value: str) -> str:
"""Verify the version.""" """Verify the version."""
try: try:
AwesomeVersion( AwesomeVersion(
@ -153,7 +153,7 @@ def verify_version(value: str):
return value return value
def verify_wildcard(value: str): def verify_wildcard(value: str) -> str:
"""Verify the matcher contains a wildcard.""" """Verify the matcher contains a wildcard."""
if "*" not in value: if "*" not in value:
raise vol.Invalid(f"'{value}' needs to contain a wildcard matcher") raise vol.Invalid(f"'{value}' needs to contain a wildcard matcher")
@ -286,13 +286,13 @@ CUSTOM_INTEGRATION_MANIFEST_SCHEMA = INTEGRATION_MANIFEST_SCHEMA.extend(
) )
def validate_version(integration: Integration): def validate_version(integration: Integration) -> None:
""" """
Validate the version of the integration. Validate the version of the integration.
Will be removed when the version key is no longer optional for custom integrations. Will be removed when the version key is no longer optional for custom integrations.
""" """
if not integration.manifest.get("version"): if not (integration.manifest and integration.manifest.get("version")):
integration.add_error("manifest", "No 'version' key in the manifest file.") integration.add_error("manifest", "No 'version' key in the manifest file.")
return return

View File

@ -25,7 +25,7 @@ class Error:
class Config: class Config:
"""Config for the run.""" """Config for the run."""
specific_integrations: pathlib.Path | None = attr.ib() specific_integrations: list[pathlib.Path] | None = attr.ib()
root: pathlib.Path = attr.ib() root: pathlib.Path = attr.ib()
action: str = attr.ib() action: str = attr.ib()
requirements: bool = attr.ib() requirements: bool = attr.ib()

View File

@ -7,7 +7,7 @@ from .model import Config, Integration
from .serializer import format_python_namespace from .serializer import format_python_namespace
def generate_and_validate(integrations: dict[str, Integration]): def generate_and_validate(integrations: dict[str, Integration]) -> str:
"""Validate and generate MQTT data.""" """Validate and generate MQTT data."""
data = defaultdict(list) data = defaultdict(list)
@ -29,7 +29,7 @@ def generate_and_validate(integrations: dict[str, Integration]):
return format_python_namespace({"MQTT": data}) return format_python_namespace({"MQTT": data})
def validate(integrations: dict[str, Integration], config: Config): def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Validate MQTT file.""" """Validate MQTT file."""
mqtt_path = config.root / "homeassistant/generated/mqtt.py" mqtt_path = config.root / "homeassistant/generated/mqtt.py"
config.cache["mqtt"] = content = generate_and_validate(integrations) config.cache["mqtt"] = content = generate_and_validate(integrations)
@ -44,10 +44,9 @@ def validate(integrations: dict[str, Integration], config: Config):
"File mqtt.py is not up to date. Run python3 -m script.hassfest", "File mqtt.py is not up to date. Run python3 -m script.hassfest",
fixable=True, fixable=True,
) )
return
def generate(integrations: dict[str, Integration], config: Config): def generate(integrations: dict[str, Integration], config: Config) -> None:
"""Generate MQTT file.""" """Generate MQTT file."""
mqtt_path = config.root / "homeassistant/generated/mqtt.py" mqtt_path = config.root / "homeassistant/generated/mqtt.py"
with open(str(mqtt_path), "w") as fp: with open(str(mqtt_path), "w") as fp:

View File

@ -8,6 +8,7 @@ import os
import re import re
import subprocess import subprocess
import sys import sys
from typing import Any
from awesomeversion import AwesomeVersion, AwesomeVersionStrategy from awesomeversion import AwesomeVersion, AwesomeVersionStrategy
from stdlib_list import stdlib_list from stdlib_list import stdlib_list
@ -53,7 +54,7 @@ IGNORE_VIOLATIONS = {
} }
def validate(integrations: dict[str, Integration], config: Config): def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Handle requirements for integrations.""" """Handle requirements for integrations."""
# Check if we are doing format-only validation. # Check if we are doing format-only validation.
if not config.requirements: if not config.requirements:
@ -63,7 +64,7 @@ def validate(integrations: dict[str, Integration], config: Config):
# check for incompatible requirements # check for incompatible requirements
disable_tqdm = config.specific_integrations or os.environ.get("CI", False) disable_tqdm = bool(config.specific_integrations or os.environ.get("CI"))
for integration in tqdm(integrations.values(), disable=disable_tqdm): for integration in tqdm(integrations.values(), disable=disable_tqdm):
if not integration.manifest: if not integration.manifest:
@ -87,7 +88,13 @@ def validate_requirements_format(integration: Integration) -> bool:
) )
continue continue
pkg, sep, version = PACKAGE_REGEX.match(req).groups() if not (match := PACKAGE_REGEX.match(req)):
integration.add_error(
"requirements",
f'Requirement "{req}" does not match package regex pattern',
)
continue
pkg, sep, version = match.groups()
if integration.core and sep != "==": if integration.core and sep != "==":
integration.add_error( integration.add_error(
@ -115,7 +122,7 @@ def validate_requirements_format(integration: Integration) -> bool:
return len(integration.errors) == start_errors return len(integration.errors) == start_errors
def validate_requirements(integration: Integration): def validate_requirements(integration: Integration) -> None:
"""Validate requirements.""" """Validate requirements."""
if not validate_requirements_format(integration): if not validate_requirements_format(integration):
return return
@ -167,7 +174,7 @@ def validate_requirements(integration: Integration):
@cache @cache
def get_pipdeptree(): def get_pipdeptree() -> dict[str, dict[str, Any]]:
"""Get pipdeptree output. Cached on first invocation. """Get pipdeptree output. Cached on first invocation.
{ {
@ -254,7 +261,7 @@ def install_requirements(integration: Integration, requirements: set[str]) -> bo
if normalized and "==" in requirement_arg: if normalized and "==" in requirement_arg:
ver = requirement_arg.split("==")[-1] ver = requirement_arg.split("==")[-1]
item = deptree.get(normalized) item = deptree.get(normalized)
is_installed = item and item["installed_version"] == ver is_installed = bool(item and item["installed_version"] == ver)
if not is_installed: if not is_installed:
try: try:

View File

@ -5,6 +5,7 @@ from collections.abc import Collection, Iterable, Mapping
from typing import Any from typing import Any
import black import black
from black.mode import Mode
DEFAULT_GENERATOR = "script.hassfest" DEFAULT_GENERATOR = "script.hassfest"
@ -13,7 +14,7 @@ def _wrap_items(
items: Iterable[str], items: Iterable[str],
opener: str, opener: str,
closer: str, closer: str,
sort=False, sort: bool = False,
) -> str: ) -> str:
"""Wrap pre-formatted Python reprs in braces, optionally sorting them.""" """Wrap pre-formatted Python reprs in braces, optionally sorting them."""
# The trailing comma is imperative so Black doesn't format some items # The trailing comma is imperative so Black doesn't format some items
@ -23,7 +24,7 @@ def _wrap_items(
return f"{opener}{','.join(items)},{closer}" return f"{opener}{','.join(items)},{closer}"
def _mapping_to_str(data: Mapping) -> str: def _mapping_to_str(data: Mapping[Any, Any]) -> str:
"""Return a string representation of a mapping.""" """Return a string representation of a mapping."""
return _wrap_items( return _wrap_items(
(f"{to_string(key)}:{to_string(value)}" for key, value in data.items()), (f"{to_string(key)}:{to_string(value)}" for key, value in data.items()),
@ -34,7 +35,10 @@ def _mapping_to_str(data: Mapping) -> str:
def _collection_to_str( def _collection_to_str(
data: Collection, opener: str = "[", closer: str = "]", sort=False data: Collection[Any],
opener: str = "[",
closer: str = "]",
sort: bool = False,
) -> str: ) -> str:
"""Return a string representation of a collection.""" """Return a string representation of a collection."""
items = (to_string(value) for value in data) items = (to_string(value) for value in data)
@ -66,7 +70,7 @@ To update, run python3 -m {generator}
{content} {content}
""" """
return black.format_str(content.strip(), mode=black.Mode()) return black.format_str(content.strip(), mode=Mode())
def format_python_namespace( def format_python_namespace(

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import pathlib import pathlib
import re import re
from typing import Any
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -12,10 +13,10 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, selector from homeassistant.helpers import config_validation as cv, selector
from homeassistant.util.yaml import load_yaml from homeassistant.util.yaml import load_yaml
from .model import Integration from .model import Config, Integration
def exists(value): def exists(value: Any) -> Any:
"""Check if value exists.""" """Check if value exists."""
if value is None: if value is None:
raise vol.Invalid("Value cannot be None") raise vol.Invalid("Value cannot be None")
@ -63,7 +64,7 @@ def grep_dir(path: pathlib.Path, glob_pattern: str, search_pattern: str) -> bool
return False return False
def validate_services(integration: Integration): def validate_services(integration: Integration) -> None:
"""Validate services.""" """Validate services."""
try: try:
data = load_yaml(str(integration.path / "services.yaml")) data = load_yaml(str(integration.path / "services.yaml"))
@ -92,7 +93,7 @@ def validate_services(integration: Integration):
) )
def validate(integrations: dict[str, Integration], config): def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Handle dependencies for integrations.""" """Handle dependencies for integrations."""
# check services.yaml is cool # check services.yaml is cool
for integration in integrations.values(): for integration in integrations.values():

View File

@ -7,7 +7,7 @@ from .model import Config, Integration
from .serializer import format_python_namespace from .serializer import format_python_namespace
def generate_and_validate(integrations: dict[str, Integration]): def generate_and_validate(integrations: dict[str, Integration]) -> str:
"""Validate and generate ssdp data.""" """Validate and generate ssdp data."""
data = defaultdict(list) data = defaultdict(list)
@ -29,7 +29,7 @@ def generate_and_validate(integrations: dict[str, Integration]):
return format_python_namespace({"SSDP": data}) return format_python_namespace({"SSDP": data})
def validate(integrations: dict[str, Integration], config: Config): def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Validate ssdp file.""" """Validate ssdp file."""
ssdp_path = config.root / "homeassistant/generated/ssdp.py" ssdp_path = config.root / "homeassistant/generated/ssdp.py"
config.cache["ssdp"] = content = generate_and_validate(integrations) config.cache["ssdp"] = content = generate_and_validate(integrations)
@ -44,10 +44,9 @@ def validate(integrations: dict[str, Integration], config: Config):
"File ssdp.py is not up to date. Run python3 -m script.hassfest", "File ssdp.py is not up to date. Run python3 -m script.hassfest",
fixable=True, fixable=True,
) )
return
def generate(integrations: dict[str, Integration], config: Config): def generate(integrations: dict[str, Integration], config: Config) -> None:
"""Generate ssdp file.""" """Generate ssdp file."""
ssdp_path = config.root / "homeassistant/generated/ssdp.py" ssdp_path = config.root / "homeassistant/generated/ssdp.py"
with open(str(ssdp_path), "w") as fp: with open(str(ssdp_path), "w") as fp:

View File

@ -5,6 +5,7 @@ from functools import partial
from itertools import chain from itertools import chain
import json import json
import re import re
from typing import Any
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -51,7 +52,7 @@ MOVED_TRANSLATIONS_DIRECTORY_MSG = (
) )
def allow_name_translation(integration: Integration): def allow_name_translation(integration: Integration) -> bool:
"""Validate that the translation name is not the same as the integration name.""" """Validate that the translation name is not the same as the integration name."""
# Only enforce for core because custom integrations can't be # Only enforce for core because custom integrations can't be
# added to allow list. # added to allow list.
@ -74,7 +75,11 @@ def check_translations_directory_name(integration: Integration) -> None:
integration.add_error("translations", MOVED_TRANSLATIONS_DIRECTORY_MSG) integration.add_error("translations", MOVED_TRANSLATIONS_DIRECTORY_MSG)
def find_references(strings, prefix, found): def find_references(
strings: dict[str, Any],
prefix: str,
found: list[dict[str, str]],
) -> None:
"""Find references.""" """Find references."""
for key, value in strings.items(): for key, value in strings.items():
if isinstance(value, dict): if isinstance(value, dict):
@ -87,7 +92,11 @@ def find_references(strings, prefix, found):
found.append({"source": f"{prefix}::{key}", "ref": match.groups()[0]}) found.append({"source": f"{prefix}::{key}", "ref": match.groups()[0]})
def removed_title_validator(config, integration, value): def removed_title_validator(
config: Config,
integration: Integration,
value: Any,
) -> Any:
"""Mark removed title.""" """Mark removed title."""
if not config.specific_integrations: if not config.specific_integrations:
raise vol.Invalid(REMOVED_TITLE_MSG) raise vol.Invalid(REMOVED_TITLE_MSG)
@ -97,7 +106,7 @@ def removed_title_validator(config, integration, value):
return value return value
def lowercase_validator(value): def lowercase_validator(value: str) -> str:
"""Validate value is lowercase.""" """Validate value is lowercase."""
if value.lower() != value: if value.lower() != value:
raise vol.Invalid("Needs to be lowercase") raise vol.Invalid("Needs to be lowercase")
@ -112,7 +121,7 @@ def gen_data_entry_schema(
flow_title: int, flow_title: int,
require_step_title: bool, require_step_title: bool,
mandatory_description: str | None = None, mandatory_description: str | None = None,
): ) -> vol.All:
"""Generate a data entry schema.""" """Generate a data entry schema."""
step_title_class = vol.Required if require_step_title else vol.Optional step_title_class = vol.Required if require_step_title else vol.Optional
schema = { schema = {
@ -138,7 +147,7 @@ def gen_data_entry_schema(
removed_title_validator, config, integration removed_title_validator, config, integration
) )
def data_description_validator(value): def data_description_validator(value: dict[str, Any]) -> dict[str, Any]:
"""Validate data description.""" """Validate data description."""
for step_info in value["step"].values(): for step_info in value["step"].values():
if "data_description" not in step_info: if "data_description" not in step_info:
@ -154,7 +163,7 @@ def gen_data_entry_schema(
if mandatory_description is not None: if mandatory_description is not None:
def validate_description_set(value): def validate_description_set(value: dict[str, Any]) -> dict[str, Any]:
"""Validate description is set.""" """Validate description is set."""
steps = value["step"] steps = value["step"]
if mandatory_description not in steps: if mandatory_description not in steps:
@ -169,7 +178,7 @@ def gen_data_entry_schema(
if not allow_name_translation(integration): if not allow_name_translation(integration):
def name_validator(value): def name_validator(value: dict[str, Any]) -> dict[str, Any]:
"""Validate name.""" """Validate name."""
for step_id, info in value["step"].items(): for step_id, info in value["step"].items():
if info.get("title") == integration.name: if info.get("title") == integration.name:
@ -250,7 +259,7 @@ def gen_strings_schema(config: Config, integration: Integration) -> vol.Schema:
) )
def gen_auth_schema(config: Config, integration: Integration): def gen_auth_schema(config: Config, integration: Integration) -> vol.Schema:
"""Generate auth schema.""" """Generate auth schema."""
return vol.Schema( return vol.Schema(
{ {
@ -266,7 +275,7 @@ def gen_auth_schema(config: Config, integration: Integration):
) )
def gen_platform_strings_schema(config: Config, integration: Integration): def gen_platform_strings_schema(config: Config, integration: Integration) -> vol.Schema:
"""Generate platform strings schema like strings.sensor.json. """Generate platform strings schema like strings.sensor.json.
Example of valid data: Example of valid data:
@ -279,7 +288,7 @@ def gen_platform_strings_schema(config: Config, integration: Integration):
} }
""" """
def device_class_validator(value): def device_class_validator(value: str) -> str:
"""Key validator for platform states. """Key validator for platform states.
Platform states are only allowed to provide states for device classes they prefix. Platform states are only allowed to provide states for device classes they prefix.
@ -313,8 +322,10 @@ ONBOARDING_SCHEMA = vol.Schema({vol.Required("area"): {str: cv.string_with_no_ht
def validate_translation_file( # noqa: C901 def validate_translation_file( # noqa: C901
config: Config, integration: Integration, all_strings config: Config,
): integration: Integration,
all_strings: dict[str, Any] | None,
) -> None:
"""Validate translation files for integration.""" """Validate translation files for integration."""
if config.specific_integrations: if config.specific_integrations:
check_translations_directory_name(integration) check_translations_directory_name(integration)
@ -326,7 +337,7 @@ def validate_translation_file( # noqa: C901
# Only English needs to be always complete # Only English needs to be always complete
strings_files.append(integration.path / "translations/en.json") strings_files.append(integration.path / "translations/en.json")
references = [] references: list[dict[str, str]] = []
if integration.domain == "auth": if integration.domain == "auth":
strings_schema = gen_auth_schema(config, integration) strings_schema = gen_auth_schema(config, integration)
@ -405,6 +416,9 @@ def validate_translation_file( # noqa: C901
if config.specific_integrations: if config.specific_integrations:
return return
if not all_strings: # Nothing to validate against
return
# Validate references # Validate references
for reference in references: for reference in references:
parts = reference["ref"].split("::") parts = reference["ref"].split("::")
@ -421,12 +435,12 @@ def validate_translation_file( # noqa: C901
) )
def validate(integrations: dict[str, Integration], config: Config): def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Handle JSON files inside integrations.""" """Handle JSON files inside integrations."""
if config.specific_integrations: if config.specific_integrations:
all_strings = None all_strings = None
else: else:
all_strings = upload.generate_upload_data() all_strings = upload.generate_upload_data() # type: ignore[no-untyped-call]
for integration in integrations.values(): for integration in integrations.values():
validate_translation_file(config, integration, all_strings) validate_translation_file(config, integration, all_strings)

View File

@ -5,7 +5,7 @@ from .model import Config, Integration
from .serializer import format_python_namespace from .serializer import format_python_namespace
def generate_and_validate(integrations: list[dict[str, str]]) -> str: def generate_and_validate(integrations: dict[str, Integration]) -> str:
"""Validate and generate usb data.""" """Validate and generate usb data."""
match_list = [] match_list = []

View File

@ -9,10 +9,10 @@ from .model import Config, Integration
from .serializer import format_python_namespace from .serializer import format_python_namespace
def generate_and_validate(integrations: dict[str, Integration]): def generate_and_validate(integrations: dict[str, Integration]) -> str:
"""Validate and generate zeroconf data.""" """Validate and generate zeroconf data."""
service_type_dict = defaultdict(list) service_type_dict = defaultdict(list)
homekit_dict = {} homekit_dict: dict[str, str] = {}
for domain in sorted(integrations): for domain in sorted(integrations):
integration = integrations[domain] integration = integrations[domain]
@ -77,7 +77,7 @@ def generate_and_validate(integrations: dict[str, Integration]):
) )
def validate(integrations: dict[str, Integration], config: Config): def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Validate zeroconf file.""" """Validate zeroconf file."""
zeroconf_path = config.root / "homeassistant/generated/zeroconf.py" zeroconf_path = config.root / "homeassistant/generated/zeroconf.py"
config.cache["zeroconf"] = content = generate_and_validate(integrations) config.cache["zeroconf"] = content = generate_and_validate(integrations)
@ -96,7 +96,7 @@ def validate(integrations: dict[str, Integration], config: Config):
return return
def generate(integrations: dict[str, Integration], config: Config): def generate(integrations: dict[str, Integration], config: Config) -> None:
"""Generate zeroconf file.""" """Generate zeroconf file."""
zeroconf_path = config.root / "homeassistant/generated/zeroconf.py" zeroconf_path = config.root / "homeassistant/generated/zeroconf.py"
with open(str(zeroconf_path), "w") as fp: with open(str(zeroconf_path), "w") as fp: