From 7111fc47c4dcb44bc360989ab5708542faf736d7 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 3 Sep 2021 10:15:57 -0700 Subject: [PATCH] Better handle invalid trigger config (#55637) --- .../components/device_automation/trigger.py | 11 +++++--- .../components/hue/device_trigger.py | 16 +++++++----- homeassistant/scripts/check_config.py | 4 +++ tests/scripts/test_check_config.py | 25 +++++++++++-------- 4 files changed, 37 insertions(+), 19 deletions(-) diff --git a/homeassistant/components/device_automation/trigger.py b/homeassistant/components/device_automation/trigger.py index a1b6e53c5c3..1a63dcb9e9b 100644 --- a/homeassistant/components/device_automation/trigger.py +++ b/homeassistant/components/device_automation/trigger.py @@ -7,6 +7,8 @@ from homeassistant.components.device_automation import ( ) from homeassistant.const import CONF_DOMAIN +from .exceptions import InvalidDeviceAutomationConfig + # mypy: allow-untyped-defs, no-check-untyped-defs TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA) @@ -17,10 +19,13 @@ async def async_validate_trigger_config(hass, config): platform = await async_get_device_automation_platform( hass, config[CONF_DOMAIN], "trigger" ) - if hasattr(platform, "async_validate_trigger_config"): - return await getattr(platform, "async_validate_trigger_config")(hass, config) + if not hasattr(platform, "async_validate_trigger_config"): + return platform.TRIGGER_SCHEMA(config) - return platform.TRIGGER_SCHEMA(config) + try: + return await getattr(platform, "async_validate_trigger_config")(hass, config) + except InvalidDeviceAutomationConfig as err: + raise vol.Invalid(str(err) or "Invalid trigger configuration") from err async def async_attach_trigger(hass, config, action, automation_info): diff --git a/homeassistant/components/hue/device_trigger.py b/homeassistant/components/hue/device_trigger.py index ea91cd07d8c..77561e47dc5 100644 --- a/homeassistant/components/hue/device_trigger.py +++ b/homeassistant/components/hue/device_trigger.py @@ -118,12 +118,16 @@ async def async_validate_trigger_config(hass, config): trigger = (config[CONF_TYPE], config[CONF_SUBTYPE]) - if ( - not device - or device.model not in REMOTES - or trigger not in REMOTES[device.model] - ): - raise InvalidDeviceAutomationConfig + if not device: + raise InvalidDeviceAutomationConfig("Device {config[CONF_DEVICE_ID]} not found") + + if device.model not in REMOTES: + raise InvalidDeviceAutomationConfig( + f"Device model {device.model} is not a remote" + ) + + if trigger not in REMOTES[device.model]: + raise InvalidDeviceAutomationConfig("Device does not support trigger {trigger}") return config diff --git a/homeassistant/scripts/check_config.py b/homeassistant/scripts/check_config.py index 551f91b2b54..0ff339169a7 100644 --- a/homeassistant/scripts/check_config.py +++ b/homeassistant/scripts/check_config.py @@ -14,6 +14,7 @@ from unittest.mock import patch from homeassistant import core from homeassistant.config import get_default_config_dir from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import area_registry, device_registry, entity_registry from homeassistant.helpers.check_config import async_check_ha_config_file from homeassistant.util.yaml import Secrets import homeassistant.util.yaml.loader as yaml_loader @@ -229,6 +230,9 @@ async def async_check_config(config_dir): """Check the HA config.""" hass = core.HomeAssistant() hass.config.config_dir = config_dir + await area_registry.async_load(hass) + await device_registry.async_load(hass) + await entity_registry.async_load(hass) components = await async_check_ha_config_file(hass) await hass.async_stop(force=True) return components diff --git a/tests/scripts/test_check_config.py b/tests/scripts/test_check_config.py index ea6048dfc9e..1a96568f8ef 100644 --- a/tests/scripts/test_check_config.py +++ b/tests/scripts/test_check_config.py @@ -27,14 +27,23 @@ async def apply_stop_hass(stop_hass): """Make sure all hass are stopped.""" +@pytest.fixture +def mock_is_file(): + """Mock is_file.""" + # All files exist except for the old entity registry file + with patch( + "os.path.isfile", lambda path: not path.endswith("entity_registry.yaml") + ): + yield + + def normalize_yaml_files(check_dict): """Remove configuration path from ['yaml_files'].""" root = get_test_config_dir() return [key.replace(root, "...") for key in sorted(check_dict["yaml_files"].keys())] -@patch("os.path.isfile", return_value=True) -def test_bad_core_config(isfile_patch, loop): +def test_bad_core_config(mock_is_file, loop): """Test a bad core config setup.""" files = {YAML_CONFIG_FILE: BAD_CORE_CONFIG} with patch_yaml_files(files): @@ -43,8 +52,7 @@ def test_bad_core_config(isfile_patch, loop): assert res["except"]["homeassistant"][1] == {"unit_system": "bad"} -@patch("os.path.isfile", return_value=True) -def test_config_platform_valid(isfile_patch, loop): +def test_config_platform_valid(mock_is_file, loop): """Test a valid platform setup.""" files = {YAML_CONFIG_FILE: BASE_CONFIG + "light:\n platform: demo"} with patch_yaml_files(files): @@ -57,8 +65,7 @@ def test_config_platform_valid(isfile_patch, loop): assert len(res["yaml_files"]) == 1 -@patch("os.path.isfile", return_value=True) -def test_component_platform_not_found(isfile_patch, loop): +def test_component_platform_not_found(mock_is_file, loop): """Test errors if component or platform not found.""" # Make sure they don't exist files = {YAML_CONFIG_FILE: BASE_CONFIG + "beer:"} @@ -89,8 +96,7 @@ def test_component_platform_not_found(isfile_patch, loop): assert len(res["yaml_files"]) == 1 -@patch("os.path.isfile", return_value=True) -def test_secrets(isfile_patch, loop): +def test_secrets(mock_is_file, loop): """Test secrets config checking method.""" secrets_path = get_test_config_dir("secrets.yaml") @@ -121,8 +127,7 @@ def test_secrets(isfile_patch, loop): ] -@patch("os.path.isfile", return_value=True) -def test_package_invalid(isfile_patch, loop): +def test_package_invalid(mock_is_file, loop): """Test an invalid package.""" files = { YAML_CONFIG_FILE: BASE_CONFIG + (" packages:\n p1:\n" ' group: ["a"]')