diff --git a/homeassistant/components/fan/device_action.py b/homeassistant/components/fan/device_action.py index 55bd862349b..fc7f1ddce1f 100644 --- a/homeassistant/components/fan/device_action.py +++ b/homeassistant/components/fan/device_action.py @@ -3,14 +3,24 @@ from __future__ import annotations import voluptuous as vol -from homeassistant.components.device_automation import toggle_entity +from homeassistant.components.device_automation import ( + async_validate_entity_schema, + toggle_entity, +) from homeassistant.const import CONF_DOMAIN from homeassistant.core import Context, HomeAssistant from homeassistant.helpers.typing import ConfigType, TemplateVarsType from . import DOMAIN -ACTION_SCHEMA = toggle_entity.ACTION_SCHEMA.extend({vol.Required(CONF_DOMAIN): DOMAIN}) +_ACTION_SCHEMA = toggle_entity.ACTION_SCHEMA.extend({vol.Required(CONF_DOMAIN): DOMAIN}) + + +async def async_validate_action_config( + hass: HomeAssistant, config: ConfigType +) -> ConfigType: + """Validate config.""" + return async_validate_entity_schema(hass, config, _ACTION_SCHEMA) async def async_get_actions( diff --git a/tests/components/fan/test_device_action.py b/tests/components/fan/test_device_action.py index 3b179bc158c..b8756d9ace5 100644 --- a/tests/components/fan/test_device_action.py +++ b/tests/components/fan/test_device_action.py @@ -171,6 +171,7 @@ async def test_action( hass.bus.async_fire("test_event_turn_off") await hass.async_block_till_done() assert len(turn_off_calls) == 1 + assert turn_off_calls[0].data["entity_id"] == entry.entity_id assert len(turn_on_calls) == 0 assert len(toggle_calls) == 0 @@ -178,6 +179,7 @@ async def test_action( await hass.async_block_till_done() assert len(turn_off_calls) == 1 assert len(turn_on_calls) == 1 + assert turn_on_calls[0].data["entity_id"] == entry.entity_id assert len(toggle_calls) == 0 hass.bus.async_fire("test_event_toggle") @@ -185,6 +187,7 @@ async def test_action( assert len(turn_off_calls) == 1 assert len(turn_on_calls) == 1 assert len(toggle_calls) == 1 + assert toggle_calls[0].data["entity_id"] == entry.entity_id async def test_action_legacy(