Make sure entity platform services work for all platforms of d… (#33176)

* Make sure entity platform services work for all platforms of domain

* Register a bad service handler

* Fix cleaning up

* Tiny cleanup
This commit is contained in:
Paulus Schoutsen 2020-03-23 12:59:36 -07:00 committed by GitHub
parent 2360fd4141
commit 1ff245d9c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 9 deletions

View File

@ -254,7 +254,13 @@ class EntityComponent:
This method must be run in the event loop.
"""
tasks = [platform.async_reset() for platform in self._platforms.values()]
tasks = []
for key, platform in self._platforms.items():
if key == self.domain:
tasks.append(platform.async_reset())
else:
tasks.append(platform.async_destroy())
if tasks:
await asyncio.wait(tasks)

View File

@ -24,6 +24,7 @@ if TYPE_CHECKING:
SLOW_SETUP_WARNING = 10
SLOW_SETUP_MAX_WAIT = 60
PLATFORM_NOT_READY_RETRIES = 10
DATA_ENTITY_PLATFORM = "entity_platform"
class EntityPlatform:
@ -57,15 +58,15 @@ class EntityPlatform:
self._async_cancel_retry_setup: Optional[CALLBACK_TYPE] = None
self._process_updates: Optional[asyncio.Lock] = None
self.parallel_updates: Optional[asyncio.Semaphore] = None
# Platform is None for the EntityComponent "catch-all" EntityPlatform
# which powers entity_component.add_entities
if platform is None:
self.parallel_updates_created = True
self.parallel_updates: Optional[asyncio.Semaphore] = None
return
self.parallel_updates_created = platform is None
self.parallel_updates_created = False
self.parallel_updates = None
hass.data.setdefault(DATA_ENTITY_PLATFORM, {}).setdefault(
self.platform_name, []
).append(self)
@callback
def _get_parallel_updates_semaphore(
@ -464,6 +465,14 @@ class EntityPlatform:
self._async_unsub_polling()
self._async_unsub_polling = None
async def async_destroy(self) -> None:
"""Destroy an entity platform.
Call before discarding the object.
"""
await self.async_reset()
self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name].remove(self)
async def async_remove_entity(self, entity_id: str) -> None:
"""Remove entity id from platform."""
await self.entities[entity_id].async_remove()
@ -488,14 +497,24 @@ class EntityPlatform:
@callback
def async_register_entity_service(self, name, schema, func, required_features=None):
"""Register an entity service."""
"""Register an entity service.
Services will automatically be shared by all platforms of the same domain.
"""
if self.hass.services.has_service(self.platform_name, name):
return
if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
async def handle_service(call):
"""Handle the service."""
await service.entity_service_call(
self.hass, [self], func, call, required_features
self.hass,
self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name],
func,
call,
required_features,
)
self.hass.services.async_register(

View File

@ -8,6 +8,7 @@ import asynctest
import pytest
from homeassistant.const import UNIT_PERCENTAGE
from homeassistant.core import callback
from homeassistant.exceptions import PlatformNotReady
from homeassistant.helpers import entity_platform, entity_registry
from homeassistant.helpers.entity import async_generate_entity_id
@ -847,3 +848,37 @@ async def test_platform_with_no_setup(hass, caplog):
"The mock-platform platform for the mock-integration integration does not support platform setup."
in caplog.text
)
async def test_platforms_sharing_services(hass):
"""Test platforms share services."""
entity_platform1 = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity1 = MockEntity(entity_id="mock_integration.entity_1")
await entity_platform1.async_add_entities([entity1])
entity_platform2 = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity2 = MockEntity(entity_id="mock_integration.entity_2")
await entity_platform2.async_add_entities([entity2])
entities = []
@callback
def handle_service(entity, data):
entities.append(entity)
entity_platform1.async_register_entity_service("hello", {}, handle_service)
entity_platform2.async_register_entity_service(
"hello", {}, Mock(side_effect=AssertionError("Should not be called"))
)
await hass.services.async_call(
"mock_platform", "hello", {"entity_id": "all"}, blocking=True
)
assert len(entities) == 2
assert entity1 in entities
assert entity2 in entities