mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Fix entity services targeting entities outside the platform when using areas/devices (#109810)
This commit is contained in:
parent
674e4ceb2c
commit
09c609459d
@ -57,6 +57,7 @@ SLOW_ADD_MIN_TIMEOUT = 500
|
|||||||
PLATFORM_NOT_READY_RETRIES = 10
|
PLATFORM_NOT_READY_RETRIES = 10
|
||||||
DATA_ENTITY_PLATFORM = "entity_platform"
|
DATA_ENTITY_PLATFORM = "entity_platform"
|
||||||
DATA_DOMAIN_ENTITIES = "domain_entities"
|
DATA_DOMAIN_ENTITIES = "domain_entities"
|
||||||
|
DATA_DOMAIN_PLATFORM_ENTITIES = "domain_platform_entities"
|
||||||
PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds
|
PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds
|
||||||
|
|
||||||
_LOGGER = getLogger(__name__)
|
_LOGGER = getLogger(__name__)
|
||||||
@ -124,6 +125,8 @@ class EntityPlatform:
|
|||||||
self.scan_interval = scan_interval
|
self.scan_interval = scan_interval
|
||||||
self.entity_namespace = entity_namespace
|
self.entity_namespace = entity_namespace
|
||||||
self.config_entry: config_entries.ConfigEntry | None = None
|
self.config_entry: config_entries.ConfigEntry | None = None
|
||||||
|
# Storage for entities for this specific platform only
|
||||||
|
# which are indexed by entity_id
|
||||||
self.entities: dict[str, Entity] = {}
|
self.entities: dict[str, Entity] = {}
|
||||||
self.component_translations: dict[str, Any] = {}
|
self.component_translations: dict[str, Any] = {}
|
||||||
self.platform_translations: dict[str, Any] = {}
|
self.platform_translations: dict[str, Any] = {}
|
||||||
@ -145,9 +148,24 @@ class EntityPlatform:
|
|||||||
# which powers entity_component.add_entities
|
# which powers entity_component.add_entities
|
||||||
self.parallel_updates_created = platform is None
|
self.parallel_updates_created = platform is None
|
||||||
|
|
||||||
self.domain_entities: dict[str, Entity] = hass.data.setdefault(
|
# Storage for entities indexed by domain
|
||||||
|
# with the child dict indexed by entity_id
|
||||||
|
#
|
||||||
|
# This is usually media_player, light, switch, etc.
|
||||||
|
domain_entities: dict[str, dict[str, Entity]] = hass.data.setdefault(
|
||||||
DATA_DOMAIN_ENTITIES, {}
|
DATA_DOMAIN_ENTITIES, {}
|
||||||
).setdefault(domain, {})
|
)
|
||||||
|
self.domain_entities = domain_entities.setdefault(domain, {})
|
||||||
|
|
||||||
|
# Storage for entities indexed by domain and platform
|
||||||
|
# with the child dict indexed by entity_id
|
||||||
|
#
|
||||||
|
# This is usually media_player.yamaha, light.hue, switch.tplink, etc.
|
||||||
|
domain_platform_entities: dict[
|
||||||
|
tuple[str, str], dict[str, Entity]
|
||||||
|
] = hass.data.setdefault(DATA_DOMAIN_PLATFORM_ENTITIES, {})
|
||||||
|
key = (domain, platform_name)
|
||||||
|
self.domain_platform_entities = domain_platform_entities.setdefault(key, {})
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Represent an EntityPlatform."""
|
"""Represent an EntityPlatform."""
|
||||||
@ -743,6 +761,7 @@ class EntityPlatform:
|
|||||||
entity_id = entity.entity_id
|
entity_id = entity.entity_id
|
||||||
self.entities[entity_id] = entity
|
self.entities[entity_id] = entity
|
||||||
self.domain_entities[entity_id] = entity
|
self.domain_entities[entity_id] = entity
|
||||||
|
self.domain_platform_entities[entity_id] = entity
|
||||||
|
|
||||||
if not restored:
|
if not restored:
|
||||||
# Reserve the state in the state machine
|
# Reserve the state in the state machine
|
||||||
@ -756,6 +775,7 @@ class EntityPlatform:
|
|||||||
"""Remove entity from entities dict."""
|
"""Remove entity from entities dict."""
|
||||||
self.entities.pop(entity_id)
|
self.entities.pop(entity_id)
|
||||||
self.domain_entities.pop(entity_id)
|
self.domain_entities.pop(entity_id)
|
||||||
|
self.domain_platform_entities.pop(entity_id)
|
||||||
|
|
||||||
entity.async_on_remove(remove_entity_cb)
|
entity.async_on_remove(remove_entity_cb)
|
||||||
|
|
||||||
@ -852,7 +872,7 @@ class EntityPlatform:
|
|||||||
partial(
|
partial(
|
||||||
service.entity_service_call,
|
service.entity_service_call,
|
||||||
self.hass,
|
self.hass,
|
||||||
self.domain_entities,
|
self.domain_platform_entities,
|
||||||
service_func,
|
service_func,
|
||||||
required_features=required_features,
|
required_features=required_features,
|
||||||
),
|
),
|
||||||
|
@ -19,6 +19,7 @@ from homeassistant.core import (
|
|||||||
)
|
)
|
||||||
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
|
area_registry as ar,
|
||||||
device_registry as dr,
|
device_registry as dr,
|
||||||
entity_platform,
|
entity_platform,
|
||||||
entity_registry as er,
|
entity_registry as er,
|
||||||
@ -1628,6 +1629,87 @@ async def test_register_entity_service_response_data_multiple_matches_raises(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_entity_service_limited_to_matching_platforms(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
area_registry: ar.AreaRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test an entity services only targets entities for the platform and domain."""
|
||||||
|
|
||||||
|
mock_area = area_registry.async_get_or_create("mock_area")
|
||||||
|
|
||||||
|
entity1_entry = entity_registry.async_get_or_create(
|
||||||
|
"base_platform", "mock_platform", "1234", suggested_object_id="entity1"
|
||||||
|
)
|
||||||
|
entity_registry.async_update_entity(entity1_entry.entity_id, area_id=mock_area.id)
|
||||||
|
entity2_entry = entity_registry.async_get_or_create(
|
||||||
|
"base_platform", "mock_platform", "5678", suggested_object_id="entity2"
|
||||||
|
)
|
||||||
|
entity_registry.async_update_entity(entity2_entry.entity_id, area_id=mock_area.id)
|
||||||
|
entity3_entry = entity_registry.async_get_or_create(
|
||||||
|
"base_platform", "other_mock_platform", "7891", suggested_object_id="entity3"
|
||||||
|
)
|
||||||
|
entity_registry.async_update_entity(entity3_entry.entity_id, area_id=mock_area.id)
|
||||||
|
entity4_entry = entity_registry.async_get_or_create(
|
||||||
|
"base_platform", "other_mock_platform", "1433", suggested_object_id="entity4"
|
||||||
|
)
|
||||||
|
entity_registry.async_update_entity(entity4_entry.entity_id, area_id=mock_area.id)
|
||||||
|
|
||||||
|
async def generate_response(
|
||||||
|
target: MockEntity, call: ServiceCall
|
||||||
|
) -> ServiceResponse:
|
||||||
|
assert call.return_response
|
||||||
|
return {"response-key": f"response-value-{target.entity_id}"}
|
||||||
|
|
||||||
|
entity_platform = MockEntityPlatform(
|
||||||
|
hass, domain="base_platform", platform_name="mock_platform", platform=None
|
||||||
|
)
|
||||||
|
entity1 = MockEntity(
|
||||||
|
entity_id=entity1_entry.entity_id, unique_id=entity1_entry.unique_id
|
||||||
|
)
|
||||||
|
entity2 = MockEntity(
|
||||||
|
entity_id=entity2_entry.entity_id, unique_id=entity2_entry.unique_id
|
||||||
|
)
|
||||||
|
await entity_platform.async_add_entities([entity1, entity2])
|
||||||
|
|
||||||
|
other_entity_platform = MockEntityPlatform(
|
||||||
|
hass, domain="base_platform", platform_name="other_mock_platform", platform=None
|
||||||
|
)
|
||||||
|
entity3 = MockEntity(
|
||||||
|
entity_id=entity3_entry.entity_id, unique_id=entity3_entry.unique_id
|
||||||
|
)
|
||||||
|
entity4 = MockEntity(
|
||||||
|
entity_id=entity4_entry.entity_id, unique_id=entity4_entry.unique_id
|
||||||
|
)
|
||||||
|
await other_entity_platform.async_add_entities([entity3, entity4])
|
||||||
|
|
||||||
|
entity_platform.async_register_entity_service(
|
||||||
|
"hello",
|
||||||
|
{"some": str},
|
||||||
|
generate_response,
|
||||||
|
supports_response=SupportsResponse.ONLY,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_data = await hass.services.async_call(
|
||||||
|
"mock_platform",
|
||||||
|
"hello",
|
||||||
|
service_data={"some": "data"},
|
||||||
|
target={"area_id": [mock_area.id]},
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
# We should not target entity3 and entity4 even though they are in the area
|
||||||
|
# because they are only part of the domain and not the platform
|
||||||
|
assert response_data == {
|
||||||
|
"base_platform.entity1": {
|
||||||
|
"response-key": "response-value-base_platform.entity1"
|
||||||
|
},
|
||||||
|
"base_platform.entity2": {
|
||||||
|
"response-key": "response-value-base_platform.entity2"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_invalid_entity_id(hass: HomeAssistant) -> None:
|
async def test_invalid_entity_id(hass: HomeAssistant) -> None:
|
||||||
"""Test specifying an invalid entity id."""
|
"""Test specifying an invalid entity id."""
|
||||||
platform = MockEntityPlatform(hass)
|
platform = MockEntityPlatform(hass)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user