mirror of
https://github.com/home-assistant/core.git
synced 2025-07-18 18:57:06 +00:00
Improve type hints in script helpers (#78364)
* Improve type hints in script helpers * Import CONF_SERVICE_DATA from homeassistant.const * Make data optional
This commit is contained in:
parent
4f963cfc64
commit
d3be06906b
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Mapping, Sequence
|
||||||
from contextlib import asynccontextmanager, suppress
|
from contextlib import asynccontextmanager, suppress
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from copy import copy
|
from copy import copy
|
||||||
@ -49,6 +49,7 @@ from homeassistant.const import (
|
|||||||
CONF_SCENE,
|
CONF_SCENE,
|
||||||
CONF_SEQUENCE,
|
CONF_SEQUENCE,
|
||||||
CONF_SERVICE,
|
CONF_SERVICE,
|
||||||
|
CONF_SERVICE_DATA,
|
||||||
CONF_STOP,
|
CONF_STOP,
|
||||||
CONF_TARGET,
|
CONF_TARGET,
|
||||||
CONF_THEN,
|
CONF_THEN,
|
||||||
@ -218,7 +219,9 @@ async def trace_action(hass, script_run, stop, variables):
|
|||||||
trace_stack_pop(trace_stack_cv)
|
trace_stack_pop(trace_stack_cv)
|
||||||
|
|
||||||
|
|
||||||
def make_script_schema(schema, default_script_mode, extra=vol.PREVENT_EXTRA):
|
def make_script_schema(
|
||||||
|
schema: Mapping[Any, Any], default_script_mode: str, extra: int = vol.PREVENT_EXTRA
|
||||||
|
) -> vol.Schema:
|
||||||
"""Make a schema for a component that uses the script helper."""
|
"""Make a schema for a component that uses the script helper."""
|
||||||
return vol.Schema(
|
return vol.Schema(
|
||||||
{
|
{
|
||||||
@ -1109,7 +1112,9 @@ async def _async_stop_scripts_at_shutdown(hass, event):
|
|||||||
_VarsType = Union[dict[str, Any], MappingProxyType]
|
_VarsType = Union[dict[str, Any], MappingProxyType]
|
||||||
|
|
||||||
|
|
||||||
def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) -> None:
|
def _referenced_extract_ids(
|
||||||
|
data: dict[str, Any] | None, key: str, found: set[str]
|
||||||
|
) -> None:
|
||||||
"""Extract referenced IDs."""
|
"""Extract referenced IDs."""
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
@ -1275,24 +1280,26 @@ class Script:
|
|||||||
return self.script_mode in (SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED)
|
return self.script_mode in (SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def referenced_areas(self):
|
def referenced_areas(self) -> set[str]:
|
||||||
"""Return a set of referenced areas."""
|
"""Return a set of referenced areas."""
|
||||||
if self._referenced_areas is not None:
|
if self._referenced_areas is not None:
|
||||||
return self._referenced_areas
|
return self._referenced_areas
|
||||||
|
|
||||||
self._referenced_areas: set[str] = set()
|
self._referenced_areas = set()
|
||||||
Script._find_referenced_areas(self._referenced_areas, self.sequence)
|
Script._find_referenced_areas(self._referenced_areas, self.sequence)
|
||||||
return self._referenced_areas
|
return self._referenced_areas
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _find_referenced_areas(referenced, sequence):
|
def _find_referenced_areas(
|
||||||
|
referenced: set[str], sequence: Sequence[dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
for step in sequence:
|
for step in sequence:
|
||||||
action = cv.determine_script_action(step)
|
action = cv.determine_script_action(step)
|
||||||
|
|
||||||
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
|
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
|
||||||
for data in (
|
for data in (
|
||||||
step.get(CONF_TARGET),
|
step.get(CONF_TARGET),
|
||||||
step.get(service.CONF_SERVICE_DATA),
|
step.get(CONF_SERVICE_DATA),
|
||||||
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
|
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
|
||||||
):
|
):
|
||||||
_referenced_extract_ids(data, ATTR_AREA_ID, referenced)
|
_referenced_extract_ids(data, ATTR_AREA_ID, referenced)
|
||||||
@ -1313,24 +1320,26 @@ class Script:
|
|||||||
Script._find_referenced_areas(referenced, script[CONF_SEQUENCE])
|
Script._find_referenced_areas(referenced, script[CONF_SEQUENCE])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def referenced_devices(self):
|
def referenced_devices(self) -> set[str]:
|
||||||
"""Return a set of referenced devices."""
|
"""Return a set of referenced devices."""
|
||||||
if self._referenced_devices is not None:
|
if self._referenced_devices is not None:
|
||||||
return self._referenced_devices
|
return self._referenced_devices
|
||||||
|
|
||||||
self._referenced_devices: set[str] = set()
|
self._referenced_devices = set()
|
||||||
Script._find_referenced_devices(self._referenced_devices, self.sequence)
|
Script._find_referenced_devices(self._referenced_devices, self.sequence)
|
||||||
return self._referenced_devices
|
return self._referenced_devices
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _find_referenced_devices(referenced, sequence):
|
def _find_referenced_devices(
|
||||||
|
referenced: set[str], sequence: Sequence[dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
for step in sequence:
|
for step in sequence:
|
||||||
action = cv.determine_script_action(step)
|
action = cv.determine_script_action(step)
|
||||||
|
|
||||||
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
|
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
|
||||||
for data in (
|
for data in (
|
||||||
step.get(CONF_TARGET),
|
step.get(CONF_TARGET),
|
||||||
step.get(service.CONF_SERVICE_DATA),
|
step.get(CONF_SERVICE_DATA),
|
||||||
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
|
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
|
||||||
):
|
):
|
||||||
_referenced_extract_ids(data, ATTR_DEVICE_ID, referenced)
|
_referenced_extract_ids(data, ATTR_DEVICE_ID, referenced)
|
||||||
@ -1361,17 +1370,19 @@ class Script:
|
|||||||
Script._find_referenced_devices(referenced, script[CONF_SEQUENCE])
|
Script._find_referenced_devices(referenced, script[CONF_SEQUENCE])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def referenced_entities(self):
|
def referenced_entities(self) -> set[str]:
|
||||||
"""Return a set of referenced entities."""
|
"""Return a set of referenced entities."""
|
||||||
if self._referenced_entities is not None:
|
if self._referenced_entities is not None:
|
||||||
return self._referenced_entities
|
return self._referenced_entities
|
||||||
|
|
||||||
self._referenced_entities: set[str] = set()
|
self._referenced_entities = set()
|
||||||
Script._find_referenced_entities(self._referenced_entities, self.sequence)
|
Script._find_referenced_entities(self._referenced_entities, self.sequence)
|
||||||
return self._referenced_entities
|
return self._referenced_entities
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _find_referenced_entities(referenced, sequence):
|
def _find_referenced_entities(
|
||||||
|
referenced: set[str], sequence: Sequence[dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
for step in sequence:
|
for step in sequence:
|
||||||
action = cv.determine_script_action(step)
|
action = cv.determine_script_action(step)
|
||||||
|
|
||||||
@ -1379,7 +1390,7 @@ class Script:
|
|||||||
for data in (
|
for data in (
|
||||||
step,
|
step,
|
||||||
step.get(CONF_TARGET),
|
step.get(CONF_TARGET),
|
||||||
step.get(service.CONF_SERVICE_DATA),
|
step.get(CONF_SERVICE_DATA),
|
||||||
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
|
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
|
||||||
):
|
):
|
||||||
_referenced_extract_ids(data, ATTR_ENTITY_ID, referenced)
|
_referenced_extract_ids(data, ATTR_ENTITY_ID, referenced)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user