mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Use entity.async_request_call in service helper (#31454)
* Use entity.async_request_call in service helper * Clean up semaphore handling * Address comments * Simplify call entity service helper * Fix stupid rflink test
This commit is contained in:
parent
2c439af165
commit
e970177eeb
@ -23,6 +23,8 @@ from . import (
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PARALLEL_UPDATES = 0
|
||||
|
||||
TYPE_STANDARD = "standard"
|
||||
TYPE_INVERTED = "inverted"
|
||||
|
||||
|
@ -31,6 +31,8 @@ from . import (
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PARALLEL_UPDATES = 0
|
||||
|
||||
TYPE_DIMMABLE = "dimmable"
|
||||
TYPE_SWITCHABLE = "switchable"
|
||||
TYPE_HYBRID = "hybrid"
|
||||
|
@ -22,6 +22,8 @@ from . import (
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PARALLEL_UPDATES = 0
|
||||
|
||||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
||||
{
|
||||
vol.Optional(
|
||||
|
@ -568,7 +568,6 @@ class Entity(ABC):
|
||||
# call an requests
|
||||
async def async_request_call(self, coro):
|
||||
"""Process request batched."""
|
||||
|
||||
if self.parallel_updates:
|
||||
await self.parallel_updates.acquire()
|
||||
|
||||
|
@ -62,22 +62,42 @@ class EntityPlatform:
|
||||
# Platform is None for the EntityComponent "catch-all" EntityPlatform
|
||||
# which powers entity_component.add_entities
|
||||
if platform is None:
|
||||
self.parallel_updates = None
|
||||
self.parallel_updates_semaphore: Optional[asyncio.Semaphore] = None
|
||||
self.parallel_updates_created = True
|
||||
self.parallel_updates: Optional[asyncio.Semaphore] = None
|
||||
return
|
||||
|
||||
self.parallel_updates = getattr(platform, "PARALLEL_UPDATES", None)
|
||||
# semaphore will be created on demand
|
||||
self.parallel_updates_semaphore = None
|
||||
self.parallel_updates_created = False
|
||||
self.parallel_updates = None
|
||||
|
||||
def _get_parallel_updates_semaphore(self) -> asyncio.Semaphore:
|
||||
"""Get or create a semaphore for parallel updates."""
|
||||
if self.parallel_updates_semaphore is None:
|
||||
self.parallel_updates_semaphore = asyncio.Semaphore(
|
||||
self.parallel_updates if self.parallel_updates else 1,
|
||||
loop=self.hass.loop,
|
||||
)
|
||||
return self.parallel_updates_semaphore
|
||||
@callback
|
||||
def _get_parallel_updates_semaphore(
|
||||
self, entity_has_async_update: bool
|
||||
) -> Optional[asyncio.Semaphore]:
|
||||
"""Get or create a semaphore for parallel updates.
|
||||
|
||||
Semaphore will be created on demand because we base it off if update method is async or not.
|
||||
|
||||
If parallel updates is set to 0, we skip the semaphore.
|
||||
If parallel updates is set to a number, we initialize the semaphore to that number.
|
||||
Default for entities with `async_update` method is 1. Otherwise it's 0.
|
||||
"""
|
||||
if self.parallel_updates_created:
|
||||
return self.parallel_updates
|
||||
|
||||
self.parallel_updates_created = True
|
||||
|
||||
parallel_updates = getattr(self.platform, "PARALLEL_UPDATES", None)
|
||||
|
||||
if parallel_updates is None and not entity_has_async_update:
|
||||
parallel_updates = 1
|
||||
|
||||
if parallel_updates == 0:
|
||||
parallel_updates = None
|
||||
|
||||
if parallel_updates is not None:
|
||||
self.parallel_updates = asyncio.Semaphore(parallel_updates)
|
||||
|
||||
return self.parallel_updates
|
||||
|
||||
async def async_setup(self, platform_config, discovery_info=None):
|
||||
"""Set up the platform from a config file."""
|
||||
@ -282,21 +302,9 @@ class EntityPlatform:
|
||||
|
||||
entity.hass = self.hass
|
||||
entity.platform = self
|
||||
|
||||
# Async entity
|
||||
# PARALLEL_UPDATES == None: entity.parallel_updates = None
|
||||
# PARALLEL_UPDATES == 0: entity.parallel_updates = None
|
||||
# PARALLEL_UPDATES > 0: entity.parallel_updates = Semaphore(p)
|
||||
# Sync entity
|
||||
# PARALLEL_UPDATES == None: entity.parallel_updates = Semaphore(1)
|
||||
# PARALLEL_UPDATES == 0: entity.parallel_updates = None
|
||||
# PARALLEL_UPDATES > 0: entity.parallel_updates = Semaphore(p)
|
||||
if hasattr(entity, "async_update") and not self.parallel_updates:
|
||||
entity.parallel_updates = None
|
||||
elif not hasattr(entity, "async_update") and self.parallel_updates == 0:
|
||||
entity.parallel_updates = None
|
||||
else:
|
||||
entity.parallel_updates = self._get_parallel_updates_semaphore()
|
||||
entity.parallel_updates = self._get_parallel_updates_semaphore(
|
||||
hasattr(entity, "async_update")
|
||||
)
|
||||
|
||||
# Update properties before we generate the entity_id
|
||||
if update_before_add:
|
||||
|
@ -316,16 +316,15 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
|
||||
|
||||
# Check the permissions
|
||||
|
||||
# A list with for each platform in platforms a list of entities to call
|
||||
# the service on.
|
||||
platforms_entities = []
|
||||
# A list with entities to call the service on.
|
||||
entity_candidates = []
|
||||
|
||||
if entity_perms is None:
|
||||
for platform in platforms:
|
||||
if target_all_entities:
|
||||
platforms_entities.append(list(platform.entities.values()))
|
||||
entity_candidates.extend(platform.entities.values())
|
||||
else:
|
||||
platforms_entities.append(
|
||||
entity_candidates.extend(
|
||||
[
|
||||
entity
|
||||
for entity in platform.entities.values()
|
||||
@ -337,7 +336,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
|
||||
# If we target all entities, we will select all entities the user
|
||||
# is allowed to control.
|
||||
for platform in platforms:
|
||||
platforms_entities.append(
|
||||
entity_candidates.extend(
|
||||
[
|
||||
entity
|
||||
for entity in platform.entities.values()
|
||||
@ -362,39 +361,20 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
|
||||
|
||||
platform_entities.append(entity)
|
||||
|
||||
platforms_entities.append(platform_entities)
|
||||
entity_candidates.extend(platform_entities)
|
||||
|
||||
if not target_all_entities:
|
||||
for platform_entities in platforms_entities:
|
||||
for entity in platform_entities:
|
||||
entity_ids.remove(entity.entity_id)
|
||||
for entity in entity_candidates:
|
||||
entity_ids.remove(entity.entity_id)
|
||||
|
||||
if entity_ids:
|
||||
_LOGGER.warning(
|
||||
"Unable to find referenced entities %s", ", ".join(sorted(entity_ids))
|
||||
)
|
||||
|
||||
tasks = [
|
||||
_handle_service_platform_call(
|
||||
hass, func, data, entities, call.context, required_features
|
||||
)
|
||||
for platform, entities in zip(platforms, platforms_entities)
|
||||
]
|
||||
entities = []
|
||||
|
||||
if tasks:
|
||||
done, pending = await asyncio.wait(tasks)
|
||||
assert not pending
|
||||
for future in done:
|
||||
future.result() # pop exception if have
|
||||
|
||||
|
||||
async def _handle_service_platform_call(
|
||||
hass, func, data, entities, context, required_features
|
||||
):
|
||||
"""Handle a function call."""
|
||||
tasks = []
|
||||
|
||||
for entity in entities:
|
||||
for entity in entity_candidates:
|
||||
if not entity.available:
|
||||
continue
|
||||
|
||||
@ -404,27 +384,33 @@ async def _handle_service_platform_call(
|
||||
):
|
||||
continue
|
||||
|
||||
entity.async_set_context(context)
|
||||
entities.append(entity)
|
||||
|
||||
if isinstance(func, str):
|
||||
result = hass.async_add_job(partial(getattr(entity, func), **data))
|
||||
else:
|
||||
result = hass.async_add_job(func, entity, data)
|
||||
if not entities:
|
||||
return
|
||||
|
||||
# Guard because callback functions do not return a task when passed to async_add_job.
|
||||
if result is not None:
|
||||
result = await result
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
_LOGGER.error(
|
||||
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.",
|
||||
func,
|
||||
entity.entity_id,
|
||||
done, pending = await asyncio.wait(
|
||||
[
|
||||
entity.async_request_call(
|
||||
_handle_entity_call(hass, entity, func, data, call.context)
|
||||
)
|
||||
await result
|
||||
for entity in entities
|
||||
]
|
||||
)
|
||||
assert not pending
|
||||
for future in done:
|
||||
future.result() # pop exception if have
|
||||
|
||||
if entity.should_poll:
|
||||
tasks.append(entity.async_update_ha_state(True))
|
||||
tasks = []
|
||||
|
||||
for entity in entities:
|
||||
if not entity.should_poll:
|
||||
continue
|
||||
|
||||
# Context expires if the turn on commands took a long time.
|
||||
# Set context again so it's there when we update
|
||||
entity.async_set_context(call.context)
|
||||
tasks.append(entity.async_update_ha_state(True))
|
||||
|
||||
if tasks:
|
||||
done, pending = await asyncio.wait(tasks)
|
||||
@ -433,6 +419,28 @@ async def _handle_service_platform_call(
|
||||
future.result() # pop exception if have
|
||||
|
||||
|
||||
async def _handle_entity_call(hass, entity, func, data, context):
|
||||
"""Handle calling service method."""
|
||||
entity.async_set_context(context)
|
||||
|
||||
if isinstance(func, str):
|
||||
result = hass.async_add_job(partial(getattr(entity, func), **data))
|
||||
else:
|
||||
result = hass.async_add_job(func, entity, data)
|
||||
|
||||
# Guard because callback functions do not return a task when passed to async_add_job.
|
||||
if result is not None:
|
||||
await result
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
_LOGGER.error(
|
||||
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.",
|
||||
func,
|
||||
entity.entity_id,
|
||||
)
|
||||
await result
|
||||
|
||||
|
||||
@bind_hass
|
||||
@ha.callback
|
||||
def async_register_admin_service(
|
||||
@ -474,6 +482,7 @@ def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable:
|
||||
return await service_handler(call)
|
||||
|
||||
user = await hass.auth.async_get_user(call.context.user_id)
|
||||
|
||||
if user is None:
|
||||
raise UnknownUser(
|
||||
context=call.context,
|
||||
@ -482,14 +491,12 @@ def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable:
|
||||
)
|
||||
|
||||
reg = await hass.helpers.entity_registry.async_get_registry()
|
||||
entities = [
|
||||
entity.entity_id
|
||||
for entity in reg.entities.values()
|
||||
if entity.platform == domain
|
||||
]
|
||||
|
||||
for entity_id in entities:
|
||||
if user.permissions.check_entity(entity_id, POLICY_CONTROL):
|
||||
for entity in reg.entities.values():
|
||||
if entity.platform != domain:
|
||||
continue
|
||||
|
||||
if user.permissions.check_entity(entity.entity_id, POLICY_CONTROL):
|
||||
return await service_handler(call)
|
||||
|
||||
raise Unauthorized(
|
||||
|
@ -270,8 +270,6 @@ async def test_parallel_updates_async_platform_with_constant(hass):
|
||||
|
||||
handle = list(component._platforms.values())[-1]
|
||||
|
||||
assert handle.parallel_updates == 2
|
||||
|
||||
class AsyncEntity(MockEntity):
|
||||
"""Mock entity that has async_update."""
|
||||
|
||||
@ -296,7 +294,6 @@ async def test_parallel_updates_sync_platform(hass):
|
||||
await component.async_setup({DOMAIN: {"platform": "platform"}})
|
||||
|
||||
handle = list(component._platforms.values())[-1]
|
||||
assert handle.parallel_updates is None
|
||||
|
||||
class SyncEntity(MockEntity):
|
||||
"""Mock entity that has update."""
|
||||
@ -323,7 +320,6 @@ async def test_parallel_updates_sync_platform_with_constant(hass):
|
||||
await component.async_setup({DOMAIN: {"platform": "platform"}})
|
||||
|
||||
handle = list(component._platforms.values())[-1]
|
||||
assert handle.parallel_updates == 2
|
||||
|
||||
class SyncEntity(MockEntity):
|
||||
"""Mock entity that has update."""
|
||||
|
@ -39,31 +39,29 @@ from tests.common import (
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_service_platform_call():
|
||||
def mock_handle_entity_call():
|
||||
"""Mock service platform call."""
|
||||
with patch(
|
||||
"homeassistant.helpers.service._handle_service_platform_call",
|
||||
"homeassistant.helpers.service._handle_entity_call",
|
||||
side_effect=lambda *args: mock_coro(),
|
||||
) as mock_call:
|
||||
yield mock_call
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_entities():
|
||||
def mock_entities(hass):
|
||||
"""Return mock entities in an ordered dict."""
|
||||
kitchen = Mock(
|
||||
kitchen = MockEntity(
|
||||
entity_id="light.kitchen",
|
||||
available=True,
|
||||
should_poll=False,
|
||||
supported_features=1,
|
||||
platform="test_domain",
|
||||
)
|
||||
living_room = Mock(
|
||||
living_room = MockEntity(
|
||||
entity_id="light.living_room",
|
||||
available=True,
|
||||
should_poll=False,
|
||||
supported_features=0,
|
||||
platform="test_domain",
|
||||
)
|
||||
entities = OrderedDict()
|
||||
entities[kitchen.entity_id] = kitchen
|
||||
@ -374,7 +372,7 @@ async def test_call_context_user_not_exist(hass):
|
||||
assert err.value.context.user_id == "non-existing"
|
||||
|
||||
|
||||
async def test_call_context_target_all(hass, mock_service_platform_call, mock_entities):
|
||||
async def test_call_context_target_all(hass, mock_handle_entity_call, mock_entities):
|
||||
"""Check we only target allowed entities if targeting all."""
|
||||
with patch(
|
||||
"homeassistant.auth.AuthManager.async_get_user",
|
||||
@ -398,13 +396,12 @@ async def test_call_context_target_all(hass, mock_service_platform_call, mock_en
|
||||
),
|
||||
)
|
||||
|
||||
assert len(mock_service_platform_call.mock_calls) == 1
|
||||
entities = mock_service_platform_call.mock_calls[0][1][3]
|
||||
assert entities == [mock_entities["light.kitchen"]]
|
||||
assert len(mock_handle_entity_call.mock_calls) == 1
|
||||
assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen"
|
||||
|
||||
|
||||
async def test_call_context_target_specific(
|
||||
hass, mock_service_platform_call, mock_entities
|
||||
hass, mock_handle_entity_call, mock_entities
|
||||
):
|
||||
"""Check targeting specific entities."""
|
||||
with patch(
|
||||
@ -429,13 +426,12 @@ async def test_call_context_target_specific(
|
||||
),
|
||||
)
|
||||
|
||||
assert len(mock_service_platform_call.mock_calls) == 1
|
||||
entities = mock_service_platform_call.mock_calls[0][1][3]
|
||||
assert entities == [mock_entities["light.kitchen"]]
|
||||
assert len(mock_handle_entity_call.mock_calls) == 1
|
||||
assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen"
|
||||
|
||||
|
||||
async def test_call_context_target_specific_no_auth(
|
||||
hass, mock_service_platform_call, mock_entities
|
||||
hass, mock_handle_entity_call, mock_entities
|
||||
):
|
||||
"""Check targeting specific entities without auth."""
|
||||
with pytest.raises(exceptions.Unauthorized) as err:
|
||||
@ -459,9 +455,7 @@ async def test_call_context_target_specific_no_auth(
|
||||
assert err.value.entity_id == "light.kitchen"
|
||||
|
||||
|
||||
async def test_call_no_context_target_all(
|
||||
hass, mock_service_platform_call, mock_entities
|
||||
):
|
||||
async def test_call_no_context_target_all(hass, mock_handle_entity_call, mock_entities):
|
||||
"""Check we target all if no user context given."""
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
@ -472,13 +466,14 @@ async def test_call_no_context_target_all(
|
||||
),
|
||||
)
|
||||
|
||||
assert len(mock_service_platform_call.mock_calls) == 1
|
||||
entities = mock_service_platform_call.mock_calls[0][1][3]
|
||||
assert entities == list(mock_entities.values())
|
||||
assert len(mock_handle_entity_call.mock_calls) == 2
|
||||
assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list(
|
||||
mock_entities.values()
|
||||
)
|
||||
|
||||
|
||||
async def test_call_no_context_target_specific(
|
||||
hass, mock_service_platform_call, mock_entities
|
||||
hass, mock_handle_entity_call, mock_entities
|
||||
):
|
||||
"""Check we can target specified entities."""
|
||||
await service.entity_service_call(
|
||||
@ -492,13 +487,12 @@ async def test_call_no_context_target_specific(
|
||||
),
|
||||
)
|
||||
|
||||
assert len(mock_service_platform_call.mock_calls) == 1
|
||||
entities = mock_service_platform_call.mock_calls[0][1][3]
|
||||
assert entities == [mock_entities["light.kitchen"]]
|
||||
assert len(mock_handle_entity_call.mock_calls) == 1
|
||||
assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen"
|
||||
|
||||
|
||||
async def test_call_with_match_all(
|
||||
hass, mock_service_platform_call, mock_entities, caplog
|
||||
hass, mock_handle_entity_call, mock_entities, caplog
|
||||
):
|
||||
"""Check we only target allowed entities if targeting all."""
|
||||
await service.entity_service_call(
|
||||
@ -508,20 +502,13 @@ async def test_call_with_match_all(
|
||||
ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
|
||||
)
|
||||
|
||||
assert len(mock_service_platform_call.mock_calls) == 1
|
||||
entities = mock_service_platform_call.mock_calls[0][1][3]
|
||||
assert entities == [
|
||||
mock_entities["light.kitchen"],
|
||||
mock_entities["light.living_room"],
|
||||
]
|
||||
assert (
|
||||
"Not passing an entity ID to a service to target all entities is deprecated"
|
||||
) not in caplog.text
|
||||
assert len(mock_handle_entity_call.mock_calls) == 2
|
||||
assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list(
|
||||
mock_entities.values()
|
||||
)
|
||||
|
||||
|
||||
async def test_call_with_omit_entity_id(
|
||||
hass, mock_service_platform_call, mock_entities
|
||||
):
|
||||
async def test_call_with_omit_entity_id(hass, mock_handle_entity_call, mock_entities):
|
||||
"""Check service call if we do not pass an entity ID."""
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
@ -530,9 +517,7 @@ async def test_call_with_omit_entity_id(
|
||||
ha.ServiceCall("test_domain", "test_service"),
|
||||
)
|
||||
|
||||
assert len(mock_service_platform_call.mock_calls) == 1
|
||||
entities = mock_service_platform_call.mock_calls[0][1][3]
|
||||
assert entities == []
|
||||
assert len(mock_handle_entity_call.mock_calls) == 0
|
||||
|
||||
|
||||
async def test_register_admin_service(hass, hass_read_only_user, hass_admin_user):
|
||||
@ -644,96 +629,113 @@ async def test_domain_control_unknown(hass, mock_entities):
|
||||
assert len(calls) == 0
|
||||
|
||||
|
||||
async def test_domain_control_unauthorized(hass, hass_read_only_user, mock_entities):
|
||||
async def test_domain_control_unauthorized(hass, hass_read_only_user):
|
||||
"""Test domain verification in a service call with an unauthorized user."""
|
||||
calls = []
|
||||
|
||||
async def mock_service_log(call):
|
||||
"""Define a protected service."""
|
||||
calls.append(call)
|
||||
|
||||
with patch(
|
||||
"homeassistant.helpers.entity_registry.async_get_registry",
|
||||
return_value=mock_coro(Mock(entities=mock_entities)),
|
||||
):
|
||||
protected_mock_service = hass.helpers.service.verify_domain_control(
|
||||
"test_domain"
|
||||
)(mock_service_log)
|
||||
|
||||
hass.services.async_register(
|
||||
"test_domain", "test_service", protected_mock_service, schema=None
|
||||
)
|
||||
|
||||
with pytest.raises(exceptions.Unauthorized):
|
||||
await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
{},
|
||||
blocking=True,
|
||||
context=ha.Context(user_id=hass_read_only_user.id),
|
||||
mock_registry(
|
||||
hass,
|
||||
{
|
||||
"light.kitchen": ent_reg.RegistryEntry(
|
||||
entity_id="light.kitchen", unique_id="kitchen", platform="test_domain",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
calls = []
|
||||
|
||||
async def mock_service_log(call):
|
||||
"""Define a protected service."""
|
||||
calls.append(call)
|
||||
|
||||
protected_mock_service = hass.helpers.service.verify_domain_control("test_domain")(
|
||||
mock_service_log
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
"test_domain", "test_service", protected_mock_service, schema=None
|
||||
)
|
||||
|
||||
with pytest.raises(exceptions.Unauthorized):
|
||||
await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
{},
|
||||
blocking=True,
|
||||
context=ha.Context(user_id=hass_read_only_user.id),
|
||||
)
|
||||
|
||||
assert len(calls) == 0
|
||||
|
||||
|
||||
async def test_domain_control_admin(hass, hass_admin_user, mock_entities):
|
||||
async def test_domain_control_admin(hass, hass_admin_user):
|
||||
"""Test domain verification in a service call with an admin user."""
|
||||
mock_registry(
|
||||
hass,
|
||||
{
|
||||
"light.kitchen": ent_reg.RegistryEntry(
|
||||
entity_id="light.kitchen", unique_id="kitchen", platform="test_domain",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
calls = []
|
||||
|
||||
async def mock_service_log(call):
|
||||
"""Define a protected service."""
|
||||
calls.append(call)
|
||||
|
||||
with patch(
|
||||
"homeassistant.helpers.entity_registry.async_get_registry",
|
||||
return_value=mock_coro(Mock(entities=mock_entities)),
|
||||
):
|
||||
protected_mock_service = hass.helpers.service.verify_domain_control(
|
||||
"test_domain"
|
||||
)(mock_service_log)
|
||||
protected_mock_service = hass.helpers.service.verify_domain_control("test_domain")(
|
||||
mock_service_log
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
"test_domain", "test_service", protected_mock_service, schema=None
|
||||
)
|
||||
hass.services.async_register(
|
||||
"test_domain", "test_service", protected_mock_service, schema=None
|
||||
)
|
||||
|
||||
await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
{},
|
||||
blocking=True,
|
||||
context=ha.Context(user_id=hass_admin_user.id),
|
||||
)
|
||||
await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
{},
|
||||
blocking=True,
|
||||
context=ha.Context(user_id=hass_admin_user.id),
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
async def test_domain_control_no_user(hass, mock_entities):
|
||||
async def test_domain_control_no_user(hass):
|
||||
"""Test domain verification in a service call with no user."""
|
||||
mock_registry(
|
||||
hass,
|
||||
{
|
||||
"light.kitchen": ent_reg.RegistryEntry(
|
||||
entity_id="light.kitchen", unique_id="kitchen", platform="test_domain",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
calls = []
|
||||
|
||||
async def mock_service_log(call):
|
||||
"""Define a protected service."""
|
||||
calls.append(call)
|
||||
|
||||
with patch(
|
||||
"homeassistant.helpers.entity_registry.async_get_registry",
|
||||
return_value=mock_coro(Mock(entities=mock_entities)),
|
||||
):
|
||||
protected_mock_service = hass.helpers.service.verify_domain_control(
|
||||
"test_domain"
|
||||
)(mock_service_log)
|
||||
protected_mock_service = hass.helpers.service.verify_domain_control("test_domain")(
|
||||
mock_service_log
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
"test_domain", "test_service", protected_mock_service, schema=None
|
||||
)
|
||||
hass.services.async_register(
|
||||
"test_domain", "test_service", protected_mock_service, schema=None
|
||||
)
|
||||
|
||||
await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
{},
|
||||
blocking=True,
|
||||
context=ha.Context(user_id=None),
|
||||
)
|
||||
await hass.services.async_call(
|
||||
"test_domain",
|
||||
"test_service",
|
||||
{},
|
||||
blocking=True,
|
||||
context=ha.Context(user_id=None),
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
async def test_extract_from_service_available_device(hass):
|
||||
|
Loading…
x
Reference in New Issue
Block a user