diff --git a/homeassistant/components/nut/device_action.py b/homeassistant/components/nut/device_action.py index 86f7fe5a7e6..c622e63a12c 100644 --- a/homeassistant/components/nut/device_action.py +++ b/homeassistant/components/nut/device_action.py @@ -2,15 +2,18 @@ from __future__ import annotations +from typing import cast + import voluptuous as vol from homeassistant.components.device_automation import InvalidDeviceAutomationConfig +from homeassistant.config_entries import ConfigEntryState from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_TYPE from homeassistant.core import Context, HomeAssistant from homeassistant.helpers import config_validation as cv, device_registry as dr from homeassistant.helpers.typing import ConfigType, TemplateVarsType -from . import NutRuntimeData +from . import NutConfigEntry, NutRuntimeData from .const import DOMAIN, INTEGRATION_SUPPORTED_COMMANDS ACTION_TYPES = {cmd.replace(".", "_") for cmd in INTEGRATION_SUPPORTED_COMMANDS} @@ -48,16 +51,11 @@ async def async_call_action_from_config( device_action_name: str = config[CONF_TYPE] command_name = _get_command_name(device_action_name) device_id: str = config[CONF_DEVICE_ID] - runtime_data = _get_runtime_data_from_device_id(hass, device_id) - if not runtime_data: - raise InvalidDeviceAutomationConfig( - translation_domain=DOMAIN, - translation_key="device_invalid", - translation_placeholders={ - "device_id": device_id, - }, - ) - await runtime_data.data.async_run_command(command_name) + + if runtime_data := _get_runtime_data_from_device_id_exception_on_failure( + hass, device_id + ): + await runtime_data.data.async_run_command(command_name) def _get_device_action_name(command_name: str) -> str: @@ -69,13 +67,55 @@ def _get_command_name(device_action_name: str) -> str: def _get_runtime_data_from_device_id( - hass: HomeAssistant, device_id: str + hass: HomeAssistant, + device_id: str, ) -> NutRuntimeData | None: + """Find the runtime data for device ID and return None on error.""" device_registry = dr.async_get(hass) if (device := device_registry.async_get(device_id)) is None: return None - entry = hass.config_entries.async_get_entry( - next(entry_id for entry_id in device.config_entries) + return _get_runtime_data_for_device(hass, device) + + +def _get_runtime_data_for_device( + hass: HomeAssistant, device: dr.DeviceEntry +) -> NutRuntimeData | None: + """Find the runtime data for device and return None on error.""" + for config_entry_id in device.config_entries: + entry = hass.config_entries.async_get_entry(config_entry_id) + if ( + entry + and entry.domain == DOMAIN + and entry.state is ConfigEntryState.LOADED + and hasattr(entry, "runtime_data") + ): + return cast(NutConfigEntry, entry).runtime_data + + return None + + +def _get_runtime_data_from_device_id_exception_on_failure( + hass: HomeAssistant, + device_id: str, +) -> NutRuntimeData | None: + """Find the runtime data for device ID and raise exception on error.""" + device_registry = dr.async_get(hass) + if (device := device_registry.async_get(device_id)) is None: + raise InvalidDeviceAutomationConfig( + translation_domain=DOMAIN, + translation_key="device_not_found", + translation_placeholders={ + "device_id": device_id, + }, + ) + + if runtime_data := _get_runtime_data_for_device(hass, device): + return runtime_data + + raise InvalidDeviceAutomationConfig( + translation_domain=DOMAIN, + translation_key="config_invalid", + translation_placeholders={ + "device_id": device_id, + }, ) - assert entry and isinstance(entry.runtime_data, NutRuntimeData) - return entry.runtime_data diff --git a/homeassistant/components/nut/strings.json b/homeassistant/components/nut/strings.json index df251ae632f..a9a3b470cca 100644 --- a/homeassistant/components/nut/strings.json +++ b/homeassistant/components/nut/strings.json @@ -312,13 +312,16 @@ } }, "exceptions": { + "config_invalid": { + "message": "Invalid configuration entries for NUT device with ID {device_id}" + }, "data_fetch_error": { "message": "Error fetching UPS state: {err}" }, "device_authentication": { "message": "Device authentication error: {err}" }, - "device_invalid": { + "device_not_found": { "message": "Unable to find a NUT device with ID {device_id}" }, "nut_command_error": { diff --git a/tests/components/nut/test_device_action.py b/tests/components/nut/test_device_action.py index ea6b7306a5f..3f48d073f9f 100644 --- a/tests/components/nut/test_device_action.py +++ b/tests/components/nut/test_device_action.py @@ -21,7 +21,7 @@ from homeassistant.setup import async_setup_component from .util import async_init_integration -from tests.common import async_get_device_automations +from tests.common import MockConfigEntry, async_get_device_automations async def test_get_all_actions_for_specified_user( @@ -79,10 +79,10 @@ async def test_no_actions_for_anonymous_user( assert len(actions) == 0 -async def test_no_actions_invalid_device( +async def test_no_actions_device_not_found( hass: HomeAssistant, ) -> None: - """Test we get no actions for an invalid device.""" + """Test we get no actions for a device that cannot be found.""" list_commands_return_value = {"beeper.enable": None} await async_init_integration( hass, @@ -99,6 +99,30 @@ async def test_no_actions_invalid_device( assert len(actions) == 0 +async def test_no_actions_device_invalid( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, +) -> None: + """Test we get no actions for a device that is invalid.""" + list_commands_return_value = {"beeper.enable": None} + entry = await async_init_integration( + hass, + list_vars={"ups.status": "OL"}, + list_commands_return_value=list_commands_return_value, + ) + device_entry = next(device for device in device_registry.devices.values()) + + assert await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + + platform = await device_automation.async_get_device_automation_platform( + hass, DOMAIN, DeviceAutomationType.ACTION + ) + actions = await platform.async_get_actions(hass, device_entry.id) + + assert len(actions) == 0 + + async def test_list_commands_exception( hass: HomeAssistant, device_registry: dr.DeviceRegistry ) -> None: @@ -227,8 +251,8 @@ async def test_run_command_exception( ) -async def test_action_exception_invalid_device(hass: HomeAssistant) -> None: - """Test raises exception if invalid device.""" +async def test_action_exception_device_not_found(hass: HomeAssistant) -> None: + """Test raises exception if device not found.""" list_commands_return_value = {"beeper.enable": None} await async_init_integration( hass, @@ -249,3 +273,64 @@ async def test_action_exception_invalid_device(hass: HomeAssistant) -> None: {}, None, ) + + +async def test_action_exception_invalid_config( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, +) -> None: + """Test raises exception if no NUT config entry found.""" + + config_entry = MockConfigEntry() + config_entry.add_to_hass(hass) + await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={(DOMAIN, "mock-identifier")}, + ) + + platform = await device_automation.async_get_device_automation_platform( + hass, DOMAIN, DeviceAutomationType.ACTION + ) + + with pytest.raises(InvalidDeviceAutomationConfig): + await platform.async_call_action_from_config( + hass, + {CONF_TYPE: "beeper.enable", CONF_DEVICE_ID: device_entry.id}, + {}, + None, + ) + + +async def test_action_exception_device_invalid( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, +) -> None: + """Test raises exception if config entry for device is invalid.""" + list_commands_return_value = {"beeper.enable": None} + entry = await async_init_integration( + hass, + list_vars={"ups.status": "OL"}, + list_commands_return_value=list_commands_return_value, + ) + device_entry = next(device for device in device_registry.devices.values()) + + assert await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + + platform = await device_automation.async_get_device_automation_platform( + hass, DOMAIN, DeviceAutomationType.ACTION + ) + + error_message = ( + f"Invalid configuration entries for NUT device with ID {device_entry.id}" + ) + with pytest.raises(InvalidDeviceAutomationConfig, match=error_message): + await platform.async_call_action_from_config( + hass, + {CONF_TYPE: "beeper.enable", CONF_DEVICE_ID: device_entry.id}, + {}, + None, + )