diff --git a/homeassistant/components/esphome/manager.py b/homeassistant/components/esphome/manager.py index b4ae1a1d0ad..1c0f82de4ae 100644 --- a/homeassistant/components/esphome/manager.py +++ b/homeassistant/components/esphome/manager.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from collections.abc import Coroutine +from functools import partial import logging from typing import TYPE_CHECKING, Any, NamedTuple @@ -456,12 +457,10 @@ class ESPHomeManager: self.device_id = _async_setup_device_registry(hass, entry, entry_data) entry_data.async_update_device_state(hass) - await asyncio.gather( - entry_data.async_update_static_infos( - hass, entry, entity_infos, device_info.mac_address - ), - _setup_services(hass, entry_data, services), + await entry_data.async_update_static_infos( + hass, entry, entity_infos, device_info.mac_address ) + _setup_services(hass, entry_data, services) setup_coros_with_disconnect_callbacks: list[ Coroutine[Any, Any, CALLBACK_TYPE] @@ -586,7 +585,7 @@ class ESPHomeManager: await entry_data.async_update_static_infos( hass, entry, infos, entry.unique_id.upper() ) - await _setup_services(hass, entry_data, services) + _setup_services(hass, entry_data, services) if entry_data.device_info is not None and entry_data.device_info.name: reconnect_logic.name = entry_data.device_info.name @@ -708,12 +707,27 @@ ARG_TYPE_METADATA = { } -async def _register_service( - hass: HomeAssistant, entry_data: RuntimeEntryData, service: UserService +async def execute_service( + entry_data: RuntimeEntryData, service: UserService, call: ServiceCall ) -> None: - if entry_data.device_info is None: - raise ValueError("Device Info needs to be fetched first") - service_name = f"{entry_data.device_info.name.replace('-', '_')}_{service.name}" + """Execute a service on a node.""" + await entry_data.client.execute_service(service, call.data) + + +def build_service_name(device_info: EsphomeDeviceInfo, service: UserService) -> str: + """Build a service name for a node.""" + return f"{device_info.name.replace('-', '_')}_{service.name}" + + +@callback +def _async_register_service( + hass: HomeAssistant, + entry_data: RuntimeEntryData, + device_info: EsphomeDeviceInfo, + service: UserService, +) -> None: + """Register a service on a node.""" + service_name = build_service_name(device_info, service) schema = {} fields = {} @@ -736,33 +750,36 @@ async def _register_service( "selector": metadata.selector, } - async def execute_service(call: ServiceCall) -> None: - await entry_data.client.execute_service(service, call.data) - hass.services.async_register( - DOMAIN, service_name, execute_service, vol.Schema(schema) + DOMAIN, + service_name, + partial(execute_service, entry_data, service), + vol.Schema(schema), + ) + async_set_service_schema( + hass, + DOMAIN, + service_name, + { + "description": ( + f"Calls the service {service.name} of the node {device_info.name}" + ), + "fields": fields, + }, ) - service_desc = { - "description": ( - f"Calls the service {service.name} of the node" - f" {entry_data.device_info.name}" - ), - "fields": fields, - } - async_set_service_schema(hass, DOMAIN, service_name, service_desc) - - -async def _setup_services( +@callback +def _setup_services( hass: HomeAssistant, entry_data: RuntimeEntryData, services: list[UserService] ) -> None: - if entry_data.device_info is None: + device_info = entry_data.device_info + if device_info is None: # Can happen if device has never connected or .storage cleared return old_services = entry_data.services.copy() - to_unregister = [] - to_register = [] + to_unregister: list[UserService] = [] + to_register: list[UserService] = [] for service in services: if service.key in old_services: # Already exists @@ -780,11 +797,11 @@ async def _setup_services( entry_data.services = {serv.key: serv for serv in services} for service in to_unregister: - service_name = f"{entry_data.device_info.name}_{service.name}" + service_name = build_service_name(device_info, service) hass.services.async_remove(DOMAIN, service_name) for service in to_register: - await _register_service(hass, entry_data, service) + _async_register_service(hass, entry_data, device_info, service) async def cleanup_instance(hass: HomeAssistant, entry: ConfigEntry) -> RuntimeEntryData: diff --git a/tests/components/esphome/test_manager.py b/tests/components/esphome/test_manager.py index 69ed653d75b..94820a03fc6 100644 --- a/tests/components/esphome/test_manager.py +++ b/tests/components/esphome/test_manager.py @@ -2,7 +2,15 @@ from collections.abc import Awaitable, Callable from unittest.mock import AsyncMock, call -from aioesphomeapi import APIClient, DeviceInfo, EntityInfo, EntityState, UserService +from aioesphomeapi import ( + APIClient, + DeviceInfo, + EntityInfo, + EntityState, + UserService, + UserServiceArg, + UserServiceArgType, +) import pytest from homeassistant import config_entries @@ -374,3 +382,221 @@ async def test_debug_logging( ) await hass.async_block_till_done() mock_client.set_debug.assert_has_calls([call(False)]) + + +async def test_esphome_device_with_dash_in_name_user_services( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test a device with user services and a dash in the name.""" + entity_info = [] + states = [] + service1 = UserService( + name="my_service", + key=1, + args=[ + UserServiceArg(name="arg1", type=UserServiceArgType.BOOL), + UserServiceArg(name="arg2", type=UserServiceArgType.INT), + UserServiceArg(name="arg3", type=UserServiceArgType.FLOAT), + UserServiceArg(name="arg4", type=UserServiceArgType.STRING), + UserServiceArg(name="arg5", type=UserServiceArgType.BOOL_ARRAY), + UserServiceArg(name="arg6", type=UserServiceArgType.INT_ARRAY), + UserServiceArg(name="arg7", type=UserServiceArgType.FLOAT_ARRAY), + UserServiceArg(name="arg8", type=UserServiceArgType.STRING_ARRAY), + ], + ) + service2 = UserService( + name="simple_service", + key=2, + args=[ + UserServiceArg(name="arg1", type=UserServiceArgType.BOOL), + ], + ) + device = await mock_esphome_device( + mock_client=mock_client, + entity_info=entity_info, + user_service=[service1, service2], + device_info={"name": "with-dash"}, + states=states, + ) + await hass.async_block_till_done() + assert hass.services.has_service(DOMAIN, "with_dash_my_service") + assert hass.services.has_service(DOMAIN, "with_dash_simple_service") + + await hass.services.async_call(DOMAIN, "with_dash_simple_service", {"arg1": True}) + await hass.async_block_till_done() + + mock_client.execute_service.assert_has_calls( + [ + call( + UserService( + name="simple_service", + key=2, + args=[UserServiceArg(name="arg1", type=UserServiceArgType.BOOL)], + ), + {"arg1": True}, + ) + ] + ) + mock_client.execute_service.reset_mock() + + # Verify the service can be removed + mock_client.list_entities_services = AsyncMock( + return_value=(entity_info, [service1]) + ) + await device.mock_disconnect(True) + await hass.async_block_till_done() + await device.mock_connect() + await hass.async_block_till_done() + assert hass.services.has_service(DOMAIN, "with_dash_my_service") + assert not hass.services.has_service(DOMAIN, "with_dash_simple_service") + + +async def test_esphome_user_services_ignores_invalid_arg_types( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test a device with user services and a dash in the name.""" + entity_info = [] + states = [] + service1 = UserService( + name="bad_service", + key=1, + args=[ + UserServiceArg(name="arg1", type="wrong"), + ], + ) + service2 = UserService( + name="simple_service", + key=2, + args=[ + UserServiceArg(name="arg1", type=UserServiceArgType.BOOL), + ], + ) + device = await mock_esphome_device( + mock_client=mock_client, + entity_info=entity_info, + user_service=[service1, service2], + device_info={"name": "with-dash"}, + states=states, + ) + await hass.async_block_till_done() + assert not hass.services.has_service(DOMAIN, "with_dash_bad_service") + assert hass.services.has_service(DOMAIN, "with_dash_simple_service") + + await hass.services.async_call(DOMAIN, "with_dash_simple_service", {"arg1": True}) + await hass.async_block_till_done() + + mock_client.execute_service.assert_has_calls( + [ + call( + UserService( + name="simple_service", + key=2, + args=[UserServiceArg(name="arg1", type=UserServiceArgType.BOOL)], + ), + {"arg1": True}, + ) + ] + ) + mock_client.execute_service.reset_mock() + + # Verify the service can be removed + mock_client.list_entities_services = AsyncMock( + return_value=(entity_info, [service2]) + ) + await device.mock_disconnect(True) + await hass.async_block_till_done() + await device.mock_connect() + await hass.async_block_till_done() + assert hass.services.has_service(DOMAIN, "with_dash_simple_service") + assert not hass.services.has_service(DOMAIN, "with_dash_bad_service") + + +async def test_esphome_user_services_changes( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test a device with user services that change arguments.""" + entity_info = [] + states = [] + service1 = UserService( + name="simple_service", + key=2, + args=[ + UserServiceArg(name="arg1", type=UserServiceArgType.BOOL), + ], + ) + device = await mock_esphome_device( + mock_client=mock_client, + entity_info=entity_info, + user_service=[service1], + device_info={"name": "with-dash"}, + states=states, + ) + await hass.async_block_till_done() + assert hass.services.has_service(DOMAIN, "with_dash_simple_service") + + await hass.services.async_call(DOMAIN, "with_dash_simple_service", {"arg1": True}) + await hass.async_block_till_done() + + mock_client.execute_service.assert_has_calls( + [ + call( + UserService( + name="simple_service", + key=2, + args=[UserServiceArg(name="arg1", type=UserServiceArgType.BOOL)], + ), + {"arg1": True}, + ) + ] + ) + mock_client.execute_service.reset_mock() + + new_service1 = UserService( + name="simple_service", + key=2, + args=[ + UserServiceArg(name="arg1", type=UserServiceArgType.FLOAT), + ], + ) + + # Verify the service can be updated + mock_client.list_entities_services = AsyncMock( + return_value=(entity_info, [new_service1]) + ) + await device.mock_disconnect(True) + await hass.async_block_till_done() + await device.mock_connect() + await hass.async_block_till_done() + assert hass.services.has_service(DOMAIN, "with_dash_simple_service") + + await hass.services.async_call(DOMAIN, "with_dash_simple_service", {"arg1": 4.5}) + await hass.async_block_till_done() + + mock_client.execute_service.assert_has_calls( + [ + call( + UserService( + name="simple_service", + key=2, + args=[UserServiceArg(name="arg1", type=UserServiceArgType.FLOAT)], + ), + {"arg1": 4.5}, + ) + ] + ) + mock_client.execute_service.reset_mock()