From 1ff245d9c2d233466bde5dcba84ade3f6cb9d4c0 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 23 Mar 2020 12:59:36 -0700 Subject: [PATCH] =?UTF-8?q?Make=20sure=20entity=20platform=20services=20wo?= =?UTF-8?q?rk=20for=20all=20platforms=20of=20d=E2=80=A6=20(#33176)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Make sure entity platform services work for all platforms of domain * Register a bad service handler * Fix cleaning up * Tiny cleanup --- homeassistant/helpers/entity_component.py | 8 +++++- homeassistant/helpers/entity_platform.py | 35 +++++++++++++++++------ tests/helpers/test_entity_platform.py | 35 +++++++++++++++++++++++ 3 files changed, 69 insertions(+), 9 deletions(-) diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 71c57dc13f1..a761273fd25 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -254,7 +254,13 @@ class EntityComponent: This method must be run in the event loop. """ - tasks = [platform.async_reset() for platform in self._platforms.values()] + tasks = [] + + for key, platform in self._platforms.items(): + if key == self.domain: + tasks.append(platform.async_reset()) + else: + tasks.append(platform.async_destroy()) if tasks: await asyncio.wait(tasks) diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 0c288b0ad21..0aebaff14de 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: SLOW_SETUP_WARNING = 10 SLOW_SETUP_MAX_WAIT = 60 PLATFORM_NOT_READY_RETRIES = 10 +DATA_ENTITY_PLATFORM = "entity_platform" class EntityPlatform: @@ -57,15 +58,15 @@ class EntityPlatform: self._async_cancel_retry_setup: Optional[CALLBACK_TYPE] = None self._process_updates: Optional[asyncio.Lock] = None + self.parallel_updates: Optional[asyncio.Semaphore] = None + # Platform is None for the EntityComponent "catch-all" EntityPlatform # which powers entity_component.add_entities - if platform is None: - self.parallel_updates_created = True - self.parallel_updates: Optional[asyncio.Semaphore] = None - return + self.parallel_updates_created = platform is None - self.parallel_updates_created = False - self.parallel_updates = None + hass.data.setdefault(DATA_ENTITY_PLATFORM, {}).setdefault( + self.platform_name, [] + ).append(self) @callback def _get_parallel_updates_semaphore( @@ -464,6 +465,14 @@ class EntityPlatform: self._async_unsub_polling() self._async_unsub_polling = None + async def async_destroy(self) -> None: + """Destroy an entity platform. + + Call before discarding the object. + """ + await self.async_reset() + self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name].remove(self) + async def async_remove_entity(self, entity_id: str) -> None: """Remove entity id from platform.""" await self.entities[entity_id].async_remove() @@ -488,14 +497,24 @@ class EntityPlatform: @callback def async_register_entity_service(self, name, schema, func, required_features=None): - """Register an entity service.""" + """Register an entity service. + + Services will automatically be shared by all platforms of the same domain. + """ + if self.hass.services.has_service(self.platform_name, name): + return + if isinstance(schema, dict): schema = cv.make_entity_service_schema(schema) async def handle_service(call): """Handle the service.""" await service.entity_service_call( - self.hass, [self], func, call, required_features + self.hass, + self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name], + func, + call, + required_features, ) self.hass.services.async_register( diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index d9cbbb31561..199284c680b 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -8,6 +8,7 @@ import asynctest import pytest from homeassistant.const import UNIT_PERCENTAGE +from homeassistant.core import callback from homeassistant.exceptions import PlatformNotReady from homeassistant.helpers import entity_platform, entity_registry from homeassistant.helpers.entity import async_generate_entity_id @@ -847,3 +848,37 @@ async def test_platform_with_no_setup(hass, caplog): "The mock-platform platform for the mock-integration integration does not support platform setup." in caplog.text ) + + +async def test_platforms_sharing_services(hass): + """Test platforms share services.""" + entity_platform1 = MockEntityPlatform( + hass, domain="mock_integration", platform_name="mock_platform", platform=None + ) + entity1 = MockEntity(entity_id="mock_integration.entity_1") + await entity_platform1.async_add_entities([entity1]) + + entity_platform2 = MockEntityPlatform( + hass, domain="mock_integration", platform_name="mock_platform", platform=None + ) + entity2 = MockEntity(entity_id="mock_integration.entity_2") + await entity_platform2.async_add_entities([entity2]) + + entities = [] + + @callback + def handle_service(entity, data): + entities.append(entity) + + entity_platform1.async_register_entity_service("hello", {}, handle_service) + entity_platform2.async_register_entity_service( + "hello", {}, Mock(side_effect=AssertionError("Should not be called")) + ) + + await hass.services.async_call( + "mock_platform", "hello", {"entity_id": "all"}, blocking=True + ) + + assert len(entities) == 2 + assert entity1 in entities + assert entity2 in entities