Add Roborock buttons for starting routines (#139845)

This commit is contained in:
Regev Brody 2025-03-06 19:18:37 +02:00 committed by GitHub
parent 93dfbb4166
commit df1563daaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 252 additions and 20 deletions

View File

@ -83,7 +83,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: RoborockConfigEntry) ->
# Get a Coordinator if the device is available or if we have connected to the device before # Get a Coordinator if the device is available or if we have connected to the device before
coordinators = await asyncio.gather( coordinators = await asyncio.gather(
*build_setup_functions( *build_setup_functions(
hass, entry, device_map, user_data, product_info, home_data.rooms hass,
entry,
device_map,
user_data,
product_info,
home_data.rooms,
api_client,
), ),
return_exceptions=True, return_exceptions=True,
) )
@ -135,6 +141,7 @@ def build_setup_functions(
user_data: UserData, user_data: UserData,
product_info: dict[str, HomeDataProduct], product_info: dict[str, HomeDataProduct],
home_data_rooms: list[HomeDataRoom], home_data_rooms: list[HomeDataRoom],
api_client: RoborockApiClient,
) -> list[ ) -> list[
Coroutine[ Coroutine[
Any, Any,
@ -151,6 +158,7 @@ def build_setup_functions(
device, device,
product_info[device.product_id], product_info[device.product_id],
home_data_rooms, home_data_rooms,
api_client,
) )
for device in device_map.values() for device in device_map.values()
] ]
@ -163,11 +171,12 @@ async def setup_device(
device: HomeDataDevice, device: HomeDataDevice,
product_info: HomeDataProduct, product_info: HomeDataProduct,
home_data_rooms: list[HomeDataRoom], home_data_rooms: list[HomeDataRoom],
api_client: RoborockApiClient,
) -> RoborockDataUpdateCoordinator | RoborockDataUpdateCoordinatorA01 | None: ) -> RoborockDataUpdateCoordinator | RoborockDataUpdateCoordinatorA01 | None:
"""Set up a coordinator for a given device.""" """Set up a coordinator for a given device."""
if device.pv == "1.0": if device.pv == "1.0":
return await setup_device_v1( return await setup_device_v1(
hass, entry, user_data, device, product_info, home_data_rooms hass, entry, user_data, device, product_info, home_data_rooms, api_client
) )
if device.pv == "A01": if device.pv == "A01":
return await setup_device_a01(hass, entry, user_data, device, product_info) return await setup_device_a01(hass, entry, user_data, device, product_info)
@ -187,6 +196,7 @@ async def setup_device_v1(
device: HomeDataDevice, device: HomeDataDevice,
product_info: HomeDataProduct, product_info: HomeDataProduct,
home_data_rooms: list[HomeDataRoom], home_data_rooms: list[HomeDataRoom],
api_client: RoborockApiClient,
) -> RoborockDataUpdateCoordinator | None: ) -> RoborockDataUpdateCoordinator | None:
"""Set up a device Coordinator.""" """Set up a device Coordinator."""
mqtt_client = await hass.async_add_executor_job( mqtt_client = await hass.async_add_executor_job(
@ -208,7 +218,15 @@ async def setup_device_v1(
await mqtt_client.async_release() await mqtt_client.async_release()
raise raise
coordinator = RoborockDataUpdateCoordinator( coordinator = RoborockDataUpdateCoordinator(
hass, entry, device, networking, product_info, mqtt_client, home_data_rooms hass,
entry,
device,
networking,
product_info,
mqtt_client,
home_data_rooms,
api_client,
user_data,
) )
try: try:
await coordinator.async_config_entry_first_refresh() await coordinator.async_config_entry_first_refresh()

View File

@ -2,7 +2,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from dataclasses import dataclass from dataclasses import dataclass
import itertools
from typing import Any
from roborock.roborock_typing import RoborockCommand from roborock.roborock_typing import RoborockCommand
@ -12,7 +15,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .coordinator import RoborockConfigEntry, RoborockDataUpdateCoordinator from .coordinator import RoborockConfigEntry, RoborockDataUpdateCoordinator
from .entity import RoborockEntityV1 from .entity import RoborockEntity, RoborockEntityV1
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
@ -65,14 +68,34 @@ async def async_setup_entry(
async_add_entities: AddConfigEntryEntitiesCallback, async_add_entities: AddConfigEntryEntitiesCallback,
) -> None: ) -> None:
"""Set up Roborock button platform.""" """Set up Roborock button platform."""
routines_lists = await asyncio.gather(
*[coordinator.get_routines() for coordinator in config_entry.runtime_data.v1],
)
async_add_entities( async_add_entities(
RoborockButtonEntity( itertools.chain(
coordinator, (
description, RoborockButtonEntity(
coordinator,
description,
)
for coordinator in config_entry.runtime_data.v1
for description in CONSUMABLE_BUTTON_DESCRIPTIONS
if isinstance(coordinator, RoborockDataUpdateCoordinator)
),
(
RoborockRoutineButtonEntity(
coordinator,
ButtonEntityDescription(
key=str(routine.id),
name=routine.name,
),
)
for coordinator, routines in zip(
config_entry.runtime_data.v1, routines_lists, strict=True
)
for routine in routines
),
) )
for coordinator in config_entry.runtime_data.v1
for description in CONSUMABLE_BUTTON_DESCRIPTIONS
if isinstance(coordinator, RoborockDataUpdateCoordinator)
) )
@ -97,3 +120,28 @@ class RoborockButtonEntity(RoborockEntityV1, ButtonEntity):
async def async_press(self) -> None: async def async_press(self) -> None:
"""Press the button.""" """Press the button."""
await self.send(self.entity_description.command, self.entity_description.param) await self.send(self.entity_description.command, self.entity_description.param)
class RoborockRoutineButtonEntity(RoborockEntity, ButtonEntity):
"""A class to define Roborock routines button entities."""
entity_description: ButtonEntityDescription
def __init__(
self,
coordinator: RoborockDataUpdateCoordinator,
entity_description: ButtonEntityDescription,
) -> None:
"""Create a button entity."""
super().__init__(
f"{entity_description.key}_{coordinator.duid_slug}",
coordinator.device_info,
coordinator.api,
)
self._routine_id = int(entity_description.key)
self._coordinator = coordinator
self.entity_description = entity_description
async def async_press(self, **kwargs: Any) -> None:
"""Press the button."""
await self._coordinator.execute_routines(self._routine_id)

View File

@ -10,17 +10,26 @@ import logging
from propcache.api import cached_property from propcache.api import cached_property
from roborock import HomeDataRoom from roborock import HomeDataRoom
from roborock.code_mappings import RoborockCategory from roborock.code_mappings import RoborockCategory
from roborock.containers import DeviceData, HomeDataDevice, HomeDataProduct, NetworkInfo from roborock.containers import (
DeviceData,
HomeDataDevice,
HomeDataProduct,
HomeDataScene,
NetworkInfo,
UserData,
)
from roborock.exceptions import RoborockException from roborock.exceptions import RoborockException
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol
from roborock.roborock_typing import DeviceProp from roborock.roborock_typing import DeviceProp
from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1 from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1 from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
from roborock.version_a01_apis import RoborockClientA01 from roborock.version_a01_apis import RoborockClientA01
from roborock.web_api import RoborockApiClient
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_CONNECTIONS from homeassistant.const import ATTR_CONNECTIONS
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.typing import StateType from homeassistant.helpers.typing import StateType
@ -67,6 +76,8 @@ class RoborockDataUpdateCoordinator(DataUpdateCoordinator[DeviceProp]):
product_info: HomeDataProduct, product_info: HomeDataProduct,
cloud_api: RoborockMqttClientV1, cloud_api: RoborockMqttClientV1,
home_data_rooms: list[HomeDataRoom], home_data_rooms: list[HomeDataRoom],
api_client: RoborockApiClient,
user_data: UserData,
) -> None: ) -> None:
"""Initialize.""" """Initialize."""
super().__init__( super().__init__(
@ -89,7 +100,7 @@ class RoborockDataUpdateCoordinator(DataUpdateCoordinator[DeviceProp]):
self.cloud_api = cloud_api self.cloud_api = cloud_api
self.device_info = DeviceInfo( self.device_info = DeviceInfo(
name=self.roborock_device_info.device.name, name=self.roborock_device_info.device.name,
identifiers={(DOMAIN, self.roborock_device_info.device.duid)}, identifiers={(DOMAIN, self.duid)},
manufacturer="Roborock", manufacturer="Roborock",
model=self.roborock_device_info.product.model, model=self.roborock_device_info.product.model,
model_id=self.roborock_device_info.product.model, model_id=self.roborock_device_info.product.model,
@ -103,8 +114,10 @@ class RoborockDataUpdateCoordinator(DataUpdateCoordinator[DeviceProp]):
self.maps: dict[int, RoborockMapInfo] = {} self.maps: dict[int, RoborockMapInfo] = {}
self._home_data_rooms = {str(room.id): room.name for room in home_data_rooms} self._home_data_rooms = {str(room.id): room.name for room in home_data_rooms}
self.map_storage = RoborockMapStorage( self.map_storage = RoborockMapStorage(
hass, self.config_entry.entry_id, slugify(self.duid) hass, self.config_entry.entry_id, self.duid_slug
) )
self._user_data = user_data
self._api_client = api_client
async def _async_setup(self) -> None: async def _async_setup(self) -> None:
"""Set up the coordinator.""" """Set up the coordinator."""
@ -134,7 +147,7 @@ class RoborockDataUpdateCoordinator(DataUpdateCoordinator[DeviceProp]):
except RoborockException: except RoborockException:
_LOGGER.warning( _LOGGER.warning(
"Using the cloud API for device %s. This is not recommended as it can lead to rate limiting. We recommend making your vacuum accessible by your Home Assistant instance", "Using the cloud API for device %s. This is not recommended as it can lead to rate limiting. We recommend making your vacuum accessible by your Home Assistant instance",
self.roborock_device_info.device.duid, self.duid,
) )
await self.api.async_disconnect() await self.api.async_disconnect()
# We use the cloud api if the local api fails to connect. # We use the cloud api if the local api fails to connect.
@ -194,6 +207,34 @@ class RoborockDataUpdateCoordinator(DataUpdateCoordinator[DeviceProp]):
for room in room_mapping or () for room in room_mapping or ()
} }
async def get_routines(self) -> list[HomeDataScene]:
"""Get routines."""
try:
return await self._api_client.get_scenes(self._user_data, self.duid)
except RoborockException as err:
_LOGGER.error("Failed to get routines %s", err)
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="command_failed",
translation_placeholders={
"command": "get_scenes",
},
) from err
async def execute_routines(self, routine_id: int) -> None:
"""Execute routines."""
try:
await self._api_client.execute_scene(self._user_data, routine_id)
except RoborockException as err:
_LOGGER.error("Failed to execute routines %s %s", routine_id, err)
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="command_failed",
translation_placeholders={
"command": "execute_scene",
},
) from err
@cached_property @cached_property
def duid(self) -> str: def duid(self) -> str:
"""Get the unique id of the device as specified by Roborock.""" """Get the unique id of the device as specified by Roborock."""

View File

@ -30,6 +30,7 @@ from .mock_data import (
MULTI_MAP_LIST, MULTI_MAP_LIST,
NETWORK_INFO, NETWORK_INFO,
PROP, PROP,
SCENES,
USER_DATA, USER_DATA,
USER_EMAIL, USER_EMAIL,
) )
@ -67,8 +68,24 @@ class A01Mock(RoborockMqttClientA01):
return {prot: self.protocol_responses[prot] for prot in dyad_data_protocols} return {prot: self.protocol_responses[prot] for prot in dyad_data_protocols}
@pytest.fixture(name="bypass_api_client_fixture")
def bypass_api_client_fixture() -> None:
"""Skip calls to the API client."""
with (
patch(
"homeassistant.components.roborock.RoborockApiClient.get_home_data_v2",
return_value=HOME_DATA,
),
patch(
"homeassistant.components.roborock.RoborockApiClient.get_scenes",
return_value=SCENES,
),
):
yield
@pytest.fixture(name="bypass_api_fixture") @pytest.fixture(name="bypass_api_fixture")
def bypass_api_fixture() -> None: def bypass_api_fixture(bypass_api_client_fixture: Any) -> None:
"""Skip calls to the API.""" """Skip calls to the API."""
with ( with (
patch("homeassistant.components.roborock.RoborockMqttClientV1.async_connect"), patch("homeassistant.components.roborock.RoborockMqttClientV1.async_connect"),
@ -76,10 +93,6 @@ def bypass_api_fixture() -> None:
patch( patch(
"homeassistant.components.roborock.coordinator.RoborockMqttClientV1._send_command" "homeassistant.components.roborock.coordinator.RoborockMqttClientV1._send_command"
), ),
patch(
"homeassistant.components.roborock.RoborockApiClient.get_home_data_v2",
return_value=HOME_DATA,
),
patch( patch(
"homeassistant.components.roborock.RoborockMqttClientV1.get_networking", "homeassistant.components.roborock.RoborockMqttClientV1.get_networking",
return_value=NETWORK_INFO, return_value=NETWORK_INFO,

View File

@ -9,6 +9,7 @@ from roborock.containers import (
Consumable, Consumable,
DnDTimer, DnDTimer,
HomeData, HomeData,
HomeDataScene,
MultiMapsList, MultiMapsList,
NetworkInfo, NetworkInfo,
S7Status, S7Status,
@ -1150,3 +1151,19 @@ MAP_DATA = MapData(0, 0)
MAP_DATA.image = ImageData( MAP_DATA.image = ImageData(
100, 10, 10, 10, 10, ImageConfig(), Image.new("RGB", (1, 1)), lambda p: p 100, 10, 10, 10, 10, ImageConfig(), Image.new("RGB", (1, 1)), lambda p: p
) )
SCENES = [
HomeDataScene.from_dict(
{
"name": "sc1",
"id": 12,
},
),
HomeDataScene.from_dict(
{
"name": "sc2",
"id": 24,
},
),
]

View File

@ -1,9 +1,10 @@
"""Test Roborock Button platform.""" """Test Roborock Button platform."""
from unittest.mock import patch from unittest.mock import ANY, patch
import pytest import pytest
import roborock import roborock
from roborock import RoborockException
from homeassistant.components.button import SERVICE_PRESS from homeassistant.components.button import SERVICE_PRESS
from homeassistant.const import Platform from homeassistant.const import Platform
@ -13,6 +14,18 @@ from homeassistant.exceptions import HomeAssistantError
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@pytest.fixture
def bypass_api_client_get_scenes_fixture(bypass_api_fixture) -> None:
"""Fixture to raise when getting scenes."""
with (
patch(
"homeassistant.components.roborock.RoborockApiClient.get_scenes",
side_effect=RoborockException(),
),
):
yield
@pytest.fixture @pytest.fixture
def platforms() -> list[Platform]: def platforms() -> list[Platform]:
"""Fixture to set platforms used in the test.""" """Fixture to set platforms used in the test."""
@ -84,3 +97,85 @@ async def test_update_failure(
) )
assert mock_send_message.assert_called_once assert mock_send_message.assert_called_once
assert hass.states.get(entity_id).state == "2023-10-30T08:50:00+00:00" assert hass.states.get(entity_id).state == "2023-10-30T08:50:00+00:00"
@pytest.mark.parametrize(
("entity_id"),
[
("button.roborock_s7_maxv_sc1"),
("button.roborock_s7_maxv_sc2"),
],
)
@pytest.mark.usefixtures("entity_registry_enabled_by_default")
async def test_get_button_routines_failure(
hass: HomeAssistant,
bypass_api_client_get_scenes_fixture,
setup_entry: MockConfigEntry,
entity_id: str,
) -> None:
"""Test that if routine retrieval fails, no entity is being created."""
# Ensure that the entity does not exist
assert hass.states.get(entity_id) is None
@pytest.mark.parametrize(
("entity_id", "routine_id"),
[
("button.roborock_s7_maxv_sc1", 12),
("button.roborock_s7_maxv_sc2", 24),
],
)
@pytest.mark.freeze_time("2023-10-30 08:50:00")
@pytest.mark.usefixtures("entity_registry_enabled_by_default")
async def test_press_routine_button_success(
hass: HomeAssistant,
bypass_api_fixture,
setup_entry: MockConfigEntry,
entity_id: str,
routine_id: int,
) -> None:
"""Test pressing the button entities."""
with patch(
"homeassistant.components.roborock.RoborockApiClient.execute_scene"
) as mock_execute_scene:
await hass.services.async_call(
"button",
SERVICE_PRESS,
blocking=True,
target={"entity_id": entity_id},
)
mock_execute_scene.assert_called_once_with(ANY, routine_id)
assert hass.states.get(entity_id).state == "2023-10-30T08:50:00+00:00"
@pytest.mark.parametrize(
("entity_id", "routine_id"),
[
("button.roborock_s7_maxv_sc1", 12),
],
)
@pytest.mark.freeze_time("2023-10-30 08:50:00")
@pytest.mark.usefixtures("entity_registry_enabled_by_default")
async def test_press_routine_button_failure(
hass: HomeAssistant,
bypass_api_fixture,
setup_entry: MockConfigEntry,
entity_id: str,
routine_id: int,
) -> None:
"""Test failure while pressing the button entity."""
with (
patch(
"homeassistant.components.roborock.RoborockApiClient.execute_scene",
side_effect=RoborockException,
) as mock_execute_scene,
pytest.raises(HomeAssistantError, match="Error while calling execute_scene"),
):
await hass.services.async_call(
"button",
SERVICE_PRESS,
blocking=True,
target={"entity_id": entity_id},
)
mock_execute_scene.assert_called_once_with(ANY, routine_id)
assert hass.states.get(entity_id).state == "2023-10-30T08:50:00+00:00"