mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Refactor Rest Switch with ManualTriggerEntity (#97403)
* Refactor Rest Switch with ManualTriggerEntity * Fix test * Fix 2 * review comments * remove async_added_to_hass * update on startup
This commit is contained in:
parent
87b7fc6c61
commit
ed18c6a013
@ -18,7 +18,9 @@ from homeassistant.components.switch import (
|
|||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
CONF_DEVICE_CLASS,
|
CONF_DEVICE_CLASS,
|
||||||
CONF_HEADERS,
|
CONF_HEADERS,
|
||||||
|
CONF_ICON,
|
||||||
CONF_METHOD,
|
CONF_METHOD,
|
||||||
|
CONF_NAME,
|
||||||
CONF_PARAMS,
|
CONF_PARAMS,
|
||||||
CONF_PASSWORD,
|
CONF_PASSWORD,
|
||||||
CONF_RESOURCE,
|
CONF_RESOURCE,
|
||||||
@ -33,8 +35,10 @@ from homeassistant.helpers import config_validation as cv, template
|
|||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.helpers.httpx_client import get_async_client
|
from homeassistant.helpers.httpx_client import get_async_client
|
||||||
from homeassistant.helpers.template_entity import (
|
from homeassistant.helpers.template_entity import (
|
||||||
|
CONF_AVAILABILITY,
|
||||||
|
CONF_PICTURE,
|
||||||
TEMPLATE_ENTITY_BASE_SCHEMA,
|
TEMPLATE_ENTITY_BASE_SCHEMA,
|
||||||
TemplateEntity,
|
ManualTriggerEntity,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
@ -44,6 +48,14 @@ CONF_BODY_ON = "body_on"
|
|||||||
CONF_IS_ON_TEMPLATE = "is_on_template"
|
CONF_IS_ON_TEMPLATE = "is_on_template"
|
||||||
CONF_STATE_RESOURCE = "state_resource"
|
CONF_STATE_RESOURCE = "state_resource"
|
||||||
|
|
||||||
|
TRIGGER_ENTITY_OPTIONS = (
|
||||||
|
CONF_AVAILABILITY,
|
||||||
|
CONF_DEVICE_CLASS,
|
||||||
|
CONF_ICON,
|
||||||
|
CONF_PICTURE,
|
||||||
|
CONF_UNIQUE_ID,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_METHOD = "post"
|
DEFAULT_METHOD = "post"
|
||||||
DEFAULT_BODY_OFF = "OFF"
|
DEFAULT_BODY_OFF = "OFF"
|
||||||
DEFAULT_BODY_ON = "ON"
|
DEFAULT_BODY_ON = "ON"
|
||||||
@ -71,6 +83,7 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
|||||||
vol.Inclusive(CONF_USERNAME, "authentication"): cv.string,
|
vol.Inclusive(CONF_USERNAME, "authentication"): cv.string,
|
||||||
vol.Inclusive(CONF_PASSWORD, "authentication"): cv.string,
|
vol.Inclusive(CONF_PASSWORD, "authentication"): cv.string,
|
||||||
vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean,
|
vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean,
|
||||||
|
vol.Optional(CONF_AVAILABILITY): cv.template,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -83,10 +96,17 @@ async def async_setup_platform(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Set up the RESTful switch."""
|
"""Set up the RESTful switch."""
|
||||||
resource: str = config[CONF_RESOURCE]
|
resource: str = config[CONF_RESOURCE]
|
||||||
unique_id: str | None = config.get(CONF_UNIQUE_ID)
|
name = config.get(CONF_NAME) or template.Template(DEFAULT_NAME, hass)
|
||||||
|
|
||||||
|
trigger_entity_config = {CONF_NAME: name}
|
||||||
|
|
||||||
|
for key in TRIGGER_ENTITY_OPTIONS:
|
||||||
|
if key not in config:
|
||||||
|
continue
|
||||||
|
trigger_entity_config[key] = config[key]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
switch = RestSwitch(hass, config, unique_id)
|
switch = RestSwitch(hass, config, trigger_entity_config)
|
||||||
|
|
||||||
req = await switch.get_device_state(hass)
|
req = await switch.get_device_state(hass)
|
||||||
if req.status_code >= HTTPStatus.BAD_REQUEST:
|
if req.status_code >= HTTPStatus.BAD_REQUEST:
|
||||||
@ -102,23 +122,17 @@ async def async_setup_platform(
|
|||||||
raise PlatformNotReady(f"No route to resource/endpoint: {resource}") from exc
|
raise PlatformNotReady(f"No route to resource/endpoint: {resource}") from exc
|
||||||
|
|
||||||
|
|
||||||
class RestSwitch(TemplateEntity, SwitchEntity):
|
class RestSwitch(ManualTriggerEntity, SwitchEntity):
|
||||||
"""Representation of a switch that can be toggled using REST."""
|
"""Representation of a switch that can be toggled using REST."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
config: ConfigType,
|
config: ConfigType,
|
||||||
unique_id: str | None,
|
trigger_entity_config: ConfigType,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the REST switch."""
|
"""Initialize the REST switch."""
|
||||||
TemplateEntity.__init__(
|
ManualTriggerEntity.__init__(self, hass, trigger_entity_config)
|
||||||
self,
|
|
||||||
hass,
|
|
||||||
config=config,
|
|
||||||
fallback_name=DEFAULT_NAME,
|
|
||||||
unique_id=unique_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
auth: httpx.BasicAuth | None = None
|
auth: httpx.BasicAuth | None = None
|
||||||
username: str | None = None
|
username: str | None = None
|
||||||
@ -138,8 +152,6 @@ class RestSwitch(TemplateEntity, SwitchEntity):
|
|||||||
self._timeout: int = config[CONF_TIMEOUT]
|
self._timeout: int = config[CONF_TIMEOUT]
|
||||||
self._verify_ssl: bool = config[CONF_VERIFY_SSL]
|
self._verify_ssl: bool = config[CONF_VERIFY_SSL]
|
||||||
|
|
||||||
self._attr_device_class = config.get(CONF_DEVICE_CLASS)
|
|
||||||
|
|
||||||
self._body_on.hass = hass
|
self._body_on.hass = hass
|
||||||
self._body_off.hass = hass
|
self._body_off.hass = hass
|
||||||
if (is_on_template := self._is_on_template) is not None:
|
if (is_on_template := self._is_on_template) is not None:
|
||||||
@ -148,6 +160,11 @@ class RestSwitch(TemplateEntity, SwitchEntity):
|
|||||||
template.attach(hass, self._headers)
|
template.attach(hass, self._headers)
|
||||||
template.attach(hass, self._params)
|
template.attach(hass, self._params)
|
||||||
|
|
||||||
|
async def async_added_to_hass(self) -> None:
|
||||||
|
"""Handle adding to Home Assistant."""
|
||||||
|
await super().async_added_to_hass()
|
||||||
|
await self.async_update()
|
||||||
|
|
||||||
async def async_turn_on(self, **kwargs: Any) -> None:
|
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||||
"""Turn the device on."""
|
"""Turn the device on."""
|
||||||
body_on_t = self._body_on.async_render(parse_result=False)
|
body_on_t = self._body_on.async_render(parse_result=False)
|
||||||
@ -198,13 +215,18 @@ class RestSwitch(TemplateEntity, SwitchEntity):
|
|||||||
|
|
||||||
async def async_update(self) -> None:
|
async def async_update(self) -> None:
|
||||||
"""Get the current state, catching errors."""
|
"""Get the current state, catching errors."""
|
||||||
|
req = None
|
||||||
try:
|
try:
|
||||||
await self.get_device_state(self.hass)
|
req = await self.get_device_state(self.hass)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
_LOGGER.exception("Timed out while fetching data")
|
_LOGGER.exception("Timed out while fetching data")
|
||||||
except httpx.RequestError as err:
|
except httpx.RequestError as err:
|
||||||
_LOGGER.exception("Error while fetching data: %s", err)
|
_LOGGER.exception("Error while fetching data: %s", err)
|
||||||
|
|
||||||
|
if req:
|
||||||
|
self._process_manual_data(req.text)
|
||||||
|
self.async_write_ha_state()
|
||||||
|
|
||||||
async def get_device_state(self, hass: HomeAssistant) -> httpx.Response:
|
async def get_device_state(self, hass: HomeAssistant) -> httpx.Response:
|
||||||
"""Get the latest data from REST API and update the state."""
|
"""Get the latest data from REST API and update the state."""
|
||||||
websession = get_async_client(hass, self._verify_ssl)
|
websession = get_async_client(hass, self._verify_ssl)
|
||||||
|
@ -111,7 +111,7 @@ async def test_setup_minimum(hass: HomeAssistant) -> None:
|
|||||||
with assert_setup_component(1, SWITCH_DOMAIN):
|
with assert_setup_component(1, SWITCH_DOMAIN):
|
||||||
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
|
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert route.call_count == 1
|
assert route.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
@ -129,7 +129,7 @@ async def test_setup_query_params(hass: HomeAssistant) -> None:
|
|||||||
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
|
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert route.call_count == 1
|
assert route.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
@ -148,7 +148,7 @@ async def test_setup(hass: HomeAssistant) -> None:
|
|||||||
}
|
}
|
||||||
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
|
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert route.call_count == 1
|
assert route.call_count == 2
|
||||||
assert_setup_component(1, SWITCH_DOMAIN)
|
assert_setup_component(1, SWITCH_DOMAIN)
|
||||||
|
|
||||||
|
|
||||||
@ -170,7 +170,7 @@ async def test_setup_with_state_resource(hass: HomeAssistant) -> None:
|
|||||||
}
|
}
|
||||||
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
|
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert route.call_count == 1
|
assert route.call_count == 2
|
||||||
assert_setup_component(1, SWITCH_DOMAIN)
|
assert_setup_component(1, SWITCH_DOMAIN)
|
||||||
|
|
||||||
|
|
||||||
@ -195,7 +195,7 @@ async def test_setup_with_templated_headers_params(hass: HomeAssistant) -> None:
|
|||||||
}
|
}
|
||||||
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
|
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert route.call_count == 1
|
assert route.call_count == 2
|
||||||
last_call = route.calls[-1]
|
last_call = route.calls[-1]
|
||||||
last_request: httpx.Request = last_call.request
|
last_request: httpx.Request = last_call.request
|
||||||
assert last_request.headers.get("Accept") == CONTENT_TYPE_JSON
|
assert last_request.headers.get("Accept") == CONTENT_TYPE_JSON
|
||||||
|
Loading…
x
Reference in New Issue
Block a user