Improve entity_platform helper typing (#75464)

* Improve entity_platform helper typing

* Add protocol class

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Marc Mueller 2022-07-20 05:45:57 +02:00 committed by GitHub
parent 0f81d1d14a
commit 8a48d54951
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 12 deletions

View File

@ -20,6 +20,7 @@ homeassistant.helpers.deprecation
homeassistant.helpers.discovery homeassistant.helpers.discovery
homeassistant.helpers.dispatcher homeassistant.helpers.dispatcher
homeassistant.helpers.entity homeassistant.helpers.entity
homeassistant.helpers.entity_platform
homeassistant.helpers.entity_values homeassistant.helpers.entity_values
homeassistant.helpers.event homeassistant.helpers.event
homeassistant.helpers.reload homeassistant.helpers.reload

View File

@ -2,11 +2,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Coroutine, Iterable from collections.abc import Awaitable, Callable, Coroutine, Iterable
from contextvars import ContextVar from contextvars import ContextVar
from datetime import datetime, timedelta from datetime import datetime, timedelta
from logging import Logger, getLogger from logging import Logger, getLogger
from types import ModuleType
from typing import TYPE_CHECKING, Any, Protocol from typing import TYPE_CHECKING, Any, Protocol
from urllib.parse import urlparse from urllib.parse import urlparse
@ -71,6 +70,36 @@ class AddEntitiesCallback(Protocol):
"""Define add_entities type.""" """Define add_entities type."""
class EntityPlatformModule(Protocol):
"""Protocol type for entity platform modules."""
async def async_setup_platform(
self,
hass: HomeAssistant,
config: ConfigType,
async_add_entities: AddEntitiesCallback,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up an integration platform async."""
def setup_platform(
self,
hass: HomeAssistant,
config: ConfigType,
add_entities: AddEntitiesCallback,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up an integration platform."""
async def async_setup_entry(
self,
hass: HomeAssistant,
entry: config_entries.ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up an integration platform from a config entry."""
class EntityPlatform: class EntityPlatform:
"""Manage the entities for a single platform.""" """Manage the entities for a single platform."""
@ -81,7 +110,7 @@ class EntityPlatform:
logger: Logger, logger: Logger,
domain: str, domain: str,
platform_name: str, platform_name: str,
platform: ModuleType | None, platform: EntityPlatformModule | None,
scan_interval: timedelta, scan_interval: timedelta,
entity_namespace: str | None, entity_namespace: str | None,
) -> None: ) -> None:
@ -95,7 +124,7 @@ class EntityPlatform:
self.entity_namespace = entity_namespace self.entity_namespace = entity_namespace
self.config_entry: config_entries.ConfigEntry | None = None self.config_entry: config_entries.ConfigEntry | None = None
self.entities: dict[str, Entity] = {} self.entities: dict[str, Entity] = {}
self._tasks: list[asyncio.Future] = [] self._tasks: list[asyncio.Task[None]] = []
# Stop tracking tasks after setup is completed # Stop tracking tasks after setup is completed
self._setup_complete = False self._setup_complete = False
# Method to cancel the state change listener # Method to cancel the state change listener
@ -169,10 +198,12 @@ class EntityPlatform:
return return
@callback @callback
def async_create_setup_task() -> Coroutine: def async_create_setup_task() -> Coroutine[
Any, Any, None
] | asyncio.Future[None]:
"""Get task to set up platform.""" """Get task to set up platform."""
if getattr(platform, "async_setup_platform", None): if getattr(platform, "async_setup_platform", None):
return platform.async_setup_platform( # type: ignore[no-any-return,union-attr] return platform.async_setup_platform( # type: ignore[union-attr]
hass, hass,
platform_config, platform_config,
self._async_schedule_add_entities, self._async_schedule_add_entities,
@ -181,7 +212,7 @@ class EntityPlatform:
# This should not be replaced with hass.async_add_job because # This should not be replaced with hass.async_add_job because
# we don't want to track this task in case it blocks startup. # we don't want to track this task in case it blocks startup.
return hass.loop.run_in_executor( # type: ignore[return-value] return hass.loop.run_in_executor(
None, None,
platform.setup_platform, # type: ignore[union-attr] platform.setup_platform, # type: ignore[union-attr]
hass, hass,
@ -211,18 +242,18 @@ class EntityPlatform:
platform = self.platform platform = self.platform
@callback @callback
def async_create_setup_task() -> Coroutine: def async_create_setup_task() -> Coroutine[Any, Any, None]:
"""Get task to set up platform.""" """Get task to set up platform."""
config_entries.current_entry.set(config_entry) config_entries.current_entry.set(config_entry)
return platform.async_setup_entry( # type: ignore[no-any-return,union-attr] return platform.async_setup_entry( # type: ignore[union-attr]
self.hass, config_entry, self._async_schedule_add_entities_for_entry self.hass, config_entry, self._async_schedule_add_entities_for_entry
) )
return await self._async_setup_platform(async_create_setup_task) return await self._async_setup_platform(async_create_setup_task)
async def _async_setup_platform( async def _async_setup_platform(
self, async_create_setup_task: Callable[[], Coroutine], tries: int = 0 self, async_create_setup_task: Callable[[], Awaitable[None]], tries: int = 0
) -> bool: ) -> bool:
"""Set up a platform via config file or config entry. """Set up a platform via config file or config entry.
@ -701,7 +732,7 @@ class EntityPlatform:
def async_register_entity_service( def async_register_entity_service(
self, self,
name: str, name: str,
schema: dict | vol.Schema, schema: dict[str, Any] | vol.Schema,
func: str | Callable[..., Any], func: str | Callable[..., Any],
required_features: Iterable[int] | None = None, required_features: Iterable[int] | None = None,
) -> None: ) -> None:
@ -753,7 +784,7 @@ class EntityPlatform:
return return
async with self._process_updates: async with self._process_updates:
tasks = [] tasks: list[Coroutine[Any, Any, None]] = []
for entity in self.entities.values(): for entity in self.entities.values():
if not entity.should_poll: if not entity.should_poll:
continue continue

View File

@ -72,6 +72,9 @@ disallow_any_generics = true
[mypy-homeassistant.helpers.entity] [mypy-homeassistant.helpers.entity]
disallow_any_generics = true disallow_any_generics = true
[mypy-homeassistant.helpers.entity_platform]
disallow_any_generics = true
[mypy-homeassistant.helpers.entity_values] [mypy-homeassistant.helpers.entity_values]
disallow_any_generics = true disallow_any_generics = true