Fix ESPHome service removal when the device name contains a dash (#107015)

* Fix ESPHome service removal when the device name contains a dash

If the device name contains a dash the service name is mutated to
replace the dash with an underscore, but the remove function did
not do the same mutation so it would fail to remove the service

* add more coverage

* more cover
This commit is contained in:
J. Nick Koston 2024-01-03 14:47:49 -10:00 committed by GitHub
parent afcf8c9718
commit 01d0031e09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 275 additions and 32 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Coroutine
from functools import partial
import logging
from typing import TYPE_CHECKING, Any, NamedTuple
@ -456,12 +457,10 @@ class ESPHomeManager:
self.device_id = _async_setup_device_registry(hass, entry, entry_data)
entry_data.async_update_device_state(hass)
await asyncio.gather(
entry_data.async_update_static_infos(
hass, entry, entity_infos, device_info.mac_address
),
_setup_services(hass, entry_data, services),
await entry_data.async_update_static_infos(
hass, entry, entity_infos, device_info.mac_address
)
_setup_services(hass, entry_data, services)
setup_coros_with_disconnect_callbacks: list[
Coroutine[Any, Any, CALLBACK_TYPE]
@ -586,7 +585,7 @@ class ESPHomeManager:
await entry_data.async_update_static_infos(
hass, entry, infos, entry.unique_id.upper()
)
await _setup_services(hass, entry_data, services)
_setup_services(hass, entry_data, services)
if entry_data.device_info is not None and entry_data.device_info.name:
reconnect_logic.name = entry_data.device_info.name
@ -708,12 +707,27 @@ ARG_TYPE_METADATA = {
}
async def _register_service(
hass: HomeAssistant, entry_data: RuntimeEntryData, service: UserService
async def execute_service(
entry_data: RuntimeEntryData, service: UserService, call: ServiceCall
) -> None:
if entry_data.device_info is None:
raise ValueError("Device Info needs to be fetched first")
service_name = f"{entry_data.device_info.name.replace('-', '_')}_{service.name}"
"""Execute a service on a node."""
await entry_data.client.execute_service(service, call.data)
def build_service_name(device_info: EsphomeDeviceInfo, service: UserService) -> str:
"""Build a service name for a node."""
return f"{device_info.name.replace('-', '_')}_{service.name}"
@callback
def _async_register_service(
hass: HomeAssistant,
entry_data: RuntimeEntryData,
device_info: EsphomeDeviceInfo,
service: UserService,
) -> None:
"""Register a service on a node."""
service_name = build_service_name(device_info, service)
schema = {}
fields = {}
@ -736,33 +750,36 @@ async def _register_service(
"selector": metadata.selector,
}
async def execute_service(call: ServiceCall) -> None:
await entry_data.client.execute_service(service, call.data)
hass.services.async_register(
DOMAIN, service_name, execute_service, vol.Schema(schema)
DOMAIN,
service_name,
partial(execute_service, entry_data, service),
vol.Schema(schema),
)
async_set_service_schema(
hass,
DOMAIN,
service_name,
{
"description": (
f"Calls the service {service.name} of the node {device_info.name}"
),
"fields": fields,
},
)
service_desc = {
"description": (
f"Calls the service {service.name} of the node"
f" {entry_data.device_info.name}"
),
"fields": fields,
}
async_set_service_schema(hass, DOMAIN, service_name, service_desc)
async def _setup_services(
@callback
def _setup_services(
hass: HomeAssistant, entry_data: RuntimeEntryData, services: list[UserService]
) -> None:
if entry_data.device_info is None:
device_info = entry_data.device_info
if device_info is None:
# Can happen if device has never connected or .storage cleared
return
old_services = entry_data.services.copy()
to_unregister = []
to_register = []
to_unregister: list[UserService] = []
to_register: list[UserService] = []
for service in services:
if service.key in old_services:
# Already exists
@ -780,11 +797,11 @@ async def _setup_services(
entry_data.services = {serv.key: serv for serv in services}
for service in to_unregister:
service_name = f"{entry_data.device_info.name}_{service.name}"
service_name = build_service_name(device_info, service)
hass.services.async_remove(DOMAIN, service_name)
for service in to_register:
await _register_service(hass, entry_data, service)
_async_register_service(hass, entry_data, device_info, service)
async def cleanup_instance(hass: HomeAssistant, entry: ConfigEntry) -> RuntimeEntryData:

View File

@ -2,7 +2,15 @@
from collections.abc import Awaitable, Callable
from unittest.mock import AsyncMock, call
from aioesphomeapi import APIClient, DeviceInfo, EntityInfo, EntityState, UserService
from aioesphomeapi import (
APIClient,
DeviceInfo,
EntityInfo,
EntityState,
UserService,
UserServiceArg,
UserServiceArgType,
)
import pytest
from homeassistant import config_entries
@ -374,3 +382,221 @@ async def test_debug_logging(
)
await hass.async_block_till_done()
mock_client.set_debug.assert_has_calls([call(False)])
async def test_esphome_device_with_dash_in_name_user_services(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test a device with user services and a dash in the name."""
entity_info = []
states = []
service1 = UserService(
name="my_service",
key=1,
args=[
UserServiceArg(name="arg1", type=UserServiceArgType.BOOL),
UserServiceArg(name="arg2", type=UserServiceArgType.INT),
UserServiceArg(name="arg3", type=UserServiceArgType.FLOAT),
UserServiceArg(name="arg4", type=UserServiceArgType.STRING),
UserServiceArg(name="arg5", type=UserServiceArgType.BOOL_ARRAY),
UserServiceArg(name="arg6", type=UserServiceArgType.INT_ARRAY),
UserServiceArg(name="arg7", type=UserServiceArgType.FLOAT_ARRAY),
UserServiceArg(name="arg8", type=UserServiceArgType.STRING_ARRAY),
],
)
service2 = UserService(
name="simple_service",
key=2,
args=[
UserServiceArg(name="arg1", type=UserServiceArgType.BOOL),
],
)
device = await mock_esphome_device(
mock_client=mock_client,
entity_info=entity_info,
user_service=[service1, service2],
device_info={"name": "with-dash"},
states=states,
)
await hass.async_block_till_done()
assert hass.services.has_service(DOMAIN, "with_dash_my_service")
assert hass.services.has_service(DOMAIN, "with_dash_simple_service")
await hass.services.async_call(DOMAIN, "with_dash_simple_service", {"arg1": True})
await hass.async_block_till_done()
mock_client.execute_service.assert_has_calls(
[
call(
UserService(
name="simple_service",
key=2,
args=[UserServiceArg(name="arg1", type=UserServiceArgType.BOOL)],
),
{"arg1": True},
)
]
)
mock_client.execute_service.reset_mock()
# Verify the service can be removed
mock_client.list_entities_services = AsyncMock(
return_value=(entity_info, [service1])
)
await device.mock_disconnect(True)
await hass.async_block_till_done()
await device.mock_connect()
await hass.async_block_till_done()
assert hass.services.has_service(DOMAIN, "with_dash_my_service")
assert not hass.services.has_service(DOMAIN, "with_dash_simple_service")
async def test_esphome_user_services_ignores_invalid_arg_types(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test a device with user services and a dash in the name."""
entity_info = []
states = []
service1 = UserService(
name="bad_service",
key=1,
args=[
UserServiceArg(name="arg1", type="wrong"),
],
)
service2 = UserService(
name="simple_service",
key=2,
args=[
UserServiceArg(name="arg1", type=UserServiceArgType.BOOL),
],
)
device = await mock_esphome_device(
mock_client=mock_client,
entity_info=entity_info,
user_service=[service1, service2],
device_info={"name": "with-dash"},
states=states,
)
await hass.async_block_till_done()
assert not hass.services.has_service(DOMAIN, "with_dash_bad_service")
assert hass.services.has_service(DOMAIN, "with_dash_simple_service")
await hass.services.async_call(DOMAIN, "with_dash_simple_service", {"arg1": True})
await hass.async_block_till_done()
mock_client.execute_service.assert_has_calls(
[
call(
UserService(
name="simple_service",
key=2,
args=[UserServiceArg(name="arg1", type=UserServiceArgType.BOOL)],
),
{"arg1": True},
)
]
)
mock_client.execute_service.reset_mock()
# Verify the service can be removed
mock_client.list_entities_services = AsyncMock(
return_value=(entity_info, [service2])
)
await device.mock_disconnect(True)
await hass.async_block_till_done()
await device.mock_connect()
await hass.async_block_till_done()
assert hass.services.has_service(DOMAIN, "with_dash_simple_service")
assert not hass.services.has_service(DOMAIN, "with_dash_bad_service")
async def test_esphome_user_services_changes(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test a device with user services that change arguments."""
entity_info = []
states = []
service1 = UserService(
name="simple_service",
key=2,
args=[
UserServiceArg(name="arg1", type=UserServiceArgType.BOOL),
],
)
device = await mock_esphome_device(
mock_client=mock_client,
entity_info=entity_info,
user_service=[service1],
device_info={"name": "with-dash"},
states=states,
)
await hass.async_block_till_done()
assert hass.services.has_service(DOMAIN, "with_dash_simple_service")
await hass.services.async_call(DOMAIN, "with_dash_simple_service", {"arg1": True})
await hass.async_block_till_done()
mock_client.execute_service.assert_has_calls(
[
call(
UserService(
name="simple_service",
key=2,
args=[UserServiceArg(name="arg1", type=UserServiceArgType.BOOL)],
),
{"arg1": True},
)
]
)
mock_client.execute_service.reset_mock()
new_service1 = UserService(
name="simple_service",
key=2,
args=[
UserServiceArg(name="arg1", type=UserServiceArgType.FLOAT),
],
)
# Verify the service can be updated
mock_client.list_entities_services = AsyncMock(
return_value=(entity_info, [new_service1])
)
await device.mock_disconnect(True)
await hass.async_block_till_done()
await device.mock_connect()
await hass.async_block_till_done()
assert hass.services.has_service(DOMAIN, "with_dash_simple_service")
await hass.services.async_call(DOMAIN, "with_dash_simple_service", {"arg1": 4.5})
await hass.async_block_till_done()
mock_client.execute_service.assert_has_calls(
[
call(
UserService(
name="simple_service",
key=2,
args=[UserServiceArg(name="arg1", type=UserServiceArgType.FLOAT)],
),
{"arg1": 4.5},
)
]
)
mock_client.execute_service.reset_mock()