Deduplicate async_register_entity_service (#124045)

This commit is contained in:
Erik Montnemery 2024-08-16 14:06:35 +02:00 committed by GitHub
parent 738cc5095d
commit 69943af68a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 72 additions and 68 deletions

View File

@ -5,13 +5,11 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Iterable
from datetime import timedelta
from functools import partial
import logging
from types import ModuleType
from typing import Any, Generic
from typing_extensions import TypeVar
import voluptuous as vol
from homeassistant import config as conf_util
from homeassistant.config_entries import ConfigEntry
@ -265,39 +263,16 @@ class EntityComponent(Generic[_EntityT]):
supports_response: SupportsResponse = SupportsResponse.NONE,
) -> None:
"""Register an entity service."""
if schema is None or isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
# Do a sanity check to check this is a valid entity service schema,
# the check could be extended to require All/Any to have sub schema(s)
# with all entity service fields
elif (
# Don't check All/Any
not isinstance(schema, (vol.All, vol.Any))
# Don't check All/Any wrapped in schema
and not isinstance(schema.schema, (vol.All, vol.Any))
and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS)
):
raise HomeAssistantError(
"The schema does not include all required keys: "
f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}"
)
service_func: str | HassJob[..., Any]
service_func = func if isinstance(func, str) else HassJob(func)
self.hass.services.async_register(
service.async_register_entity_service(
self.hass,
self.domain,
name,
partial(
service.entity_service_call,
self.hass,
self._entities,
service_func,
required_features=required_features,
),
schema,
supports_response,
entities=self._entities,
func=func,
job_type=HassJobType.Coroutinefunction,
required_features=required_features,
schema=schema,
supports_response=supports_response,
)
async def async_setup_platform(

View File

@ -6,12 +6,9 @@ import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Iterable
from contextvars import ContextVar
from datetime import timedelta
from functools import partial
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Protocol
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.const import (
ATTR_RESTORED,
@ -22,7 +19,6 @@ from homeassistant.core import (
CALLBACK_TYPE,
DOMAIN as HOMEASSISTANT_DOMAIN,
CoreState,
HassJob,
HomeAssistant,
ServiceCall,
SupportsResponse,
@ -43,7 +39,6 @@ from homeassistant.util.async_ import create_eager_task
from homeassistant.util.hass_dict import HassKey
from . import (
config_validation as cv,
device_registry as dev_reg,
entity_registry as ent_reg,
service,
@ -999,38 +994,16 @@ class EntityPlatform:
if self.hass.services.has_service(self.platform_name, name):
return
if schema is None or isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
# Do a sanity check to check this is a valid entity service schema,
# the check could be extended to require All/Any to have sub schema(s)
# with all entity service fields
elif (
# Don't check All/Any
not isinstance(schema, (vol.All, vol.Any))
# Don't check All/Any wrapped in schema
and not isinstance(schema.schema, (vol.All, vol.Any))
and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS)
):
raise HomeAssistantError(
"The schema does not include all required keys: "
f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}"
)
service_func: str | HassJob[..., Any]
service_func = func if isinstance(func, str) else HassJob(func)
self.hass.services.async_register(
service.async_register_entity_service(
self.hass,
self.platform_name,
name,
partial(
service.entity_service_call,
self.hass,
self.domain_platform_entities,
service_func,
required_features=required_features,
),
schema,
supports_response,
entities=self.domain_platform_entities,
func=func,
job_type=None,
required_features=required_features,
schema=schema,
supports_response=supports_response,
)
async def _async_update_entity_states(self) -> None:

View File

@ -33,6 +33,7 @@ from homeassistant.core import (
Context,
EntityServiceResponse,
HassJob,
HassJobType,
HomeAssistant,
ServiceCall,
ServiceResponse,
@ -63,7 +64,7 @@ from . import (
)
from .group import expand_entity_ids
from .selector import TargetSelector
from .typing import ConfigType, TemplateVarsType, VolSchemaType
from .typing import ConfigType, TemplateVarsType, VolDictType, VolSchemaType
if TYPE_CHECKING:
from .entity import Entity
@ -1240,3 +1241,58 @@ class ReloadServiceHelper[_T]:
self._service_running = False
self._pending_reload_targets -= reload_targets
self._service_condition.notify_all()
@callback
def async_register_entity_service(
hass: HomeAssistant,
domain: str,
name: str,
*,
entities: dict[str, Entity],
func: str | Callable[..., Any],
job_type: HassJobType | None,
required_features: Iterable[int] | None = None,
schema: VolDictType | VolSchemaType | None,
supports_response: SupportsResponse = SupportsResponse.NONE,
) -> None:
"""Help registering an entity service.
This is called by EntityComponent.async_register_entity_service and
EntityPlatform.async_register_entity_service and should not be called
directly by integrations.
"""
if schema is None or isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
# Do a sanity check to check this is a valid entity service schema,
# the check could be extended to require All/Any to have sub schema(s)
# with all entity service fields
elif (
# Don't check All/Any
not isinstance(schema, (vol.All, vol.Any))
# Don't check All/Any wrapped in schema
and not isinstance(schema.schema, (vol.All, vol.Any))
and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS)
):
raise HomeAssistantError(
"The schema does not include all required keys: "
f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}"
)
service_func: str | HassJob[..., Any]
service_func = func if isinstance(func, str) else HassJob(func)
hass.services.async_register(
domain,
name,
partial(
entity_service_call,
hass,
entities,
service_func,
required_features=required_features,
),
schema,
supports_response,
job_type=job_type,
)