Add button platform to template integration (#61908)

* Add button platform to template integration

* review comments

* add unique ID check
This commit is contained in:
Raman Gupta 2022-01-03 04:03:37 -05:00 committed by GitHub
parent ad7a0d799d
commit 6f8cd54ca1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 295 additions and 0 deletions

View File

@ -0,0 +1,122 @@
"""Support for buttons which integrates with other components."""
from __future__ import annotations
import contextlib
import logging
from typing import Any
import voluptuous as vol
from homeassistant.components.button import (
DEVICE_CLASSES_SCHEMA,
ButtonDeviceClass,
ButtonEntity,
)
from homeassistant.const import CONF_DEVICE_CLASS, CONF_ICON, CONF_NAME, CONF_UNIQUE_ID
from homeassistant.core import Config, HomeAssistant
from homeassistant.exceptions import PlatformNotReady
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.script import Script
from homeassistant.helpers.template import Template, TemplateError
from .const import CONF_AVAILABILITY, DOMAIN
from .template_entity import TemplateEntity
_LOGGER = logging.getLogger(__name__)
CONF_PRESS = "press"
DEFAULT_NAME = "Template Button"
DEFAULT_OPTIMISTIC = False
BUTTON_SCHEMA = vol.Schema(
{
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.template,
vol.Required(CONF_PRESS): cv.SCRIPT_SCHEMA,
vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA,
vol.Optional(CONF_AVAILABILITY): cv.template,
vol.Optional(CONF_UNIQUE_ID): cv.string,
vol.Optional(CONF_ICON): cv.template,
}
)
async def _async_create_entities(
hass: HomeAssistant, definitions: list[dict[str, Any]], unique_id_prefix: str | None
) -> list[TemplateButtonEntity]:
"""Create the Template button."""
entities = []
for definition in definitions:
unique_id = definition.get(CONF_UNIQUE_ID)
if unique_id and unique_id_prefix:
unique_id = f"{unique_id_prefix}-{unique_id}"
entities.append(
TemplateButtonEntity(
hass,
definition[CONF_NAME],
definition.get(CONF_AVAILABILITY),
definition[CONF_PRESS],
definition.get(CONF_DEVICE_CLASS),
unique_id,
definition.get(CONF_ICON),
)
)
return entities
async def async_setup_platform(
hass: HomeAssistant,
config: Config,
async_add_entities: AddEntitiesCallback,
discovery_info: dict[str, Any] | None = None,
) -> None:
"""Set up the template button."""
if "coordinator" in discovery_info:
raise PlatformNotReady(
"The template button platform doesn't support trigger entities"
)
async_add_entities(
await _async_create_entities(
hass, discovery_info["entities"], discovery_info["unique_id"]
)
)
class TemplateButtonEntity(TemplateEntity, ButtonEntity):
"""Representation of a template button."""
def __init__(
self,
hass: HomeAssistant,
name_template: Template,
availability_template: Template | None,
command_press: dict[str, Any],
device_class: ButtonDeviceClass | None,
unique_id: str | None,
icon_template: Template | None,
) -> None:
"""Initialize the button."""
super().__init__(
availability_template=availability_template, icon_template=icon_template
)
self._attr_name = DEFAULT_NAME
self._name_template = name_template
name_template.hass = hass
with contextlib.suppress(TemplateError):
self._attr_name = name_template.async_render(parse_result=False)
self._command_press = Script(hass, command_press, self._attr_name, DOMAIN)
self._attr_device_class = device_class
self._attr_unique_id = unique_id
self._attr_state = None
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
if self._name_template and not self._name_template.is_static:
self.add_template_attribute("_attr_name", self._name_template, cv.string)
await super().async_added_to_hass()
async def async_press(self) -> None:
"""Press the button."""
await self._command_press.async_run(context=self._context)

View File

@ -4,6 +4,7 @@ import logging
import voluptuous as vol
from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN
from homeassistant.components.button import DOMAIN as BUTTON_DOMAIN
from homeassistant.components.number import DOMAIN as NUMBER_DOMAIN
from homeassistant.components.select import DOMAIN as SELECT_DOMAIN
from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
@ -14,6 +15,7 @@ from homeassistant.helpers.trigger import async_validate_trigger_config
from . import (
binary_sensor as binary_sensor_platform,
button as button_platform,
number as number_platform,
select as select_platform,
sensor as sensor_platform,
@ -44,6 +46,9 @@ CONFIG_SECTION_SCHEMA = vol.Schema(
vol.Optional(SELECT_DOMAIN): vol.All(
cv.ensure_list, [select_platform.SELECT_SCHEMA]
),
vol.Optional(BUTTON_DOMAIN): vol.All(
cv.ensure_list, [button_platform.BUTTON_SCHEMA]
),
}
)

View File

@ -13,6 +13,7 @@ PLATFORM_STORAGE_KEY = "template_platforms"
PLATFORMS = [
Platform.ALARM_CONTROL_PANEL,
Platform.BINARY_SENSOR,
Platform.BUTTON,
Platform.COVER,
Platform.FAN,
Platform.LIGHT,

View File

@ -0,0 +1,167 @@
"""The tests for the Template button platform."""
import datetime as dt
from unittest.mock import patch
import pytest
from homeassistant import setup
from homeassistant.components.button.const import DOMAIN as BUTTON_DOMAIN, SERVICE_PRESS
from homeassistant.components.template.button import DEFAULT_NAME
from homeassistant.const import (
CONF_DEVICE_CLASS,
CONF_ENTITY_ID,
CONF_FRIENDLY_NAME,
CONF_ICON,
STATE_UNKNOWN,
)
from homeassistant.helpers.entity_registry import async_get
from tests.common import assert_setup_component, async_mock_service
_TEST_BUTTON = "button.template_button"
_TEST_OPTIONS_BUTTON = "button.test"
@pytest.fixture
def calls(hass):
"""Track calls to a mock service."""
return async_mock_service(hass, "test", "automation")
async def test_missing_optional_config(hass, calls):
"""Test: missing optional template is ok."""
with assert_setup_component(1, "template"):
assert await setup.async_setup_component(
hass,
"template",
{
"template": {
"button": {
"press": {"service": "script.press"},
},
}
},
)
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
_verify(hass, STATE_UNKNOWN)
async def test_missing_required_keys(hass, calls):
"""Test: missing required fields will fail."""
with assert_setup_component(0, "template"):
assert await setup.async_setup_component(
hass,
"template",
{"template": {"button": {}}},
)
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
assert hass.states.async_all("button") == []
async def test_all_optional_config(hass, calls):
"""Test: including all optional templates is ok."""
with assert_setup_component(1, "template"):
assert await setup.async_setup_component(
hass,
"template",
{
"template": {
"unique_id": "test",
"button": {
"press": {"service": "test.automation"},
"device_class": "restart",
"unique_id": "test",
"name": "test",
"icon": "mdi:test",
},
}
},
)
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
_verify(
hass,
STATE_UNKNOWN,
{
CONF_DEVICE_CLASS: "restart",
CONF_FRIENDLY_NAME: "test",
CONF_ICON: "mdi:test",
},
_TEST_OPTIONS_BUTTON,
)
now = dt.datetime.now(dt.timezone.utc)
with patch("homeassistant.util.dt.utcnow", return_value=now):
await hass.services.async_call(
BUTTON_DOMAIN,
SERVICE_PRESS,
{CONF_ENTITY_ID: _TEST_OPTIONS_BUTTON},
blocking=True,
)
assert len(calls) == 1
_verify(
hass,
now.isoformat(),
{
CONF_DEVICE_CLASS: "restart",
CONF_FRIENDLY_NAME: "test",
CONF_ICON: "mdi:test",
},
_TEST_OPTIONS_BUTTON,
)
er = async_get(hass)
assert er.async_get_entity_id("button", "template", "test-test")
async def test_unique_id(hass, calls):
"""Test: unique id is ok."""
with assert_setup_component(1, "template"):
assert await setup.async_setup_component(
hass,
"template",
{
"template": {
"unique_id": "test",
"button": {
"press": {"service": "script.press"},
"unique_id": "test",
},
}
},
)
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
_verify(hass, STATE_UNKNOWN)
def _verify(
hass,
expected_value,
attributes=None,
entity_id=_TEST_BUTTON,
):
"""Verify button's state."""
attributes = attributes or {}
if CONF_FRIENDLY_NAME not in attributes:
attributes[CONF_FRIENDLY_NAME] = DEFAULT_NAME
state = hass.states.get(entity_id)
assert state.state == expected_value
assert state.attributes == attributes