mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 03:07:37 +00:00
Make sure entity platform services work for all platforms of d… (#33176)
* Make sure entity platform services work for all platforms of domain * Register a bad service handler * Fix cleaning up * Tiny cleanup
This commit is contained in:
parent
2360fd4141
commit
1ff245d9c2
@ -254,7 +254,13 @@ class EntityComponent:
|
|||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
tasks = [platform.async_reset() for platform in self._platforms.values()]
|
tasks = []
|
||||||
|
|
||||||
|
for key, platform in self._platforms.items():
|
||||||
|
if key == self.domain:
|
||||||
|
tasks.append(platform.async_reset())
|
||||||
|
else:
|
||||||
|
tasks.append(platform.async_destroy())
|
||||||
|
|
||||||
if tasks:
|
if tasks:
|
||||||
await asyncio.wait(tasks)
|
await asyncio.wait(tasks)
|
||||||
|
@ -24,6 +24,7 @@ if TYPE_CHECKING:
|
|||||||
SLOW_SETUP_WARNING = 10
|
SLOW_SETUP_WARNING = 10
|
||||||
SLOW_SETUP_MAX_WAIT = 60
|
SLOW_SETUP_MAX_WAIT = 60
|
||||||
PLATFORM_NOT_READY_RETRIES = 10
|
PLATFORM_NOT_READY_RETRIES = 10
|
||||||
|
DATA_ENTITY_PLATFORM = "entity_platform"
|
||||||
|
|
||||||
|
|
||||||
class EntityPlatform:
|
class EntityPlatform:
|
||||||
@ -57,15 +58,15 @@ class EntityPlatform:
|
|||||||
self._async_cancel_retry_setup: Optional[CALLBACK_TYPE] = None
|
self._async_cancel_retry_setup: Optional[CALLBACK_TYPE] = None
|
||||||
self._process_updates: Optional[asyncio.Lock] = None
|
self._process_updates: Optional[asyncio.Lock] = None
|
||||||
|
|
||||||
|
self.parallel_updates: Optional[asyncio.Semaphore] = None
|
||||||
|
|
||||||
# Platform is None for the EntityComponent "catch-all" EntityPlatform
|
# Platform is None for the EntityComponent "catch-all" EntityPlatform
|
||||||
# which powers entity_component.add_entities
|
# which powers entity_component.add_entities
|
||||||
if platform is None:
|
self.parallel_updates_created = platform is None
|
||||||
self.parallel_updates_created = True
|
|
||||||
self.parallel_updates: Optional[asyncio.Semaphore] = None
|
|
||||||
return
|
|
||||||
|
|
||||||
self.parallel_updates_created = False
|
hass.data.setdefault(DATA_ENTITY_PLATFORM, {}).setdefault(
|
||||||
self.parallel_updates = None
|
self.platform_name, []
|
||||||
|
).append(self)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _get_parallel_updates_semaphore(
|
def _get_parallel_updates_semaphore(
|
||||||
@ -464,6 +465,14 @@ class EntityPlatform:
|
|||||||
self._async_unsub_polling()
|
self._async_unsub_polling()
|
||||||
self._async_unsub_polling = None
|
self._async_unsub_polling = None
|
||||||
|
|
||||||
|
async def async_destroy(self) -> None:
|
||||||
|
"""Destroy an entity platform.
|
||||||
|
|
||||||
|
Call before discarding the object.
|
||||||
|
"""
|
||||||
|
await self.async_reset()
|
||||||
|
self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name].remove(self)
|
||||||
|
|
||||||
async def async_remove_entity(self, entity_id: str) -> None:
|
async def async_remove_entity(self, entity_id: str) -> None:
|
||||||
"""Remove entity id from platform."""
|
"""Remove entity id from platform."""
|
||||||
await self.entities[entity_id].async_remove()
|
await self.entities[entity_id].async_remove()
|
||||||
@ -488,14 +497,24 @@ class EntityPlatform:
|
|||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_register_entity_service(self, name, schema, func, required_features=None):
|
def async_register_entity_service(self, name, schema, func, required_features=None):
|
||||||
"""Register an entity service."""
|
"""Register an entity service.
|
||||||
|
|
||||||
|
Services will automatically be shared by all platforms of the same domain.
|
||||||
|
"""
|
||||||
|
if self.hass.services.has_service(self.platform_name, name):
|
||||||
|
return
|
||||||
|
|
||||||
if isinstance(schema, dict):
|
if isinstance(schema, dict):
|
||||||
schema = cv.make_entity_service_schema(schema)
|
schema = cv.make_entity_service_schema(schema)
|
||||||
|
|
||||||
async def handle_service(call):
|
async def handle_service(call):
|
||||||
"""Handle the service."""
|
"""Handle the service."""
|
||||||
await service.entity_service_call(
|
await service.entity_service_call(
|
||||||
self.hass, [self], func, call, required_features
|
self.hass,
|
||||||
|
self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name],
|
||||||
|
func,
|
||||||
|
call,
|
||||||
|
required_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hass.services.async_register(
|
self.hass.services.async_register(
|
||||||
|
@ -8,6 +8,7 @@ import asynctest
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.const import UNIT_PERCENTAGE
|
from homeassistant.const import UNIT_PERCENTAGE
|
||||||
|
from homeassistant.core import callback
|
||||||
from homeassistant.exceptions import PlatformNotReady
|
from homeassistant.exceptions import PlatformNotReady
|
||||||
from homeassistant.helpers import entity_platform, entity_registry
|
from homeassistant.helpers import entity_platform, entity_registry
|
||||||
from homeassistant.helpers.entity import async_generate_entity_id
|
from homeassistant.helpers.entity import async_generate_entity_id
|
||||||
@ -847,3 +848,37 @@ async def test_platform_with_no_setup(hass, caplog):
|
|||||||
"The mock-platform platform for the mock-integration integration does not support platform setup."
|
"The mock-platform platform for the mock-integration integration does not support platform setup."
|
||||||
in caplog.text
|
in caplog.text
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_platforms_sharing_services(hass):
|
||||||
|
"""Test platforms share services."""
|
||||||
|
entity_platform1 = MockEntityPlatform(
|
||||||
|
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||||
|
)
|
||||||
|
entity1 = MockEntity(entity_id="mock_integration.entity_1")
|
||||||
|
await entity_platform1.async_add_entities([entity1])
|
||||||
|
|
||||||
|
entity_platform2 = MockEntityPlatform(
|
||||||
|
hass, domain="mock_integration", platform_name="mock_platform", platform=None
|
||||||
|
)
|
||||||
|
entity2 = MockEntity(entity_id="mock_integration.entity_2")
|
||||||
|
await entity_platform2.async_add_entities([entity2])
|
||||||
|
|
||||||
|
entities = []
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def handle_service(entity, data):
|
||||||
|
entities.append(entity)
|
||||||
|
|
||||||
|
entity_platform1.async_register_entity_service("hello", {}, handle_service)
|
||||||
|
entity_platform2.async_register_entity_service(
|
||||||
|
"hello", {}, Mock(side_effect=AssertionError("Should not be called"))
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
"mock_platform", "hello", {"entity_id": "all"}, blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(entities) == 2
|
||||||
|
assert entity1 in entities
|
||||||
|
assert entity2 in entities
|
||||||
|
Loading…
x
Reference in New Issue
Block a user