mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 09:47:52 +00:00
Make sure entity platform services work for all platforms of d… (#33176)
* Make sure entity platform services work for all platforms of domain * Register a bad service handler * Fix cleaning up * Tiny cleanup
This commit is contained in:
parent
2360fd4141
commit
1ff245d9c2
@ -254,7 +254,13 @@ class EntityComponent:
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
tasks = [platform.async_reset() for platform in self._platforms.values()]
|
||||
tasks = []
|
||||
|
||||
for key, platform in self._platforms.items():
|
||||
if key == self.domain:
|
||||
tasks.append(platform.async_reset())
|
||||
else:
|
||||
tasks.append(platform.async_destroy())
|
||||
|
||||
if tasks:
|
||||
await asyncio.wait(tasks)
|
||||
|
@ -24,6 +24,7 @@ if TYPE_CHECKING:
|
||||
SLOW_SETUP_WARNING = 10
|
||||
SLOW_SETUP_MAX_WAIT = 60
|
||||
PLATFORM_NOT_READY_RETRIES = 10
|
||||
DATA_ENTITY_PLATFORM = "entity_platform"
|
||||
|
||||
|
||||
class EntityPlatform:
|
||||
@ -57,15 +58,15 @@ class EntityPlatform:
|
||||
self._async_cancel_retry_setup: Optional[CALLBACK_TYPE] = None
|
||||
self._process_updates: Optional[asyncio.Lock] = None
|
||||
|
||||
self.parallel_updates: Optional[asyncio.Semaphore] = None
|
||||
|
||||
# Platform is None for the EntityComponent "catch-all" EntityPlatform
|
||||
# which powers entity_component.add_entities
|
||||
if platform is None:
|
||||
self.parallel_updates_created = True
|
||||
self.parallel_updates: Optional[asyncio.Semaphore] = None
|
||||
return
|
||||
self.parallel_updates_created = platform is None
|
||||
|
||||
self.parallel_updates_created = False
|
||||
self.parallel_updates = None
|
||||
hass.data.setdefault(DATA_ENTITY_PLATFORM, {}).setdefault(
|
||||
self.platform_name, []
|
||||
).append(self)
|
||||
|
||||
@callback
|
||||
def _get_parallel_updates_semaphore(
|
||||
@ -464,6 +465,14 @@ class EntityPlatform:
|
||||
self._async_unsub_polling()
|
||||
self._async_unsub_polling = None
|
||||
|
||||
async def async_destroy(self) -> None:
|
||||
"""Destroy an entity platform.
|
||||
|
||||
Call before discarding the object.
|
||||
"""
|
||||
await self.async_reset()
|
||||
self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name].remove(self)
|
||||
|
||||
async def async_remove_entity(self, entity_id: str) -> None:
|
||||
"""Remove entity id from platform."""
|
||||
await self.entities[entity_id].async_remove()
|
||||
@ -488,14 +497,24 @@ class EntityPlatform:
|
||||
|
||||
@callback
|
||||
def async_register_entity_service(self, name, schema, func, required_features=None):
|
||||
"""Register an entity service."""
|
||||
"""Register an entity service.
|
||||
|
||||
Services will automatically be shared by all platforms of the same domain.
|
||||
"""
|
||||
if self.hass.services.has_service(self.platform_name, name):
|
||||
return
|
||||
|
||||
if isinstance(schema, dict):
|
||||
schema = cv.make_entity_service_schema(schema)
|
||||
|
||||
async def handle_service(call):
|
||||
"""Handle the service."""
|
||||
await service.entity_service_call(
|
||||
self.hass, [self], func, call, required_features
|
||||
self.hass,
|
||||
self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name],
|
||||
func,
|
||||
call,
|
||||
required_features,
|
||||
)
|
||||
|
||||
self.hass.services.async_register(
|
||||
|
@ -8,6 +8,7 @@ import asynctest
|
||||
import pytest
|
||||
|
||||
from homeassistant.const import UNIT_PERCENTAGE
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import PlatformNotReady
|
||||
from homeassistant.helpers import entity_platform, entity_registry
|
||||
from homeassistant.helpers.entity import async_generate_entity_id
|
||||
@ -847,3 +848,37 @@ async def test_platform_with_no_setup(hass, caplog):
|
||||
"The mock-platform platform for the mock-integration integration does not support platform setup."
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
|
||||
async def test_platforms_sharing_services(hass):
|
||||
"""Test platforms share services."""
|
||||
entity_platform1 = MockEntityPlatform(
|
||||
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||
)
|
||||
entity1 = MockEntity(entity_id="mock_integration.entity_1")
|
||||
await entity_platform1.async_add_entities([entity1])
|
||||
|
||||
entity_platform2 = MockEntityPlatform(
|
||||
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||
)
|
||||
entity2 = MockEntity(entity_id="mock_integration.entity_2")
|
||||
await entity_platform2.async_add_entities([entity2])
|
||||
|
||||
entities = []
|
||||
|
||||
@callback
|
||||
def handle_service(entity, data):
|
||||
entities.append(entity)
|
||||
|
||||
entity_platform1.async_register_entity_service("hello", {}, handle_service)
|
||||
entity_platform2.async_register_entity_service(
|
||||
"hello", {}, Mock(side_effect=AssertionError("Should not be called"))
|
||||
)
|
||||
|
||||
await hass.services.async_call(
|
||||
"mock_platform", "hello", {"entity_id": "all"}, blocking=True
|
||||
)
|
||||
|
||||
assert len(entities) == 2
|
||||
assert entity1 in entities
|
||||
assert entity2 in entities
|
||||
|
Loading…
x
Reference in New Issue
Block a user