diff --git a/homeassistant/components/rest/switch.py b/homeassistant/components/rest/switch.py index 827f4bad0b3..0a220204997 100644 --- a/homeassistant/components/rest/switch.py +++ b/homeassistant/components/rest/switch.py @@ -18,7 +18,9 @@ from homeassistant.components.switch import ( from homeassistant.const import ( CONF_DEVICE_CLASS, CONF_HEADERS, + CONF_ICON, CONF_METHOD, + CONF_NAME, CONF_PARAMS, CONF_PASSWORD, CONF_RESOURCE, @@ -33,8 +35,10 @@ from homeassistant.helpers import config_validation as cv, template from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.httpx_client import get_async_client from homeassistant.helpers.template_entity import ( + CONF_AVAILABILITY, + CONF_PICTURE, TEMPLATE_ENTITY_BASE_SCHEMA, - TemplateEntity, + ManualTriggerEntity, ) from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType @@ -44,6 +48,14 @@ CONF_BODY_ON = "body_on" CONF_IS_ON_TEMPLATE = "is_on_template" CONF_STATE_RESOURCE = "state_resource" +TRIGGER_ENTITY_OPTIONS = ( + CONF_AVAILABILITY, + CONF_DEVICE_CLASS, + CONF_ICON, + CONF_PICTURE, + CONF_UNIQUE_ID, +) + DEFAULT_METHOD = "post" DEFAULT_BODY_OFF = "OFF" DEFAULT_BODY_ON = "ON" @@ -71,6 +83,7 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( vol.Inclusive(CONF_USERNAME, "authentication"): cv.string, vol.Inclusive(CONF_PASSWORD, "authentication"): cv.string, vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean, + vol.Optional(CONF_AVAILABILITY): cv.template, } ) @@ -83,10 +96,17 @@ async def async_setup_platform( ) -> None: """Set up the RESTful switch.""" resource: str = config[CONF_RESOURCE] - unique_id: str | None = config.get(CONF_UNIQUE_ID) + name = config.get(CONF_NAME) or template.Template(DEFAULT_NAME, hass) + + trigger_entity_config = {CONF_NAME: name} + + for key in TRIGGER_ENTITY_OPTIONS: + if key not in config: + continue + trigger_entity_config[key] = config[key] try: - switch = RestSwitch(hass, config, unique_id) + switch = RestSwitch(hass, config, trigger_entity_config) req = await switch.get_device_state(hass) if req.status_code >= HTTPStatus.BAD_REQUEST: @@ -102,23 +122,17 @@ async def async_setup_platform( raise PlatformNotReady(f"No route to resource/endpoint: {resource}") from exc -class RestSwitch(TemplateEntity, SwitchEntity): +class RestSwitch(ManualTriggerEntity, SwitchEntity): """Representation of a switch that can be toggled using REST.""" def __init__( self, hass: HomeAssistant, config: ConfigType, - unique_id: str | None, + trigger_entity_config: ConfigType, ) -> None: """Initialize the REST switch.""" - TemplateEntity.__init__( - self, - hass, - config=config, - fallback_name=DEFAULT_NAME, - unique_id=unique_id, - ) + ManualTriggerEntity.__init__(self, hass, trigger_entity_config) auth: httpx.BasicAuth | None = None username: str | None = None @@ -138,8 +152,6 @@ class RestSwitch(TemplateEntity, SwitchEntity): self._timeout: int = config[CONF_TIMEOUT] self._verify_ssl: bool = config[CONF_VERIFY_SSL] - self._attr_device_class = config.get(CONF_DEVICE_CLASS) - self._body_on.hass = hass self._body_off.hass = hass if (is_on_template := self._is_on_template) is not None: @@ -148,6 +160,11 @@ class RestSwitch(TemplateEntity, SwitchEntity): template.attach(hass, self._headers) template.attach(hass, self._params) + async def async_added_to_hass(self) -> None: + """Handle adding to Home Assistant.""" + await super().async_added_to_hass() + await self.async_update() + async def async_turn_on(self, **kwargs: Any) -> None: """Turn the device on.""" body_on_t = self._body_on.async_render(parse_result=False) @@ -198,13 +215,18 @@ class RestSwitch(TemplateEntity, SwitchEntity): async def async_update(self) -> None: """Get the current state, catching errors.""" + req = None try: - await self.get_device_state(self.hass) + req = await self.get_device_state(self.hass) except asyncio.TimeoutError: _LOGGER.exception("Timed out while fetching data") except httpx.RequestError as err: _LOGGER.exception("Error while fetching data: %s", err) + if req: + self._process_manual_data(req.text) + self.async_write_ha_state() + async def get_device_state(self, hass: HomeAssistant) -> httpx.Response: """Get the latest data from REST API and update the state.""" websession = get_async_client(hass, self._verify_ssl) diff --git a/tests/components/rest/test_switch.py b/tests/components/rest/test_switch.py index a6895183d4e..8bd13550960 100644 --- a/tests/components/rest/test_switch.py +++ b/tests/components/rest/test_switch.py @@ -111,7 +111,7 @@ async def test_setup_minimum(hass: HomeAssistant) -> None: with assert_setup_component(1, SWITCH_DOMAIN): assert await async_setup_component(hass, SWITCH_DOMAIN, config) await hass.async_block_till_done() - assert route.call_count == 1 + assert route.call_count == 2 @respx.mock @@ -129,7 +129,7 @@ async def test_setup_query_params(hass: HomeAssistant) -> None: assert await async_setup_component(hass, SWITCH_DOMAIN, config) await hass.async_block_till_done() - assert route.call_count == 1 + assert route.call_count == 2 @respx.mock @@ -148,7 +148,7 @@ async def test_setup(hass: HomeAssistant) -> None: } assert await async_setup_component(hass, SWITCH_DOMAIN, config) await hass.async_block_till_done() - assert route.call_count == 1 + assert route.call_count == 2 assert_setup_component(1, SWITCH_DOMAIN) @@ -170,7 +170,7 @@ async def test_setup_with_state_resource(hass: HomeAssistant) -> None: } assert await async_setup_component(hass, SWITCH_DOMAIN, config) await hass.async_block_till_done() - assert route.call_count == 1 + assert route.call_count == 2 assert_setup_component(1, SWITCH_DOMAIN) @@ -195,7 +195,7 @@ async def test_setup_with_templated_headers_params(hass: HomeAssistant) -> None: } assert await async_setup_component(hass, SWITCH_DOMAIN, config) await hass.async_block_till_done() - assert route.call_count == 1 + assert route.call_count == 2 last_call = route.calls[-1] last_request: httpx.Request = last_call.request assert last_request.headers.get("Accept") == CONTENT_TYPE_JSON