Add a base class for template entities to inherit from (#139645)

* add-abstract-template-entity-base-class

* review 1 changes
This commit is contained in:
Petro31 2025-03-04 01:23:05 -05:00 committed by GitHub
parent a778092941
commit 890d3f4af4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 384 additions and 398 deletions

4
CODEOWNERS generated
View File

@ -1529,8 +1529,8 @@ build.json @home-assistant/supervisor
/tests/components/tedee/ @patrickhilker @zweckj /tests/components/tedee/ @patrickhilker @zweckj
/homeassistant/components/tellduslive/ @fredrike /homeassistant/components/tellduslive/ @fredrike
/tests/components/tellduslive/ @fredrike /tests/components/tellduslive/ @fredrike
/homeassistant/components/template/ @PhracturedBlue @home-assistant/core /homeassistant/components/template/ @Petro31 @PhracturedBlue @home-assistant/core
/tests/components/template/ @PhracturedBlue @home-assistant/core /tests/components/template/ @Petro31 @PhracturedBlue @home-assistant/core
/homeassistant/components/tesla_fleet/ @Bre77 /homeassistant/components/tesla_fleet/ @Bre77
/tests/components/tesla_fleet/ @Bre77 /tests/components/tesla_fleet/ @Bre77
/homeassistant/components/tesla_wall_connector/ @einarhauks /homeassistant/components/tesla_wall_connector/ @einarhauks

View File

@ -36,7 +36,6 @@ from homeassistant.helpers.entity_platform import (
AddEntitiesCallback, AddEntitiesCallback,
) )
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.script import Script
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.util import slugify from homeassistant.util import slugify
@ -199,70 +198,31 @@ class AlarmControlPanelTemplate(TemplateEntity, AlarmControlPanelEntity, Restore
name = self._attr_name name = self._attr_name
assert name is not None assert name is not None
self._template = config.get(CONF_VALUE_TEMPLATE) self._template = config.get(CONF_VALUE_TEMPLATE)
self._disarm_script = None
self._attr_code_arm_required: bool = config[CONF_CODE_ARM_REQUIRED] self._attr_code_arm_required: bool = config[CONF_CODE_ARM_REQUIRED]
self._attr_code_format = config[CONF_CODE_FORMAT].value self._attr_code_format = config[CONF_CODE_FORMAT].value
if (disarm_action := config.get(CONF_DISARM_ACTION)) is not None:
self._disarm_script = Script(hass, disarm_action, name, DOMAIN) self._attr_supported_features = AlarmControlPanelEntityFeature(0)
self._arm_away_script = None for action_id, supported_feature in (
if (arm_away_action := config.get(CONF_ARM_AWAY_ACTION)) is not None: (CONF_DISARM_ACTION, 0),
self._arm_away_script = Script(hass, arm_away_action, name, DOMAIN) (CONF_ARM_AWAY_ACTION, AlarmControlPanelEntityFeature.ARM_AWAY),
self._arm_home_script = None (CONF_ARM_HOME_ACTION, AlarmControlPanelEntityFeature.ARM_HOME),
if (arm_home_action := config.get(CONF_ARM_HOME_ACTION)) is not None: (CONF_ARM_NIGHT_ACTION, AlarmControlPanelEntityFeature.ARM_NIGHT),
self._arm_home_script = Script(hass, arm_home_action, name, DOMAIN) (CONF_ARM_VACATION_ACTION, AlarmControlPanelEntityFeature.ARM_VACATION),
self._arm_night_script = None (
if (arm_night_action := config.get(CONF_ARM_NIGHT_ACTION)) is not None: CONF_ARM_CUSTOM_BYPASS_ACTION,
self._arm_night_script = Script(hass, arm_night_action, name, DOMAIN) AlarmControlPanelEntityFeature.ARM_CUSTOM_BYPASS,
self._arm_vacation_script = None ),
if (arm_vacation_action := config.get(CONF_ARM_VACATION_ACTION)) is not None: (CONF_TRIGGER_ACTION, AlarmControlPanelEntityFeature.TRIGGER),
self._arm_vacation_script = Script(hass, arm_vacation_action, name, DOMAIN) ):
self._arm_custom_bypass_script = None if action_config := config.get(action_id):
if ( self.add_script(action_id, action_config, name, DOMAIN)
arm_custom_bypass_action := config.get(CONF_ARM_CUSTOM_BYPASS_ACTION) self._attr_supported_features |= supported_feature
) is not None:
self._arm_custom_bypass_script = Script(
hass, arm_custom_bypass_action, name, DOMAIN
)
self._trigger_script = None
if (trigger_action := config.get(CONF_TRIGGER_ACTION)) is not None:
self._trigger_script = Script(hass, trigger_action, name, DOMAIN)
self._state: AlarmControlPanelState | None = None self._state: AlarmControlPanelState | None = None
self._attr_device_info = async_device_info_to_link_from_device_id( self._attr_device_info = async_device_info_to_link_from_device_id(
hass, hass,
config.get(CONF_DEVICE_ID), config.get(CONF_DEVICE_ID),
) )
supported_features = AlarmControlPanelEntityFeature(0)
if self._arm_night_script is not None:
supported_features = (
supported_features | AlarmControlPanelEntityFeature.ARM_NIGHT
)
if self._arm_home_script is not None:
supported_features = (
supported_features | AlarmControlPanelEntityFeature.ARM_HOME
)
if self._arm_away_script is not None:
supported_features = (
supported_features | AlarmControlPanelEntityFeature.ARM_AWAY
)
if self._arm_vacation_script is not None:
supported_features = (
supported_features | AlarmControlPanelEntityFeature.ARM_VACATION
)
if self._arm_custom_bypass_script is not None:
supported_features = (
supported_features | AlarmControlPanelEntityFeature.ARM_CUSTOM_BYPASS
)
if self._trigger_script is not None:
supported_features = (
supported_features | AlarmControlPanelEntityFeature.TRIGGER
)
self._attr_supported_features = supported_features
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Restore last state.""" """Restore last state."""
@ -330,7 +290,7 @@ class AlarmControlPanelTemplate(TemplateEntity, AlarmControlPanelEntity, Restore
"""Arm the panel to Away.""" """Arm the panel to Away."""
await self._async_alarm_arm( await self._async_alarm_arm(
AlarmControlPanelState.ARMED_AWAY, AlarmControlPanelState.ARMED_AWAY,
script=self._arm_away_script, script=self._action_scripts.get(CONF_ARM_AWAY_ACTION),
code=code, code=code,
) )
@ -338,7 +298,7 @@ class AlarmControlPanelTemplate(TemplateEntity, AlarmControlPanelEntity, Restore
"""Arm the panel to Home.""" """Arm the panel to Home."""
await self._async_alarm_arm( await self._async_alarm_arm(
AlarmControlPanelState.ARMED_HOME, AlarmControlPanelState.ARMED_HOME,
script=self._arm_home_script, script=self._action_scripts.get(CONF_ARM_HOME_ACTION),
code=code, code=code,
) )
@ -346,7 +306,7 @@ class AlarmControlPanelTemplate(TemplateEntity, AlarmControlPanelEntity, Restore
"""Arm the panel to Night.""" """Arm the panel to Night."""
await self._async_alarm_arm( await self._async_alarm_arm(
AlarmControlPanelState.ARMED_NIGHT, AlarmControlPanelState.ARMED_NIGHT,
script=self._arm_night_script, script=self._action_scripts.get(CONF_ARM_NIGHT_ACTION),
code=code, code=code,
) )
@ -354,7 +314,7 @@ class AlarmControlPanelTemplate(TemplateEntity, AlarmControlPanelEntity, Restore
"""Arm the panel to Vacation.""" """Arm the panel to Vacation."""
await self._async_alarm_arm( await self._async_alarm_arm(
AlarmControlPanelState.ARMED_VACATION, AlarmControlPanelState.ARMED_VACATION,
script=self._arm_vacation_script, script=self._action_scripts.get(CONF_ARM_VACATION_ACTION),
code=code, code=code,
) )
@ -362,20 +322,22 @@ class AlarmControlPanelTemplate(TemplateEntity, AlarmControlPanelEntity, Restore
"""Arm the panel to Custom Bypass.""" """Arm the panel to Custom Bypass."""
await self._async_alarm_arm( await self._async_alarm_arm(
AlarmControlPanelState.ARMED_CUSTOM_BYPASS, AlarmControlPanelState.ARMED_CUSTOM_BYPASS,
script=self._arm_custom_bypass_script, script=self._action_scripts.get(CONF_ARM_CUSTOM_BYPASS_ACTION),
code=code, code=code,
) )
async def async_alarm_disarm(self, code: str | None = None) -> None: async def async_alarm_disarm(self, code: str | None = None) -> None:
"""Disarm the panel.""" """Disarm the panel."""
await self._async_alarm_arm( await self._async_alarm_arm(
AlarmControlPanelState.DISARMED, script=self._disarm_script, code=code AlarmControlPanelState.DISARMED,
script=self._action_scripts.get(CONF_DISARM_ACTION),
code=code,
) )
async def async_alarm_trigger(self, code: str | None = None) -> None: async def async_alarm_trigger(self, code: str | None = None) -> None:
"""Trigger the panel.""" """Trigger the panel."""
await self._async_alarm_arm( await self._async_alarm_arm(
AlarmControlPanelState.TRIGGERED, AlarmControlPanelState.TRIGGERED,
script=self._trigger_script, script=self._action_scripts.get(CONF_TRIGGER_ACTION),
code=code, code=code,
) )

View File

@ -23,7 +23,6 @@ from homeassistant.helpers.entity_platform import (
AddConfigEntryEntitiesCallback, AddConfigEntryEntitiesCallback,
AddEntitiesCallback, AddEntitiesCallback,
) )
from homeassistant.helpers.script import Script
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import CONF_PRESS, DOMAIN from .const import CONF_PRESS, DOMAIN
@ -121,11 +120,8 @@ class TemplateButtonEntity(TemplateEntity, ButtonEntity):
"""Initialize the button.""" """Initialize the button."""
super().__init__(hass, config=config, unique_id=unique_id) super().__init__(hass, config=config, unique_id=unique_id)
assert self._attr_name is not None assert self._attr_name is not None
self._command_press = ( if action := config.get(CONF_PRESS):
Script(hass, config.get(CONF_PRESS), self._attr_name, DOMAIN) self.add_script(CONF_PRESS, action, self._attr_name, DOMAIN)
if config.get(CONF_PRESS, None) is not None
else None
)
self._attr_device_class = config.get(CONF_DEVICE_CLASS) self._attr_device_class = config.get(CONF_DEVICE_CLASS)
self._attr_state = None self._attr_state = None
self._attr_device_info = async_device_info_to_link_from_device_id( self._attr_device_info = async_device_info_to_link_from_device_id(
@ -135,5 +131,5 @@ class TemplateButtonEntity(TemplateEntity, ButtonEntity):
async def async_press(self) -> None: async def async_press(self) -> None:
"""Press the button.""" """Press the button."""
if self._command_press: if script := self._action_scripts.get(CONF_PRESS):
await self.async_run_script(self._command_press, context=self._context) await self.async_run_script(script, context=self._context)

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any from typing import TYPE_CHECKING, Any
import voluptuous as vol import voluptuous as vol
@ -30,7 +30,6 @@ from homeassistant.exceptions import TemplateError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity import async_generate_entity_id from homeassistant.helpers.entity import async_generate_entity_id
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.script import Script
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import DOMAIN from .const import DOMAIN
@ -103,7 +102,7 @@ PLATFORM_SCHEMA = COVER_PLATFORM_SCHEMA.extend(
) )
async def _async_create_entities(hass, config): async def _async_create_entities(hass: HomeAssistant, config):
"""Create the Template cover.""" """Create the Template cover."""
covers = [] covers = []
@ -141,11 +140,11 @@ class CoverTemplate(TemplateEntity, CoverEntity):
def __init__( def __init__(
self, self,
hass, hass: HomeAssistant,
object_id, object_id,
config, config: dict[str, Any],
unique_id, unique_id,
): ) -> None:
"""Initialize the Template cover.""" """Initialize the Template cover."""
super().__init__( super().__init__(
hass, config=config, fallback_name=object_id, unique_id=unique_id hass, config=config, fallback_name=object_id, unique_id=unique_id
@ -153,45 +152,40 @@ class CoverTemplate(TemplateEntity, CoverEntity):
self.entity_id = async_generate_entity_id( self.entity_id = async_generate_entity_id(
ENTITY_ID_FORMAT, object_id, hass=hass ENTITY_ID_FORMAT, object_id, hass=hass
) )
friendly_name = self._attr_name name = self._attr_name
if TYPE_CHECKING:
assert name is not None
self._template = config.get(CONF_VALUE_TEMPLATE) self._template = config.get(CONF_VALUE_TEMPLATE)
self._position_template = config.get(CONF_POSITION_TEMPLATE) self._position_template = config.get(CONF_POSITION_TEMPLATE)
self._tilt_template = config.get(CONF_TILT_TEMPLATE) self._tilt_template = config.get(CONF_TILT_TEMPLATE)
self._attr_device_class = config.get(CONF_DEVICE_CLASS) self._attr_device_class = config.get(CONF_DEVICE_CLASS)
self._open_script = None
if (open_action := config.get(OPEN_ACTION)) is not None: # The config requires (open and close scripts) or a set position script,
self._open_script = Script(hass, open_action, friendly_name, DOMAIN) # therefore the base supported features will always include them.
self._close_script = None self._attr_supported_features = (
if (close_action := config.get(CLOSE_ACTION)) is not None: CoverEntityFeature.OPEN | CoverEntityFeature.CLOSE
self._close_script = Script(hass, close_action, friendly_name, DOMAIN) )
self._stop_script = None for action_id, supported_feature in (
if (stop_action := config.get(STOP_ACTION)) is not None: (OPEN_ACTION, 0),
self._stop_script = Script(hass, stop_action, friendly_name, DOMAIN) (CLOSE_ACTION, 0),
self._position_script = None (STOP_ACTION, CoverEntityFeature.STOP),
if (position_action := config.get(POSITION_ACTION)) is not None: (POSITION_ACTION, CoverEntityFeature.SET_POSITION),
self._position_script = Script(hass, position_action, friendly_name, DOMAIN) (TILT_ACTION, TILT_FEATURES),
self._tilt_script = None ):
if (tilt_action := config.get(TILT_ACTION)) is not None: if action_config := config.get(action_id):
self._tilt_script = Script(hass, tilt_action, friendly_name, DOMAIN) self.add_script(action_id, action_config, name, DOMAIN)
self._attr_supported_features |= supported_feature
optimistic = config.get(CONF_OPTIMISTIC) optimistic = config.get(CONF_OPTIMISTIC)
self._optimistic = optimistic or ( self._optimistic = optimistic or (
optimistic is None and not self._template and not self._position_template optimistic is None and not self._template and not self._position_template
) )
tilt_optimistic = config.get(CONF_TILT_OPTIMISTIC) tilt_optimistic = config.get(CONF_TILT_OPTIMISTIC)
self._tilt_optimistic = tilt_optimistic or not self._tilt_template self._tilt_optimistic = tilt_optimistic or not self._tilt_template
self._position = None self._position: int | None = None
self._is_opening = False self._is_opening = False
self._is_closing = False self._is_closing = False
self._tilt_value = None self._tilt_value: int | None = None
supported_features = CoverEntityFeature.OPEN | CoverEntityFeature.CLOSE
if self._stop_script is not None:
supported_features |= CoverEntityFeature.STOP
if self._position_script is not None:
supported_features |= CoverEntityFeature.SET_POSITION
if self._tilt_script is not None:
supported_features |= TILT_FEATURES
self._attr_supported_features = supported_features
@callback @callback
def _async_setup_templates(self) -> None: def _async_setup_templates(self) -> None:
@ -317,7 +311,7 @@ class CoverTemplate(TemplateEntity, CoverEntity):
None is unknown, 0 is closed, 100 is fully open. None is unknown, 0 is closed, 100 is fully open.
""" """
if self._position_template or self._position_script: if self._position_template or self._action_scripts.get(POSITION_ACTION):
return self._position return self._position
return None return None
@ -331,11 +325,11 @@ class CoverTemplate(TemplateEntity, CoverEntity):
async def async_open_cover(self, **kwargs: Any) -> None: async def async_open_cover(self, **kwargs: Any) -> None:
"""Move the cover up.""" """Move the cover up."""
if self._open_script: if (open_script := self._action_scripts.get(OPEN_ACTION)) is not None:
await self.async_run_script(self._open_script, context=self._context) await self.async_run_script(open_script, context=self._context)
elif self._position_script: elif (position_script := self._action_scripts.get(POSITION_ACTION)) is not None:
await self.async_run_script( await self.async_run_script(
self._position_script, position_script,
run_variables={"position": 100}, run_variables={"position": 100},
context=self._context, context=self._context,
) )
@ -345,11 +339,11 @@ class CoverTemplate(TemplateEntity, CoverEntity):
async def async_close_cover(self, **kwargs: Any) -> None: async def async_close_cover(self, **kwargs: Any) -> None:
"""Move the cover down.""" """Move the cover down."""
if self._close_script: if (close_script := self._action_scripts.get(CLOSE_ACTION)) is not None:
await self.async_run_script(self._close_script, context=self._context) await self.async_run_script(close_script, context=self._context)
elif self._position_script: elif (position_script := self._action_scripts.get(POSITION_ACTION)) is not None:
await self.async_run_script( await self.async_run_script(
self._position_script, position_script,
run_variables={"position": 0}, run_variables={"position": 0},
context=self._context, context=self._context,
) )
@ -359,14 +353,14 @@ class CoverTemplate(TemplateEntity, CoverEntity):
async def async_stop_cover(self, **kwargs: Any) -> None: async def async_stop_cover(self, **kwargs: Any) -> None:
"""Fire the stop action.""" """Fire the stop action."""
if self._stop_script: if (stop_script := self._action_scripts.get(STOP_ACTION)) is not None:
await self.async_run_script(self._stop_script, context=self._context) await self.async_run_script(stop_script, context=self._context)
async def async_set_cover_position(self, **kwargs: Any) -> None: async def async_set_cover_position(self, **kwargs: Any) -> None:
"""Set cover position.""" """Set cover position."""
self._position = kwargs[ATTR_POSITION] self._position = kwargs[ATTR_POSITION]
await self.async_run_script( await self.async_run_script(
self._position_script, self._action_scripts[POSITION_ACTION],
run_variables={"position": self._position}, run_variables={"position": self._position},
context=self._context, context=self._context,
) )
@ -377,7 +371,7 @@ class CoverTemplate(TemplateEntity, CoverEntity):
"""Tilt the cover open.""" """Tilt the cover open."""
self._tilt_value = 100 self._tilt_value = 100
await self.async_run_script( await self.async_run_script(
self._tilt_script, self._action_scripts[TILT_ACTION],
run_variables={"tilt": self._tilt_value}, run_variables={"tilt": self._tilt_value},
context=self._context, context=self._context,
) )
@ -388,7 +382,7 @@ class CoverTemplate(TemplateEntity, CoverEntity):
"""Tilt the cover closed.""" """Tilt the cover closed."""
self._tilt_value = 0 self._tilt_value = 0
await self.async_run_script( await self.async_run_script(
self._tilt_script, self._action_scripts[TILT_ACTION],
run_variables={"tilt": self._tilt_value}, run_variables={"tilt": self._tilt_value},
context=self._context, context=self._context,
) )
@ -399,7 +393,7 @@ class CoverTemplate(TemplateEntity, CoverEntity):
"""Move the cover tilt to a specific position.""" """Move the cover tilt to a specific position."""
self._tilt_value = kwargs[ATTR_TILT_POSITION] self._tilt_value = kwargs[ATTR_TILT_POSITION]
await self.async_run_script( await self.async_run_script(
self._tilt_script, self._action_scripts[TILT_ACTION],
run_variables={"tilt": self._tilt_value}, run_variables={"tilt": self._tilt_value},
context=self._context, context=self._context,
) )

View File

@ -0,0 +1,66 @@
"""Template entity base class."""
from collections.abc import Sequence
from typing import Any
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.script import Script, _VarsType
from homeassistant.helpers.template import TemplateStateFromEntityId
class AbstractTemplateEntity(Entity):
"""Actions linked to a template entity."""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the entity."""
self.hass = hass
self._action_scripts: dict[str, Script] = {}
@property
def referenced_blueprint(self) -> str | None:
"""Return referenced blueprint or None."""
raise NotImplementedError
@callback
def _render_script_variables(self) -> dict:
"""Render configured variables."""
raise NotImplementedError
def add_script(
self,
script_id: str,
config: Sequence[dict[str, Any]],
name: str,
domain: str,
):
"""Add an action script."""
# Cannot use self.hass because it may be None in child class
# at instantiation.
self._action_scripts[script_id] = Script(
self.hass,
config,
f"{name} {script_id}",
domain,
)
async def async_run_script(
self,
script: Script,
*,
run_variables: _VarsType | None = None,
context: Context | None = None,
) -> None:
"""Run an action script."""
if run_variables is None:
run_variables = {}
await script.async_run(
run_variables={
"this": TemplateStateFromEntityId(self.hass, self.entity_id),
**self._render_script_variables(),
**run_variables,
},
context=context,
)

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any from typing import TYPE_CHECKING, Any
import voluptuous as vol import voluptuous as vol
@ -32,7 +32,6 @@ from homeassistant.exceptions import TemplateError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity import async_generate_entity_id from homeassistant.helpers.entity import async_generate_entity_id
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.script import Script
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import DOMAIN from .const import DOMAIN
@ -89,7 +88,7 @@ PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
) )
async def _async_create_entities(hass, config): async def _async_create_entities(hass: HomeAssistant, config):
"""Create the Template Fans.""" """Create the Template Fans."""
fans = [] fans = []
@ -127,11 +126,11 @@ class TemplateFan(TemplateEntity, FanEntity):
def __init__( def __init__(
self, self,
hass, hass: HomeAssistant,
object_id, object_id,
config, config: dict[str, Any],
unique_id, unique_id,
): ) -> None:
"""Initialize the fan.""" """Initialize the fan."""
super().__init__( super().__init__(
hass, config=config, fallback_name=object_id, unique_id=unique_id hass, config=config, fallback_name=object_id, unique_id=unique_id
@ -140,7 +139,9 @@ class TemplateFan(TemplateEntity, FanEntity):
self.entity_id = async_generate_entity_id( self.entity_id = async_generate_entity_id(
ENTITY_ID_FORMAT, object_id, hass=hass ENTITY_ID_FORMAT, object_id, hass=hass
) )
friendly_name = self._attr_name name = self._attr_name
if TYPE_CHECKING:
assert name is not None
self._template = config.get(CONF_VALUE_TEMPLATE) self._template = config.get(CONF_VALUE_TEMPLATE)
self._percentage_template = config.get(CONF_PERCENTAGE_TEMPLATE) self._percentage_template = config.get(CONF_PERCENTAGE_TEMPLATE)
@ -148,44 +149,28 @@ class TemplateFan(TemplateEntity, FanEntity):
self._oscillating_template = config.get(CONF_OSCILLATING_TEMPLATE) self._oscillating_template = config.get(CONF_OSCILLATING_TEMPLATE)
self._direction_template = config.get(CONF_DIRECTION_TEMPLATE) self._direction_template = config.get(CONF_DIRECTION_TEMPLATE)
self._on_script = Script(hass, config[CONF_ON_ACTION], friendly_name, DOMAIN) for action_id in (
self._off_script = Script(hass, config[CONF_OFF_ACTION], friendly_name, DOMAIN) CONF_ON_ACTION,
CONF_OFF_ACTION,
self._set_percentage_script = None CONF_SET_PERCENTAGE_ACTION,
if set_percentage_action := config.get(CONF_SET_PERCENTAGE_ACTION): CONF_SET_PRESET_MODE_ACTION,
self._set_percentage_script = Script( CONF_SET_OSCILLATING_ACTION,
hass, set_percentage_action, friendly_name, DOMAIN CONF_SET_DIRECTION_ACTION,
) ):
if action_config := config.get(action_id):
self._set_preset_mode_script = None self.add_script(action_id, action_config, name, DOMAIN)
if set_preset_mode_action := config.get(CONF_SET_PRESET_MODE_ACTION):
self._set_preset_mode_script = Script(
hass, set_preset_mode_action, friendly_name, DOMAIN
)
self._set_oscillating_script = None
if set_oscillating_action := config.get(CONF_SET_OSCILLATING_ACTION):
self._set_oscillating_script = Script(
hass, set_oscillating_action, friendly_name, DOMAIN
)
self._set_direction_script = None
if set_direction_action := config.get(CONF_SET_DIRECTION_ACTION):
self._set_direction_script = Script(
hass, set_direction_action, friendly_name, DOMAIN
)
self._state: bool | None = False self._state: bool | None = False
self._percentage = None self._percentage: int | None = None
self._preset_mode = None self._preset_mode: str | None = None
self._oscillating = None self._oscillating: bool | None = None
self._direction = None self._direction: str | None = None
# Number of valid speeds # Number of valid speeds
self._speed_count = config.get(CONF_SPEED_COUNT) self._speed_count = config.get(CONF_SPEED_COUNT)
# List of valid preset modes # List of valid preset modes
self._preset_modes = config.get(CONF_PRESET_MODES) self._preset_modes: list[str] | None = config.get(CONF_PRESET_MODES)
if self._percentage_template: if self._percentage_template:
self._attr_supported_features |= FanEntityFeature.SET_SPEED self._attr_supported_features |= FanEntityFeature.SET_SPEED
@ -207,7 +192,7 @@ class TemplateFan(TemplateEntity, FanEntity):
return self._speed_count or 100 return self._speed_count or 100
@property @property
def preset_modes(self) -> list[str]: def preset_modes(self) -> list[str] | None:
"""Get the list of available preset modes.""" """Get the list of available preset modes."""
return self._preset_modes return self._preset_modes
@ -244,7 +229,7 @@ class TemplateFan(TemplateEntity, FanEntity):
) -> None: ) -> None:
"""Turn on the fan.""" """Turn on the fan."""
await self.async_run_script( await self.async_run_script(
self._on_script, self._action_scripts[CONF_ON_ACTION],
run_variables={ run_variables={
ATTR_PERCENTAGE: percentage, ATTR_PERCENTAGE: percentage,
ATTR_PRESET_MODE: preset_mode, ATTR_PRESET_MODE: preset_mode,
@ -263,7 +248,9 @@ class TemplateFan(TemplateEntity, FanEntity):
async def async_turn_off(self, **kwargs: Any) -> None: async def async_turn_off(self, **kwargs: Any) -> None:
"""Turn off the fan.""" """Turn off the fan."""
await self.async_run_script(self._off_script, context=self._context) await self.async_run_script(
self._action_scripts[CONF_OFF_ACTION], context=self._context
)
if self._template is None: if self._template is None:
self._state = False self._state = False
@ -273,9 +260,9 @@ class TemplateFan(TemplateEntity, FanEntity):
"""Set the percentage speed of the fan.""" """Set the percentage speed of the fan."""
self._percentage = percentage self._percentage = percentage
if self._set_percentage_script: if (script := self._action_scripts.get(CONF_SET_PERCENTAGE_ACTION)) is not None:
await self.async_run_script( await self.async_run_script(
self._set_percentage_script, script,
run_variables={ATTR_PERCENTAGE: self._percentage}, run_variables={ATTR_PERCENTAGE: self._percentage},
context=self._context, context=self._context,
) )
@ -288,9 +275,11 @@ class TemplateFan(TemplateEntity, FanEntity):
"""Set the preset_mode of the fan.""" """Set the preset_mode of the fan."""
self._preset_mode = preset_mode self._preset_mode = preset_mode
if self._set_preset_mode_script: if (
script := self._action_scripts.get(CONF_SET_PRESET_MODE_ACTION)
) is not None:
await self.async_run_script( await self.async_run_script(
self._set_preset_mode_script, script,
run_variables={ATTR_PRESET_MODE: self._preset_mode}, run_variables={ATTR_PRESET_MODE: self._preset_mode},
context=self._context, context=self._context,
) )
@ -301,25 +290,25 @@ class TemplateFan(TemplateEntity, FanEntity):
async def async_oscillate(self, oscillating: bool) -> None: async def async_oscillate(self, oscillating: bool) -> None:
"""Set oscillation of the fan.""" """Set oscillation of the fan."""
if self._set_oscillating_script is None: if (script := self._action_scripts.get(CONF_SET_OSCILLATING_ACTION)) is None:
return return
self._oscillating = oscillating self._oscillating = oscillating
await self.async_run_script( await self.async_run_script(
self._set_oscillating_script, script,
run_variables={ATTR_OSCILLATING: self.oscillating}, run_variables={ATTR_OSCILLATING: self.oscillating},
context=self._context, context=self._context,
) )
async def async_set_direction(self, direction: str) -> None: async def async_set_direction(self, direction: str) -> None:
"""Set the direction of the fan.""" """Set the direction of the fan."""
if self._set_direction_script is None: if (script := self._action_scripts.get(CONF_SET_DIRECTION_ACTION)) is None:
return return
if direction in _VALID_DIRECTIONS: if direction in _VALID_DIRECTIONS:
self._direction = direction self._direction = direction
await self.async_run_script( await self.async_run_script(
self._set_direction_script, script,
run_variables={ATTR_DIRECTION: direction}, run_variables={ATTR_DIRECTION: direction},
context=self._context, context=self._context,
) )

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any from typing import TYPE_CHECKING, Any
import voluptuous as vol import voluptuous as vol
@ -39,7 +39,6 @@ from homeassistant.exceptions import TemplateError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity import async_generate_entity_id from homeassistant.helpers.entity import async_generate_entity_id
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.script import Script
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.util import color as color_util from homeassistant.util import color as color_util
@ -127,7 +126,7 @@ PLATFORM_SCHEMA = vol.All(
) )
async def _async_create_entities(hass, config): async def _async_create_entities(hass: HomeAssistant, config):
"""Create the Template Lights.""" """Create the Template Lights."""
lights = [] lights = []
@ -164,11 +163,11 @@ class LightTemplate(TemplateEntity, LightEntity):
def __init__( def __init__(
self, self,
hass, hass: HomeAssistant,
object_id, object_id,
config, config: dict[str, Any],
unique_id, unique_id,
): ) -> None:
"""Initialize the light.""" """Initialize the light."""
super().__init__( super().__init__(
hass, config=config, fallback_name=object_id, unique_id=unique_id hass, config=config, fallback_name=object_id, unique_id=unique_id
@ -176,52 +175,31 @@ class LightTemplate(TemplateEntity, LightEntity):
self.entity_id = async_generate_entity_id( self.entity_id = async_generate_entity_id(
ENTITY_ID_FORMAT, object_id, hass=hass ENTITY_ID_FORMAT, object_id, hass=hass
) )
friendly_name = self._attr_name name = self._attr_name
if TYPE_CHECKING:
assert name is not None
self._template = config.get(CONF_VALUE_TEMPLATE) self._template = config.get(CONF_VALUE_TEMPLATE)
self._on_script = Script(hass, config[CONF_ON_ACTION], friendly_name, DOMAIN)
self._off_script = Script(hass, config[CONF_OFF_ACTION], friendly_name, DOMAIN)
self._level_script = None
if (level_action := config.get(CONF_LEVEL_ACTION)) is not None:
self._level_script = Script(hass, level_action, friendly_name, DOMAIN)
self._level_template = config.get(CONF_LEVEL_TEMPLATE) self._level_template = config.get(CONF_LEVEL_TEMPLATE)
self._temperature_script = None
if (temperature_action := config.get(CONF_TEMPERATURE_ACTION)) is not None:
self._temperature_script = Script(
hass, temperature_action, friendly_name, DOMAIN
)
self._temperature_template = config.get(CONF_TEMPERATURE_TEMPLATE) self._temperature_template = config.get(CONF_TEMPERATURE_TEMPLATE)
self._color_script = None
if (color_action := config.get(CONF_COLOR_ACTION)) is not None:
self._color_script = Script(hass, color_action, friendly_name, DOMAIN)
self._color_template = config.get(CONF_COLOR_TEMPLATE) self._color_template = config.get(CONF_COLOR_TEMPLATE)
self._hs_script = None
if (hs_action := config.get(CONF_HS_ACTION)) is not None:
self._hs_script = Script(hass, hs_action, friendly_name, DOMAIN)
self._hs_template = config.get(CONF_HS_TEMPLATE) self._hs_template = config.get(CONF_HS_TEMPLATE)
self._rgb_script = None
if (rgb_action := config.get(CONF_RGB_ACTION)) is not None:
self._rgb_script = Script(hass, rgb_action, friendly_name, DOMAIN)
self._rgb_template = config.get(CONF_RGB_TEMPLATE) self._rgb_template = config.get(CONF_RGB_TEMPLATE)
self._rgbw_script = None
if (rgbw_action := config.get(CONF_RGBW_ACTION)) is not None:
self._rgbw_script = Script(hass, rgbw_action, friendly_name, DOMAIN)
self._rgbw_template = config.get(CONF_RGBW_TEMPLATE) self._rgbw_template = config.get(CONF_RGBW_TEMPLATE)
self._rgbww_script = None
if (rgbww_action := config.get(CONF_RGBWW_ACTION)) is not None:
self._rgbww_script = Script(hass, rgbww_action, friendly_name, DOMAIN)
self._rgbww_template = config.get(CONF_RGBWW_TEMPLATE) self._rgbww_template = config.get(CONF_RGBWW_TEMPLATE)
self._effect_script = None
if (effect_action := config.get(CONF_EFFECT_ACTION)) is not None:
self._effect_script = Script(hass, effect_action, friendly_name, DOMAIN)
self._effect_list_template = config.get(CONF_EFFECT_LIST_TEMPLATE) self._effect_list_template = config.get(CONF_EFFECT_LIST_TEMPLATE)
self._effect_template = config.get(CONF_EFFECT_TEMPLATE) self._effect_template = config.get(CONF_EFFECT_TEMPLATE)
self._max_mireds_template = config.get(CONF_MAX_MIREDS_TEMPLATE) self._max_mireds_template = config.get(CONF_MAX_MIREDS_TEMPLATE)
self._min_mireds_template = config.get(CONF_MIN_MIREDS_TEMPLATE) self._min_mireds_template = config.get(CONF_MIN_MIREDS_TEMPLATE)
self._supports_transition_template = config.get(CONF_SUPPORTS_TRANSITION) self._supports_transition_template = config.get(CONF_SUPPORTS_TRANSITION)
for action_id in (CONF_ON_ACTION, CONF_OFF_ACTION, CONF_EFFECT_ACTION):
if action_config := config.get(action_id):
self.add_script(action_id, action_config, name, DOMAIN)
self._state = False self._state = False
self._brightness = None self._brightness = None
self._temperature = None self._temperature: int | None = None
self._hs_color = None self._hs_color = None
self._rgb_color = None self._rgb_color = None
self._rgbw_color = None self._rgbw_color = None
@ -235,21 +213,18 @@ class LightTemplate(TemplateEntity, LightEntity):
self._supported_color_modes = None self._supported_color_modes = None
color_modes = {ColorMode.ONOFF} color_modes = {ColorMode.ONOFF}
if self._level_script is not None: for action_id, color_mode in (
color_modes.add(ColorMode.BRIGHTNESS) (CONF_TEMPERATURE_ACTION, ColorMode.COLOR_TEMP),
if self._temperature_script is not None: (CONF_LEVEL_ACTION, ColorMode.BRIGHTNESS),
color_modes.add(ColorMode.COLOR_TEMP) (CONF_COLOR_ACTION, ColorMode.HS),
if self._hs_script is not None: (CONF_HS_ACTION, ColorMode.HS),
color_modes.add(ColorMode.HS) (CONF_RGB_ACTION, ColorMode.RGB),
if self._color_script is not None: (CONF_RGBW_ACTION, ColorMode.RGBW),
color_modes.add(ColorMode.HS) (CONF_RGBWW_ACTION, ColorMode.RGBWW),
if self._rgb_script is not None: ):
color_modes.add(ColorMode.RGB) if (action_config := config.get(action_id)) is not None:
if self._rgbw_script is not None: self.add_script(action_id, action_config, name, DOMAIN)
color_modes.add(ColorMode.RGBW) color_modes.add(color_mode)
if self._rgbww_script is not None:
color_modes.add(ColorMode.RGBWW)
self._supported_color_modes = filter_supported_color_modes(color_modes) self._supported_color_modes = filter_supported_color_modes(color_modes)
if len(self._supported_color_modes) > 1: if len(self._supported_color_modes) > 1:
self._color_mode = ColorMode.UNKNOWN self._color_mode = ColorMode.UNKNOWN
@ -257,7 +232,7 @@ class LightTemplate(TemplateEntity, LightEntity):
self._color_mode = next(iter(self._supported_color_modes)) self._color_mode = next(iter(self._supported_color_modes))
self._attr_supported_features = LightEntityFeature(0) self._attr_supported_features = LightEntityFeature(0)
if self._effect_script is not None: if self._action_scripts.get(CONF_EFFECT_ACTION) is not None:
self._attr_supported_features |= LightEntityFeature.EFFECT self._attr_supported_features |= LightEntityFeature.EFFECT
if self._supports_transition is True: if self._supports_transition is True:
self._attr_supported_features |= LightEntityFeature.TRANSITION self._attr_supported_features |= LightEntityFeature.TRANSITION
@ -321,12 +296,12 @@ class LightTemplate(TemplateEntity, LightEntity):
return self._effect_list return self._effect_list
@property @property
def color_mode(self): def color_mode(self) -> ColorMode | None:
"""Return current color mode.""" """Return current color mode."""
return self._color_mode return self._color_mode
@property @property
def supported_color_modes(self): def supported_color_modes(self) -> set[ColorMode] | None:
"""Flag supported color modes.""" """Flag supported color modes."""
return self._supported_color_modes return self._supported_color_modes
@ -555,17 +530,28 @@ class LightTemplate(TemplateEntity, LightEntity):
if ATTR_TRANSITION in kwargs and self._supports_transition is True: if ATTR_TRANSITION in kwargs and self._supports_transition is True:
common_params["transition"] = kwargs[ATTR_TRANSITION] common_params["transition"] = kwargs[ATTR_TRANSITION]
if ATTR_COLOR_TEMP_KELVIN in kwargs and self._temperature_script: if (
ATTR_COLOR_TEMP_KELVIN in kwargs
and (
temperature_script := self._action_scripts.get(CONF_TEMPERATURE_ACTION)
)
is not None
):
common_params["color_temp"] = color_util.color_temperature_kelvin_to_mired( common_params["color_temp"] = color_util.color_temperature_kelvin_to_mired(
kwargs[ATTR_COLOR_TEMP_KELVIN] kwargs[ATTR_COLOR_TEMP_KELVIN]
) )
await self.async_run_script( await self.async_run_script(
self._temperature_script, temperature_script,
run_variables=common_params, run_variables=common_params,
context=self._context, context=self._context,
) )
elif ATTR_EFFECT in kwargs and self._effect_script: elif (
ATTR_EFFECT in kwargs
and (effect_script := self._action_scripts.get(CONF_EFFECT_ACTION))
is not None
):
assert self._effect_list is not None
effect = kwargs[ATTR_EFFECT] effect = kwargs[ATTR_EFFECT]
if effect not in self._effect_list: if effect not in self._effect_list:
_LOGGER.error( _LOGGER.error(
@ -579,27 +565,38 @@ class LightTemplate(TemplateEntity, LightEntity):
common_params["effect"] = effect common_params["effect"] = effect
await self.async_run_script( await self.async_run_script(
self._effect_script, run_variables=common_params, context=self._context effect_script, run_variables=common_params, context=self._context
) )
elif ATTR_HS_COLOR in kwargs and self._color_script: elif (
ATTR_HS_COLOR in kwargs
and (color_script := self._action_scripts.get(CONF_COLOR_ACTION))
is not None
):
hs_value = kwargs[ATTR_HS_COLOR] hs_value = kwargs[ATTR_HS_COLOR]
common_params["hs"] = hs_value common_params["hs"] = hs_value
common_params["h"] = int(hs_value[0]) common_params["h"] = int(hs_value[0])
common_params["s"] = int(hs_value[1]) common_params["s"] = int(hs_value[1])
await self.async_run_script( await self.async_run_script(
self._color_script, run_variables=common_params, context=self._context color_script, run_variables=common_params, context=self._context
) )
elif ATTR_HS_COLOR in kwargs and self._hs_script: elif (
ATTR_HS_COLOR in kwargs
and (hs_script := self._action_scripts.get(CONF_HS_ACTION)) is not None
):
hs_value = kwargs[ATTR_HS_COLOR] hs_value = kwargs[ATTR_HS_COLOR]
common_params["hs"] = hs_value common_params["hs"] = hs_value
common_params["h"] = int(hs_value[0]) common_params["h"] = int(hs_value[0])
common_params["s"] = int(hs_value[1]) common_params["s"] = int(hs_value[1])
await self.async_run_script( await self.async_run_script(
self._hs_script, run_variables=common_params, context=self._context hs_script, run_variables=common_params, context=self._context
) )
elif ATTR_RGBWW_COLOR in kwargs and self._rgbww_script: elif (
ATTR_RGBWW_COLOR in kwargs
and (rgbww_script := self._action_scripts.get(CONF_RGBWW_ACTION))
is not None
):
rgbww_value = kwargs[ATTR_RGBWW_COLOR] rgbww_value = kwargs[ATTR_RGBWW_COLOR]
common_params["rgbww"] = rgbww_value common_params["rgbww"] = rgbww_value
common_params["rgb"] = ( common_params["rgb"] = (
@ -614,9 +611,12 @@ class LightTemplate(TemplateEntity, LightEntity):
common_params["ww"] = int(rgbww_value[4]) common_params["ww"] = int(rgbww_value[4])
await self.async_run_script( await self.async_run_script(
self._rgbww_script, run_variables=common_params, context=self._context rgbww_script, run_variables=common_params, context=self._context
) )
elif ATTR_RGBW_COLOR in kwargs and self._rgbw_script: elif (
ATTR_RGBW_COLOR in kwargs
and (rgbw_script := self._action_scripts.get(CONF_RGBW_ACTION)) is not None
):
rgbw_value = kwargs[ATTR_RGBW_COLOR] rgbw_value = kwargs[ATTR_RGBW_COLOR]
common_params["rgbw"] = rgbw_value common_params["rgbw"] = rgbw_value
common_params["rgb"] = ( common_params["rgb"] = (
@ -630,9 +630,12 @@ class LightTemplate(TemplateEntity, LightEntity):
common_params["w"] = int(rgbw_value[3]) common_params["w"] = int(rgbw_value[3])
await self.async_run_script( await self.async_run_script(
self._rgbw_script, run_variables=common_params, context=self._context rgbw_script, run_variables=common_params, context=self._context
) )
elif ATTR_RGB_COLOR in kwargs and self._rgb_script: elif (
ATTR_RGB_COLOR in kwargs
and (rgb_script := self._action_scripts.get(CONF_RGB_ACTION)) is not None
):
rgb_value = kwargs[ATTR_RGB_COLOR] rgb_value = kwargs[ATTR_RGB_COLOR]
common_params["rgb"] = rgb_value common_params["rgb"] = rgb_value
common_params["r"] = int(rgb_value[0]) common_params["r"] = int(rgb_value[0])
@ -640,15 +643,21 @@ class LightTemplate(TemplateEntity, LightEntity):
common_params["b"] = int(rgb_value[2]) common_params["b"] = int(rgb_value[2])
await self.async_run_script( await self.async_run_script(
self._rgb_script, run_variables=common_params, context=self._context rgb_script, run_variables=common_params, context=self._context
) )
elif ATTR_BRIGHTNESS in kwargs and self._level_script: elif (
ATTR_BRIGHTNESS in kwargs
and (level_script := self._action_scripts.get(CONF_LEVEL_ACTION))
is not None
):
await self.async_run_script( await self.async_run_script(
self._level_script, run_variables=common_params, context=self._context level_script, run_variables=common_params, context=self._context
) )
else: else:
await self.async_run_script( await self.async_run_script(
self._on_script, run_variables=common_params, context=self._context self._action_scripts[CONF_ON_ACTION],
run_variables=common_params,
context=self._context,
) )
if optimistic_set: if optimistic_set:
@ -656,14 +665,15 @@ class LightTemplate(TemplateEntity, LightEntity):
async def async_turn_off(self, **kwargs: Any) -> None: async def async_turn_off(self, **kwargs: Any) -> None:
"""Turn the light off.""" """Turn the light off."""
off_script = self._action_scripts[CONF_OFF_ACTION]
if ATTR_TRANSITION in kwargs and self._supports_transition is True: if ATTR_TRANSITION in kwargs and self._supports_transition is True:
await self.async_run_script( await self.async_run_script(
self._off_script, off_script,
run_variables={"transition": kwargs[ATTR_TRANSITION]}, run_variables={"transition": kwargs[ATTR_TRANSITION]},
context=self._context, context=self._context,
) )
else: else:
await self.async_run_script(self._off_script, context=self._context) await self.async_run_script(off_script, context=self._context)
if self._template is None: if self._template is None:
self._state = False self._state = False
self.async_write_ha_state() self.async_write_ha_state()
@ -1013,7 +1023,7 @@ class LightTemplate(TemplateEntity, LightEntity):
if render in (None, "None", ""): if render in (None, "None", ""):
self._supports_transition = False self._supports_transition = False
return return
self._attr_supported_features &= ~LightEntityFeature.TRANSITION self._attr_supported_features &= LightEntityFeature.EFFECT
self._supports_transition = bool(render) self._supports_transition = bool(render)
if self._supports_transition: if self._supports_transition:
self._attr_supported_features |= LightEntityFeature.TRANSITION self._attr_supported_features |= LightEntityFeature.TRANSITION

View File

@ -23,7 +23,6 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ServiceValidationError, TemplateError from homeassistant.exceptions import ServiceValidationError, TemplateError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.script import Script
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import DOMAIN from .const import DOMAIN
@ -90,13 +89,18 @@ class TemplateLock(TemplateEntity, LockEntity):
) )
self._state: LockState | None = None self._state: LockState | None = None
name = self._attr_name name = self._attr_name
assert name if TYPE_CHECKING:
assert name is not None
self._state_template = config.get(CONF_VALUE_TEMPLATE) self._state_template = config.get(CONF_VALUE_TEMPLATE)
self._command_lock = Script(hass, config[CONF_LOCK], name, DOMAIN) for action_id, supported_feature in (
self._command_unlock = Script(hass, config[CONF_UNLOCK], name, DOMAIN) (CONF_LOCK, 0),
if CONF_OPEN in config: (CONF_UNLOCK, 0),
self._command_open = Script(hass, config[CONF_OPEN], name, DOMAIN) (CONF_OPEN, LockEntityFeature.OPEN),
self._attr_supported_features |= LockEntityFeature.OPEN ):
if action_config := config.get(action_id):
self.add_script(action_id, action_config, name, DOMAIN)
self._attr_supported_features |= supported_feature
self._code_format_template = config.get(CONF_CODE_FORMAT_TEMPLATE) self._code_format_template = config.get(CONF_CODE_FORMAT_TEMPLATE)
self._code_format: str | None = None self._code_format: str | None = None
self._code_format_template_error: TemplateError | None = None self._code_format_template_error: TemplateError | None = None
@ -210,7 +214,9 @@ class TemplateLock(TemplateEntity, LockEntity):
tpl_vars = {ATTR_CODE: kwargs.get(ATTR_CODE) if kwargs else None} tpl_vars = {ATTR_CODE: kwargs.get(ATTR_CODE) if kwargs else None}
await self.async_run_script( await self.async_run_script(
self._command_lock, run_variables=tpl_vars, context=self._context self._action_scripts[CONF_LOCK],
run_variables=tpl_vars,
context=self._context,
) )
async def async_unlock(self, **kwargs: Any) -> None: async def async_unlock(self, **kwargs: Any) -> None:
@ -226,7 +232,9 @@ class TemplateLock(TemplateEntity, LockEntity):
tpl_vars = {ATTR_CODE: kwargs.get(ATTR_CODE) if kwargs else None} tpl_vars = {ATTR_CODE: kwargs.get(ATTR_CODE) if kwargs else None}
await self.async_run_script( await self.async_run_script(
self._command_unlock, run_variables=tpl_vars, context=self._context self._action_scripts[CONF_UNLOCK],
run_variables=tpl_vars,
context=self._context,
) )
async def async_open(self, **kwargs: Any) -> None: async def async_open(self, **kwargs: Any) -> None:
@ -242,7 +250,9 @@ class TemplateLock(TemplateEntity, LockEntity):
tpl_vars = {ATTR_CODE: kwargs.get(ATTR_CODE) if kwargs else None} tpl_vars = {ATTR_CODE: kwargs.get(ATTR_CODE) if kwargs else None}
await self.async_run_script( await self.async_run_script(
self._command_open, run_variables=tpl_vars, context=self._context self._action_scripts[CONF_OPEN],
run_variables=tpl_vars,
context=self._context,
) )
def _raise_template_error_if_available(self): def _raise_template_error_if_available(self):

View File

@ -2,7 +2,7 @@
"domain": "template", "domain": "template",
"name": "Template", "name": "Template",
"after_dependencies": ["group"], "after_dependencies": ["group"],
"codeowners": ["@PhracturedBlue", "@home-assistant/core"], "codeowners": ["@Petro31", "@PhracturedBlue", "@home-assistant/core"],
"config_flow": true, "config_flow": true,
"dependencies": ["blueprint"], "dependencies": ["blueprint"],
"documentation": "https://www.home-assistant.io/integrations/template", "documentation": "https://www.home-assistant.io/integrations/template",

View File

@ -157,9 +157,7 @@ class TemplateNumber(TemplateEntity, NumberEntity):
super().__init__(hass, config=config, unique_id=unique_id) super().__init__(hass, config=config, unique_id=unique_id)
assert self._attr_name is not None assert self._attr_name is not None
self._value_template = config[CONF_STATE] self._value_template = config[CONF_STATE]
self._command_set_value = Script( self.add_script(CONF_SET_VALUE, config[CONF_SET_VALUE], self._attr_name, DOMAIN)
hass, config[CONF_SET_VALUE], self._attr_name, DOMAIN
)
self._step_template = config[CONF_STEP] self._step_template = config[CONF_STEP]
self._min_value_template = config[CONF_MIN] self._min_value_template = config[CONF_MIN]
@ -210,9 +208,9 @@ class TemplateNumber(TemplateEntity, NumberEntity):
if self._optimistic: if self._optimistic:
self._attr_native_value = value self._attr_native_value = value
self.async_write_ha_state() self.async_write_ha_state()
if self._command_set_value: if (set_value := self._action_scripts.get(CONF_SET_VALUE)) is not None:
await self.async_run_script( await self.async_run_script(
self._command_set_value, set_value,
run_variables={ATTR_VALUE: value}, run_variables={ATTR_VALUE: value},
context=self._context, context=self._context,
) )

View File

@ -143,8 +143,8 @@ class TemplateSelect(TemplateEntity, SelectEntity):
assert self._attr_name is not None assert self._attr_name is not None
self._value_template = config[CONF_STATE] self._value_template = config[CONF_STATE]
if (selection_option := config.get(CONF_SELECT_OPTION)) is not None: if (selection_option := config.get(CONF_SELECT_OPTION)) is not None:
self._command_select_option = Script( self.add_script(
hass, selection_option, self._attr_name, DOMAIN CONF_SELECT_OPTION, selection_option, self._attr_name, DOMAIN
) )
self._options_template = config[ATTR_OPTIONS] self._options_template = config[ATTR_OPTIONS]
self._attr_assumed_state = self._optimistic = config.get(CONF_OPTIMISTIC, False) self._attr_assumed_state = self._optimistic = config.get(CONF_OPTIMISTIC, False)
@ -177,9 +177,9 @@ class TemplateSelect(TemplateEntity, SelectEntity):
if self._optimistic: if self._optimistic:
self._attr_current_option = option self._attr_current_option = option
self.async_write_ha_state() self.async_write_ha_state()
if self._command_select_option: if (select_option := self._action_scripts.get(CONF_SELECT_OPTION)) is not None:
await self.async_run_script( await self.async_run_script(
self._command_select_option, select_option,
run_variables={ATTR_OPTION: option}, run_variables={ATTR_OPTION: option},
context=self._context, context=self._context,
) )

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import TYPE_CHECKING, Any
import voluptuous as vol import voluptuous as vol
@ -33,7 +33,6 @@ from homeassistant.helpers.entity_platform import (
AddEntitiesCallback, AddEntitiesCallback,
) )
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.script import Script
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import CONF_TURN_OFF, CONF_TURN_ON, DOMAIN from .const import CONF_TURN_OFF, CONF_TURN_ON, DOMAIN
@ -74,7 +73,7 @@ SWITCH_CONFIG_SCHEMA = vol.Schema(
) )
async def _async_create_entities(hass, config): async def _async_create_entities(hass: HomeAssistant, config: ConfigType):
"""Create the Template switches.""" """Create the Template switches."""
switches = [] switches = []
@ -134,11 +133,11 @@ class SwitchTemplate(TemplateEntity, SwitchEntity, RestoreEntity):
def __init__( def __init__(
self, self,
hass, hass: HomeAssistant,
object_id, object_id,
config, config: ConfigType,
unique_id, unique_id,
): ) -> None:
"""Initialize the Template switch.""" """Initialize the Template switch."""
super().__init__( super().__init__(
hass, config=config, fallback_name=object_id, unique_id=unique_id hass, config=config, fallback_name=object_id, unique_id=unique_id
@ -147,18 +146,16 @@ class SwitchTemplate(TemplateEntity, SwitchEntity, RestoreEntity):
self.entity_id = async_generate_entity_id( self.entity_id = async_generate_entity_id(
ENTITY_ID_FORMAT, object_id, hass=hass ENTITY_ID_FORMAT, object_id, hass=hass
) )
friendly_name = self._attr_name name = self._attr_name
if TYPE_CHECKING:
assert name is not None
self._template = config.get(CONF_VALUE_TEMPLATE) self._template = config.get(CONF_VALUE_TEMPLATE)
self._on_script = (
Script(hass, config.get(CONF_TURN_ON), friendly_name, DOMAIN) if on_action := config.get(CONF_TURN_ON):
if config.get(CONF_TURN_ON) is not None self.add_script(CONF_TURN_ON, on_action, name, DOMAIN)
else None if off_action := config.get(CONF_TURN_OFF):
) self.add_script(CONF_TURN_OFF, off_action, name, DOMAIN)
self._off_script = (
Script(hass, config.get(CONF_TURN_OFF), friendly_name, DOMAIN)
if config.get(CONF_TURN_OFF) is not None
else None
)
self._state: bool | None = False self._state: bool | None = False
self._attr_assumed_state = self._template is None self._attr_assumed_state = self._template is None
self._attr_device_info = async_device_info_to_link_from_device_id( self._attr_device_info = async_device_info_to_link_from_device_id(
@ -209,16 +206,16 @@ class SwitchTemplate(TemplateEntity, SwitchEntity, RestoreEntity):
async def async_turn_on(self, **kwargs: Any) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Fire the on action.""" """Fire the on action."""
if self._on_script: if (on_script := self._action_scripts.get(CONF_TURN_ON)) is not None:
await self.async_run_script(self._on_script, context=self._context) await self.async_run_script(on_script, context=self._context)
if self._template is None: if self._template is None:
self._state = True self._state = True
self.async_write_ha_state() self.async_write_ha_state()
async def async_turn_off(self, **kwargs: Any) -> None: async def async_turn_off(self, **kwargs: Any) -> None:
"""Fire the off action.""" """Fire the off action."""
if self._off_script: if (off_script := self._action_scripts.get(CONF_TURN_OFF)) is not None:
await self.async_run_script(self._off_script, context=self._context) await self.async_run_script(off_script, context=self._context)
if self._template is None: if self._template is None:
self._state = False self._state = False
self.async_write_ha_state() self.async_write_ha_state()

View File

@ -24,7 +24,6 @@ from homeassistant.const import (
) )
from homeassistant.core import ( from homeassistant.core import (
CALLBACK_TYPE, CALLBACK_TYPE,
Context,
Event, Event,
EventStateChangedData, EventStateChangedData,
HomeAssistant, HomeAssistant,
@ -41,7 +40,7 @@ from homeassistant.helpers.event import (
TrackTemplateResultInfo, TrackTemplateResultInfo,
async_track_template_result, async_track_template_result,
) )
from homeassistant.helpers.script import Script, _VarsType from homeassistant.helpers.script_variables import ScriptVariables
from homeassistant.helpers.start import async_at_start from homeassistant.helpers.start import async_at_start
from homeassistant.helpers.template import ( from homeassistant.helpers.template import (
Template, Template,
@ -61,6 +60,7 @@ from .const import (
CONF_AVAILABILITY_TEMPLATE, CONF_AVAILABILITY_TEMPLATE,
CONF_PICTURE, CONF_PICTURE,
) )
from .entity import AbstractTemplateEntity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -248,7 +248,7 @@ class _TemplateAttribute:
return return
class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module class TemplateEntity(AbstractTemplateEntity): # pylint: disable=hass-enforce-class-module
"""Entity that uses templates to calculate attributes.""" """Entity that uses templates to calculate attributes."""
_attr_available = True _attr_available = True
@ -268,6 +268,7 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module
unique_id: str | None = None, unique_id: str | None = None,
) -> None: ) -> None:
"""Template Entity.""" """Template Entity."""
super().__init__(hass)
self._template_attrs: dict[Template, list[_TemplateAttribute]] = {} self._template_attrs: dict[Template, list[_TemplateAttribute]] = {}
self._template_result_info: TrackTemplateResultInfo | None = None self._template_result_info: TrackTemplateResultInfo | None = None
self._attr_extra_state_attributes = {} self._attr_extra_state_attributes = {}
@ -285,6 +286,7 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module
] ]
| None | None
) = None ) = None
self._run_variables: ScriptVariables | dict
if config is None: if config is None:
self._attribute_templates = attribute_templates self._attribute_templates = attribute_templates
self._availability_template = availability_template self._availability_template = availability_template
@ -339,18 +341,6 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module
variables=variables, parse_result=False variables=variables, parse_result=False
) )
@callback
def _render_variables(self) -> dict:
if isinstance(self._run_variables, dict):
return self._run_variables
return self._run_variables.async_render(
self.hass,
{
"this": TemplateStateFromEntityId(self.hass, self.entity_id),
},
)
@callback @callback
def _update_available(self, result: str | TemplateError) -> None: def _update_available(self, result: str | TemplateError) -> None:
if isinstance(result, TemplateError): if isinstance(result, TemplateError):
@ -387,6 +377,18 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module
return None return None
return cast(str, self._blueprint_inputs[CONF_USE_BLUEPRINT][CONF_PATH]) return cast(str, self._blueprint_inputs[CONF_USE_BLUEPRINT][CONF_PATH])
def _render_script_variables(self) -> dict[str, Any]:
"""Render configured variables."""
if isinstance(self._run_variables, dict):
return self._run_variables
return self._run_variables.async_render(
self.hass,
{
"this": TemplateStateFromEntityId(self.hass, self.entity_id),
},
)
def add_template_attribute( def add_template_attribute(
self, self,
attribute: str, attribute: str,
@ -488,7 +490,7 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module
variables = { variables = {
"this": TemplateStateFromEntityId(self.hass, self.entity_id), "this": TemplateStateFromEntityId(self.hass, self.entity_id),
**self._render_variables(), **self._render_script_variables(),
} }
for template, attributes in self._template_attrs.items(): for template, attributes in self._template_attrs.items():
@ -581,22 +583,3 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module
"""Call for forced update.""" """Call for forced update."""
assert self._template_result_info assert self._template_result_info
self._template_result_info.async_refresh() self._template_result_info.async_refresh()
async def async_run_script(
self,
script: Script,
*,
run_variables: _VarsType | None = None,
context: Context | None = None,
) -> None:
"""Run an action script."""
if run_variables is None:
run_variables = {}
await script.async_run(
run_variables={
"this": TemplateStateFromEntityId(self.hass, self.entity_id),
**self._render_variables(),
**run_variables,
},
context=context,
)

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any from typing import TYPE_CHECKING, Any
import voluptuous as vol import voluptuous as vol
@ -33,7 +33,6 @@ from homeassistant.exceptions import TemplateError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity import async_generate_entity_id from homeassistant.helpers.entity import async_generate_entity_id
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.script import Script
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import DOMAIN from .const import DOMAIN
@ -90,7 +89,7 @@ PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
) )
async def _async_create_entities(hass, config): async def _async_create_entities(hass: HomeAssistant, config: ConfigType):
"""Create the Template Vacuums.""" """Create the Template Vacuums."""
vacuums = [] vacuums = []
@ -127,11 +126,11 @@ class TemplateVacuum(TemplateEntity, StateVacuumEntity):
def __init__( def __init__(
self, self,
hass, hass: HomeAssistant,
object_id, object_id,
config, config: ConfigType,
unique_id, unique_id,
): ) -> None:
"""Initialize the vacuum.""" """Initialize the vacuum."""
super().__init__( super().__init__(
hass, config=config, fallback_name=object_id, unique_id=unique_id hass, config=config, fallback_name=object_id, unique_id=unique_id
@ -139,7 +138,9 @@ class TemplateVacuum(TemplateEntity, StateVacuumEntity):
self.entity_id = async_generate_entity_id( self.entity_id = async_generate_entity_id(
ENTITY_ID_FORMAT, object_id, hass=hass ENTITY_ID_FORMAT, object_id, hass=hass
) )
friendly_name = self._attr_name name = self._attr_name
if TYPE_CHECKING:
assert name is not None
self._template = config.get(CONF_VALUE_TEMPLATE) self._template = config.get(CONF_VALUE_TEMPLATE)
self._battery_level_template = config.get(CONF_BATTERY_LEVEL_TEMPLATE) self._battery_level_template = config.get(CONF_BATTERY_LEVEL_TEMPLATE)
@ -148,43 +149,18 @@ class TemplateVacuum(TemplateEntity, StateVacuumEntity):
VacuumEntityFeature.START | VacuumEntityFeature.STATE VacuumEntityFeature.START | VacuumEntityFeature.STATE
) )
self._start_script = Script(hass, config[SERVICE_START], friendly_name, DOMAIN) for action_id, supported_feature in (
(SERVICE_START, 0),
self._pause_script = None (SERVICE_PAUSE, VacuumEntityFeature.PAUSE),
if pause_action := config.get(SERVICE_PAUSE): (SERVICE_STOP, VacuumEntityFeature.STOP),
self._pause_script = Script(hass, pause_action, friendly_name, DOMAIN) (SERVICE_RETURN_TO_BASE, VacuumEntityFeature.RETURN_HOME),
self._attr_supported_features |= VacuumEntityFeature.PAUSE (SERVICE_CLEAN_SPOT, VacuumEntityFeature.CLEAN_SPOT),
(SERVICE_LOCATE, VacuumEntityFeature.LOCATE),
self._stop_script = None (SERVICE_SET_FAN_SPEED, VacuumEntityFeature.FAN_SPEED),
if stop_action := config.get(SERVICE_STOP): ):
self._stop_script = Script(hass, stop_action, friendly_name, DOMAIN) if action_config := config.get(action_id):
self._attr_supported_features |= VacuumEntityFeature.STOP self.add_script(action_id, action_config, name, DOMAIN)
self._attr_supported_features |= supported_feature
self._return_to_base_script = None
if return_to_base_action := config.get(SERVICE_RETURN_TO_BASE):
self._return_to_base_script = Script(
hass, return_to_base_action, friendly_name, DOMAIN
)
self._attr_supported_features |= VacuumEntityFeature.RETURN_HOME
self._clean_spot_script = None
if clean_spot_action := config.get(SERVICE_CLEAN_SPOT):
self._clean_spot_script = Script(
hass, clean_spot_action, friendly_name, DOMAIN
)
self._attr_supported_features |= VacuumEntityFeature.CLEAN_SPOT
self._locate_script = None
if locate_action := config.get(SERVICE_LOCATE):
self._locate_script = Script(hass, locate_action, friendly_name, DOMAIN)
self._attr_supported_features |= VacuumEntityFeature.LOCATE
self._set_fan_speed_script = None
if set_fan_speed_action := config.get(SERVICE_SET_FAN_SPEED):
self._set_fan_speed_script = Script(
hass, set_fan_speed_action, friendly_name, DOMAIN
)
self._attr_supported_features |= VacuumEntityFeature.FAN_SPEED
self._state = None self._state = None
self._battery_level = None self._battery_level = None
@ -203,62 +179,50 @@ class TemplateVacuum(TemplateEntity, StateVacuumEntity):
async def async_start(self) -> None: async def async_start(self) -> None:
"""Start or resume the cleaning task.""" """Start or resume the cleaning task."""
await self.async_run_script(self._start_script, context=self._context) await self.async_run_script(
self._action_scripts[SERVICE_START], context=self._context
)
async def async_pause(self) -> None: async def async_pause(self) -> None:
"""Pause the cleaning task.""" """Pause the cleaning task."""
if self._pause_script is None: if (script := self._action_scripts.get(SERVICE_PAUSE)) is not None:
return await self.async_run_script(script, context=self._context)
await self.async_run_script(self._pause_script, context=self._context)
async def async_stop(self, **kwargs: Any) -> None: async def async_stop(self, **kwargs: Any) -> None:
"""Stop the cleaning task.""" """Stop the cleaning task."""
if self._stop_script is None: if (script := self._action_scripts.get(SERVICE_STOP)) is not None:
return await self.async_run_script(script, context=self._context)
await self.async_run_script(self._stop_script, context=self._context)
async def async_return_to_base(self, **kwargs: Any) -> None: async def async_return_to_base(self, **kwargs: Any) -> None:
"""Set the vacuum cleaner to return to the dock.""" """Set the vacuum cleaner to return to the dock."""
if self._return_to_base_script is None: if (script := self._action_scripts.get(SERVICE_RETURN_TO_BASE)) is not None:
return await self.async_run_script(script, context=self._context)
await self.async_run_script(self._return_to_base_script, context=self._context)
async def async_clean_spot(self, **kwargs: Any) -> None: async def async_clean_spot(self, **kwargs: Any) -> None:
"""Perform a spot clean-up.""" """Perform a spot clean-up."""
if self._clean_spot_script is None: if (script := self._action_scripts.get(SERVICE_CLEAN_SPOT)) is not None:
return await self.async_run_script(script, context=self._context)
await self.async_run_script(self._clean_spot_script, context=self._context)
async def async_locate(self, **kwargs: Any) -> None: async def async_locate(self, **kwargs: Any) -> None:
"""Locate the vacuum cleaner.""" """Locate the vacuum cleaner."""
if self._locate_script is None: if (script := self._action_scripts.get(SERVICE_LOCATE)) is not None:
return await self.async_run_script(script, context=self._context)
await self.async_run_script(self._locate_script, context=self._context)
async def async_set_fan_speed(self, fan_speed: str, **kwargs: Any) -> None: async def async_set_fan_speed(self, fan_speed: str, **kwargs: Any) -> None:
"""Set fan speed.""" """Set fan speed."""
if self._set_fan_speed_script is None: if fan_speed not in self._attr_fan_speed_list:
return
if fan_speed in self._attr_fan_speed_list:
self._attr_fan_speed = fan_speed
await self.async_run_script(
self._set_fan_speed_script,
run_variables={ATTR_FAN_SPEED: fan_speed},
context=self._context,
)
else:
_LOGGER.error( _LOGGER.error(
"Received invalid fan speed: %s for entity %s. Expected: %s", "Received invalid fan speed: %s for entity %s. Expected: %s",
fan_speed, fan_speed,
self.entity_id, self.entity_id,
self._attr_fan_speed_list, self._attr_fan_speed_list,
) )
return
if (script := self._action_scripts.get(SERVICE_SET_FAN_SPEED)) is not None:
await self.async_run_script(
script, run_variables={ATTR_FAN_SPEED: fan_speed}, context=self._context
)
@callback @callback
def _async_setup_templates(self) -> None: def _async_setup_templates(self) -> None:

View File

@ -0,0 +1,17 @@
"""Test abstract template entity."""
import pytest
from homeassistant.components.template import entity as abstract_entity
from homeassistant.core import HomeAssistant
async def test_template_entity_not_implemented(hass: HomeAssistant) -> None:
"""Test abstract template entity raises not implemented error."""
entity = abstract_entity.AbstractTemplateEntity(None)
with pytest.raises(NotImplementedError):
_ = entity.referenced_blueprint
with pytest.raises(NotImplementedError):
entity._render_script_variables()

View File

@ -9,7 +9,7 @@ from homeassistant.helpers import template
async def test_template_entity_requires_hass_set(hass: HomeAssistant) -> None: async def test_template_entity_requires_hass_set(hass: HomeAssistant) -> None:
"""Test template entity requires hass to be set before accepting templates.""" """Test template entity requires hass to be set before accepting templates."""
entity = template_entity.TemplateEntity(hass) entity = template_entity.TemplateEntity(None)
with pytest.raises(ValueError, match="^hass cannot be None"): with pytest.raises(ValueError, match="^hass cannot be None"):
entity.add_template_attribute("_hello", template.Template("Hello")) entity.add_template_attribute("_hello", template.Template("Hello"))