Add typing to async_register_entity_service (#50015)

This commit is contained in:
Franck Nijhof 2021-05-03 14:22:38 +02:00 committed by GitHub
parent 1ad9f1d714
commit 378cee01b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 14 deletions

View File

@ -71,7 +71,7 @@ async def async_setup_entry(
platform = entity_platform.current_platform.get() platform = entity_platform.current_platform.get()
assert platform assert platform
platform.async_register_entity_service( # type: ignore platform.async_register_entity_service(
SERVICE_SET_LOCK_USERCODE, SERVICE_SET_LOCK_USERCODE,
{ {
vol.Required(ATTR_CODE_SLOT): vol.Coerce(int), vol.Required(ATTR_CODE_SLOT): vol.Coerce(int),
@ -80,7 +80,7 @@ async def async_setup_entry(
"async_set_lock_usercode", "async_set_lock_usercode",
) )
platform.async_register_entity_service( # type: ignore platform.async_register_entity_service(
SERVICE_CLEAR_LOCK_USERCODE, SERVICE_CLEAR_LOCK_USERCODE,
{ {
vol.Required(ATTR_CODE_SLOT): vol.Coerce(int), vol.Required(ATTR_CODE_SLOT): vol.Coerce(int),

View File

@ -862,17 +862,19 @@ ENTITY_SERVICE_FIELDS = {
def make_entity_service_schema( def make_entity_service_schema(
schema: dict, *, extra: int = vol.PREVENT_EXTRA schema: dict, *, extra: int = vol.PREVENT_EXTRA
) -> vol.All: ) -> vol.Schema:
"""Create an entity service schema.""" """Create an entity service schema."""
return vol.All( return vol.Schema(
vol.Schema( vol.All(
{ vol.Schema(
**schema, {
**ENTITY_SERVICE_FIELDS, **schema,
}, **ENTITY_SERVICE_FIELDS,
extra=extra, },
), extra=extra,
has_at_least_one_key(*ENTITY_SERVICE_FIELDS), ),
has_at_least_one_key(*ENTITY_SERVICE_FIELDS),
)
) )

View File

@ -8,9 +8,10 @@ from datetime import datetime, timedelta
import logging import logging
from logging import Logger from logging import Logger
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Any, Callable
from typing_extensions import Protocol from typing_extensions import Protocol
import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import ( from homeassistant.const import (
@ -625,7 +626,13 @@ class EntityPlatform:
) )
@callback @callback
def async_register_entity_service(self, name, schema, func, required_features=None): # type: ignore[no-untyped-def] def async_register_entity_service(
self,
name: str,
schema: dict | vol.Schema,
func: str | Callable[..., Any],
required_features: Iterable[int] | None = None,
) -> None:
"""Register an entity service. """Register an entity service.
Services will automatically be shared by all platforms of the same domain. Services will automatically be shared by all platforms of the same domain.