mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
Deduplicate async_register_entity_service (#124045)
This commit is contained in:
parent
738cc5095d
commit
69943af68a
@ -5,13 +5,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Callable, Iterable
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config as conf_util
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@ -265,39 +263,16 @@ class EntityComponent(Generic[_EntityT]):
|
||||
supports_response: SupportsResponse = SupportsResponse.NONE,
|
||||
) -> None:
|
||||
"""Register an entity service."""
|
||||
if schema is None or isinstance(schema, dict):
|
||||
schema = cv.make_entity_service_schema(schema)
|
||||
# Do a sanity check to check this is a valid entity service schema,
|
||||
# the check could be extended to require All/Any to have sub schema(s)
|
||||
# with all entity service fields
|
||||
elif (
|
||||
# Don't check All/Any
|
||||
not isinstance(schema, (vol.All, vol.Any))
|
||||
# Don't check All/Any wrapped in schema
|
||||
and not isinstance(schema.schema, (vol.All, vol.Any))
|
||||
and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS)
|
||||
):
|
||||
raise HomeAssistantError(
|
||||
"The schema does not include all required keys: "
|
||||
f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}"
|
||||
)
|
||||
|
||||
service_func: str | HassJob[..., Any]
|
||||
service_func = func if isinstance(func, str) else HassJob(func)
|
||||
|
||||
self.hass.services.async_register(
|
||||
service.async_register_entity_service(
|
||||
self.hass,
|
||||
self.domain,
|
||||
name,
|
||||
partial(
|
||||
service.entity_service_call,
|
||||
self.hass,
|
||||
self._entities,
|
||||
service_func,
|
||||
required_features=required_features,
|
||||
),
|
||||
schema,
|
||||
supports_response,
|
||||
entities=self._entities,
|
||||
func=func,
|
||||
job_type=HassJobType.Coroutinefunction,
|
||||
required_features=required_features,
|
||||
schema=schema,
|
||||
supports_response=supports_response,
|
||||
)
|
||||
|
||||
async def async_setup_platform(
|
||||
|
@ -6,12 +6,9 @@ import asyncio
|
||||
from collections.abc import Awaitable, Callable, Coroutine, Iterable
|
||||
from contextvars import ContextVar
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
from logging import Logger, getLogger
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.const import (
|
||||
ATTR_RESTORED,
|
||||
@ -22,7 +19,6 @@ from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
DOMAIN as HOMEASSISTANT_DOMAIN,
|
||||
CoreState,
|
||||
HassJob,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
SupportsResponse,
|
||||
@ -43,7 +39,6 @@ from homeassistant.util.async_ import create_eager_task
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
from . import (
|
||||
config_validation as cv,
|
||||
device_registry as dev_reg,
|
||||
entity_registry as ent_reg,
|
||||
service,
|
||||
@ -999,38 +994,16 @@ class EntityPlatform:
|
||||
if self.hass.services.has_service(self.platform_name, name):
|
||||
return
|
||||
|
||||
if schema is None or isinstance(schema, dict):
|
||||
schema = cv.make_entity_service_schema(schema)
|
||||
# Do a sanity check to check this is a valid entity service schema,
|
||||
# the check could be extended to require All/Any to have sub schema(s)
|
||||
# with all entity service fields
|
||||
elif (
|
||||
# Don't check All/Any
|
||||
not isinstance(schema, (vol.All, vol.Any))
|
||||
# Don't check All/Any wrapped in schema
|
||||
and not isinstance(schema.schema, (vol.All, vol.Any))
|
||||
and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS)
|
||||
):
|
||||
raise HomeAssistantError(
|
||||
"The schema does not include all required keys: "
|
||||
f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}"
|
||||
)
|
||||
|
||||
service_func: str | HassJob[..., Any]
|
||||
service_func = func if isinstance(func, str) else HassJob(func)
|
||||
|
||||
self.hass.services.async_register(
|
||||
service.async_register_entity_service(
|
||||
self.hass,
|
||||
self.platform_name,
|
||||
name,
|
||||
partial(
|
||||
service.entity_service_call,
|
||||
self.hass,
|
||||
self.domain_platform_entities,
|
||||
service_func,
|
||||
required_features=required_features,
|
||||
),
|
||||
schema,
|
||||
supports_response,
|
||||
entities=self.domain_platform_entities,
|
||||
func=func,
|
||||
job_type=None,
|
||||
required_features=required_features,
|
||||
schema=schema,
|
||||
supports_response=supports_response,
|
||||
)
|
||||
|
||||
async def _async_update_entity_states(self) -> None:
|
||||
|
@ -33,6 +33,7 @@ from homeassistant.core import (
|
||||
Context,
|
||||
EntityServiceResponse,
|
||||
HassJob,
|
||||
HassJobType,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
@ -63,7 +64,7 @@ from . import (
|
||||
)
|
||||
from .group import expand_entity_ids
|
||||
from .selector import TargetSelector
|
||||
from .typing import ConfigType, TemplateVarsType, VolSchemaType
|
||||
from .typing import ConfigType, TemplateVarsType, VolDictType, VolSchemaType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .entity import Entity
|
||||
@ -1240,3 +1241,58 @@ class ReloadServiceHelper[_T]:
|
||||
self._service_running = False
|
||||
self._pending_reload_targets -= reload_targets
|
||||
self._service_condition.notify_all()
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_entity_service(
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
name: str,
|
||||
*,
|
||||
entities: dict[str, Entity],
|
||||
func: str | Callable[..., Any],
|
||||
job_type: HassJobType | None,
|
||||
required_features: Iterable[int] | None = None,
|
||||
schema: VolDictType | VolSchemaType | None,
|
||||
supports_response: SupportsResponse = SupportsResponse.NONE,
|
||||
) -> None:
|
||||
"""Help registering an entity service.
|
||||
|
||||
This is called by EntityComponent.async_register_entity_service and
|
||||
EntityPlatform.async_register_entity_service and should not be called
|
||||
directly by integrations.
|
||||
"""
|
||||
if schema is None or isinstance(schema, dict):
|
||||
schema = cv.make_entity_service_schema(schema)
|
||||
# Do a sanity check to check this is a valid entity service schema,
|
||||
# the check could be extended to require All/Any to have sub schema(s)
|
||||
# with all entity service fields
|
||||
elif (
|
||||
# Don't check All/Any
|
||||
not isinstance(schema, (vol.All, vol.Any))
|
||||
# Don't check All/Any wrapped in schema
|
||||
and not isinstance(schema.schema, (vol.All, vol.Any))
|
||||
and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS)
|
||||
):
|
||||
raise HomeAssistantError(
|
||||
"The schema does not include all required keys: "
|
||||
f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}"
|
||||
)
|
||||
|
||||
service_func: str | HassJob[..., Any]
|
||||
service_func = func if isinstance(func, str) else HassJob(func)
|
||||
|
||||
hass.services.async_register(
|
||||
domain,
|
||||
name,
|
||||
partial(
|
||||
entity_service_call,
|
||||
hass,
|
||||
entities,
|
||||
service_func,
|
||||
required_features=required_features,
|
||||
),
|
||||
schema,
|
||||
supports_response,
|
||||
job_type=job_type,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user