Improve typing for device_automation (#74282)

This commit is contained in:
J. Nick Koston 2022-07-01 01:23:00 -05:00 committed by GitHub
parent 273e9b287f
commit 57b63db567
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Iterable, Mapping from collections.abc import Awaitable, Callable, Coroutine, Iterable, Mapping
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
import logging import logging
@ -13,6 +13,7 @@ import voluptuous as vol
import voluptuous_serialize import voluptuous_serialize
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
CONF_DEVICE_ID, CONF_DEVICE_ID,
@ -43,10 +44,9 @@ if TYPE_CHECKING:
DeviceAutomationActionProtocol, DeviceAutomationActionProtocol,
] ]
# mypy: allow-untyped-calls, allow-untyped-defs
DOMAIN = "device_automation" DOMAIN = "device_automation"
DEVICE_TRIGGER_BASE_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend( DEVICE_TRIGGER_BASE_SCHEMA: vol.Schema = cv.TRIGGER_BASE_SCHEMA.extend(
{ {
vol.Required(CONF_PLATFORM): "device", vol.Required(CONF_PLATFORM): "device",
vol.Required(CONF_DOMAIN): str, vol.Required(CONF_DOMAIN): str,
@ -310,11 +310,17 @@ async def _async_get_device_automation_capabilities(
return capabilities # type: ignore[no-any-return] return capabilities # type: ignore[no-any-return]
def handle_device_errors(func): def handle_device_errors(
func: Callable[[HomeAssistant, ActiveConnection, dict[str, Any]], Awaitable[None]]
) -> Callable[
[HomeAssistant, ActiveConnection, dict[str, Any]], Coroutine[Any, Any, None]
]:
"""Handle device automation errors.""" """Handle device automation errors."""
@wraps(func) @wraps(func)
async def with_error_handling(hass, connection, msg): async def with_error_handling(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
try: try:
await func(hass, connection, msg) await func(hass, connection, msg)
except DeviceNotFound: except DeviceNotFound:
@ -333,7 +339,9 @@ def handle_device_errors(func):
) )
@websocket_api.async_response @websocket_api.async_response
@handle_device_errors @handle_device_errors
async def websocket_device_automation_list_actions(hass, connection, msg): async def websocket_device_automation_list_actions(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle request for device actions.""" """Handle request for device actions."""
device_id = msg["device_id"] device_id = msg["device_id"]
actions = ( actions = (
@ -352,7 +360,9 @@ async def websocket_device_automation_list_actions(hass, connection, msg):
) )
@websocket_api.async_response @websocket_api.async_response
@handle_device_errors @handle_device_errors
async def websocket_device_automation_list_conditions(hass, connection, msg): async def websocket_device_automation_list_conditions(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle request for device conditions.""" """Handle request for device conditions."""
device_id = msg["device_id"] device_id = msg["device_id"]
conditions = ( conditions = (
@ -371,7 +381,9 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
) )
@websocket_api.async_response @websocket_api.async_response
@handle_device_errors @handle_device_errors
async def websocket_device_automation_list_triggers(hass, connection, msg): async def websocket_device_automation_list_triggers(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle request for device triggers.""" """Handle request for device triggers."""
device_id = msg["device_id"] device_id = msg["device_id"]
triggers = ( triggers = (
@ -390,7 +402,9 @@ async def websocket_device_automation_list_triggers(hass, connection, msg):
) )
@websocket_api.async_response @websocket_api.async_response
@handle_device_errors @handle_device_errors
async def websocket_device_automation_get_action_capabilities(hass, connection, msg): async def websocket_device_automation_get_action_capabilities(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle request for device action capabilities.""" """Handle request for device action capabilities."""
action = msg["action"] action = msg["action"]
capabilities = await _async_get_device_automation_capabilities( capabilities = await _async_get_device_automation_capabilities(
@ -409,7 +423,9 @@ async def websocket_device_automation_get_action_capabilities(hass, connection,
) )
@websocket_api.async_response @websocket_api.async_response
@handle_device_errors @handle_device_errors
async def websocket_device_automation_get_condition_capabilities(hass, connection, msg): async def websocket_device_automation_get_condition_capabilities(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle request for device condition capabilities.""" """Handle request for device condition capabilities."""
condition = msg["condition"] condition = msg["condition"]
capabilities = await _async_get_device_automation_capabilities( capabilities = await _async_get_device_automation_capabilities(
@ -428,7 +444,9 @@ async def websocket_device_automation_get_condition_capabilities(hass, connectio
) )
@websocket_api.async_response @websocket_api.async_response
@handle_device_errors @handle_device_errors
async def websocket_device_automation_get_trigger_capabilities(hass, connection, msg): async def websocket_device_automation_get_trigger_capabilities(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle request for device trigger capabilities.""" """Handle request for device trigger capabilities."""
trigger = msg["trigger"] trigger = msg["trigger"]
capabilities = await _async_get_device_automation_capabilities( capabilities = await _async_get_device_automation_capabilities(