Improve type hints in script helpers (#78364)

* Improve type hints in script helpers

* Import CONF_SERVICE_DATA from homeassistant.const

* Make data optional
This commit is contained in:
epenet 2022-09-13 23:11:29 +02:00 committed by GitHub
parent 4f963cfc64
commit d3be06906b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Sequence from collections.abc import Callable, Mapping, Sequence
from contextlib import asynccontextmanager, suppress from contextlib import asynccontextmanager, suppress
from contextvars import ContextVar from contextvars import ContextVar
from copy import copy from copy import copy
@ -49,6 +49,7 @@ from homeassistant.const import (
CONF_SCENE, CONF_SCENE,
CONF_SEQUENCE, CONF_SEQUENCE,
CONF_SERVICE, CONF_SERVICE,
CONF_SERVICE_DATA,
CONF_STOP, CONF_STOP,
CONF_TARGET, CONF_TARGET,
CONF_THEN, CONF_THEN,
@ -218,7 +219,9 @@ async def trace_action(hass, script_run, stop, variables):
trace_stack_pop(trace_stack_cv) trace_stack_pop(trace_stack_cv)
def make_script_schema(schema, default_script_mode, extra=vol.PREVENT_EXTRA): def make_script_schema(
schema: Mapping[Any, Any], default_script_mode: str, extra: int = vol.PREVENT_EXTRA
) -> vol.Schema:
"""Make a schema for a component that uses the script helper.""" """Make a schema for a component that uses the script helper."""
return vol.Schema( return vol.Schema(
{ {
@ -1109,7 +1112,9 @@ async def _async_stop_scripts_at_shutdown(hass, event):
_VarsType = Union[dict[str, Any], MappingProxyType] _VarsType = Union[dict[str, Any], MappingProxyType]
def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) -> None: def _referenced_extract_ids(
data: dict[str, Any] | None, key: str, found: set[str]
) -> None:
"""Extract referenced IDs.""" """Extract referenced IDs."""
if not data: if not data:
return return
@ -1275,24 +1280,26 @@ class Script:
return self.script_mode in (SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED) return self.script_mode in (SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED)
@property @property
def referenced_areas(self): def referenced_areas(self) -> set[str]:
"""Return a set of referenced areas.""" """Return a set of referenced areas."""
if self._referenced_areas is not None: if self._referenced_areas is not None:
return self._referenced_areas return self._referenced_areas
self._referenced_areas: set[str] = set() self._referenced_areas = set()
Script._find_referenced_areas(self._referenced_areas, self.sequence) Script._find_referenced_areas(self._referenced_areas, self.sequence)
return self._referenced_areas return self._referenced_areas
@staticmethod @staticmethod
def _find_referenced_areas(referenced, sequence): def _find_referenced_areas(
referenced: set[str], sequence: Sequence[dict[str, Any]]
) -> None:
for step in sequence: for step in sequence:
action = cv.determine_script_action(step) action = cv.determine_script_action(step)
if action == cv.SCRIPT_ACTION_CALL_SERVICE: if action == cv.SCRIPT_ACTION_CALL_SERVICE:
for data in ( for data in (
step.get(CONF_TARGET), step.get(CONF_TARGET),
step.get(service.CONF_SERVICE_DATA), step.get(CONF_SERVICE_DATA),
step.get(service.CONF_SERVICE_DATA_TEMPLATE), step.get(service.CONF_SERVICE_DATA_TEMPLATE),
): ):
_referenced_extract_ids(data, ATTR_AREA_ID, referenced) _referenced_extract_ids(data, ATTR_AREA_ID, referenced)
@ -1313,24 +1320,26 @@ class Script:
Script._find_referenced_areas(referenced, script[CONF_SEQUENCE]) Script._find_referenced_areas(referenced, script[CONF_SEQUENCE])
@property @property
def referenced_devices(self): def referenced_devices(self) -> set[str]:
"""Return a set of referenced devices.""" """Return a set of referenced devices."""
if self._referenced_devices is not None: if self._referenced_devices is not None:
return self._referenced_devices return self._referenced_devices
self._referenced_devices: set[str] = set() self._referenced_devices = set()
Script._find_referenced_devices(self._referenced_devices, self.sequence) Script._find_referenced_devices(self._referenced_devices, self.sequence)
return self._referenced_devices return self._referenced_devices
@staticmethod @staticmethod
def _find_referenced_devices(referenced, sequence): def _find_referenced_devices(
referenced: set[str], sequence: Sequence[dict[str, Any]]
) -> None:
for step in sequence: for step in sequence:
action = cv.determine_script_action(step) action = cv.determine_script_action(step)
if action == cv.SCRIPT_ACTION_CALL_SERVICE: if action == cv.SCRIPT_ACTION_CALL_SERVICE:
for data in ( for data in (
step.get(CONF_TARGET), step.get(CONF_TARGET),
step.get(service.CONF_SERVICE_DATA), step.get(CONF_SERVICE_DATA),
step.get(service.CONF_SERVICE_DATA_TEMPLATE), step.get(service.CONF_SERVICE_DATA_TEMPLATE),
): ):
_referenced_extract_ids(data, ATTR_DEVICE_ID, referenced) _referenced_extract_ids(data, ATTR_DEVICE_ID, referenced)
@ -1361,17 +1370,19 @@ class Script:
Script._find_referenced_devices(referenced, script[CONF_SEQUENCE]) Script._find_referenced_devices(referenced, script[CONF_SEQUENCE])
@property @property
def referenced_entities(self): def referenced_entities(self) -> set[str]:
"""Return a set of referenced entities.""" """Return a set of referenced entities."""
if self._referenced_entities is not None: if self._referenced_entities is not None:
return self._referenced_entities return self._referenced_entities
self._referenced_entities: set[str] = set() self._referenced_entities = set()
Script._find_referenced_entities(self._referenced_entities, self.sequence) Script._find_referenced_entities(self._referenced_entities, self.sequence)
return self._referenced_entities return self._referenced_entities
@staticmethod @staticmethod
def _find_referenced_entities(referenced, sequence): def _find_referenced_entities(
referenced: set[str], sequence: Sequence[dict[str, Any]]
) -> None:
for step in sequence: for step in sequence:
action = cv.determine_script_action(step) action = cv.determine_script_action(step)
@ -1379,7 +1390,7 @@ class Script:
for data in ( for data in (
step, step,
step.get(CONF_TARGET), step.get(CONF_TARGET),
step.get(service.CONF_SERVICE_DATA), step.get(CONF_SERVICE_DATA),
step.get(service.CONF_SERVICE_DATA_TEMPLATE), step.get(service.CONF_SERVICE_DATA_TEMPLATE),
): ):
_referenced_extract_ids(data, ATTR_ENTITY_ID, referenced) _referenced_extract_ids(data, ATTR_ENTITY_ID, referenced)