diff --git a/homeassistant/components/thread/dataset_store.py b/homeassistant/components/thread/dataset_store.py index 9c5d79cc0e0..9dc4ad31217 100644 --- a/homeassistant/components/thread/dataset_store.py +++ b/homeassistant/components/thread/dataset_store.py @@ -1,6 +1,7 @@ """Persistently store thread datasets.""" from __future__ import annotations +from asyncio import Event, Task, wait import dataclasses from datetime import datetime import logging @@ -16,6 +17,9 @@ from homeassistant.helpers.singleton import singleton from homeassistant.helpers.storage import Store from homeassistant.util import dt as dt_util, ulid as ulid_util +from . import discovery + +BORDER_AGENT_DISCOVERY_TIMEOUT = 30 DATA_STORE = "thread.datasets" STORAGE_KEY = "thread.datasets" STORAGE_VERSION_MAJOR = 1 @@ -177,6 +181,7 @@ class DatasetStore: self.hass = hass self.datasets: dict[str, DatasetEntry] = {} self._preferred_dataset: str | None = None + self._set_preferred_dataset_task: Task | None = None self._store: Store[dict[str, Any]] = DatasetStoreStore( hass, STORAGE_VERSION_MAJOR, @@ -267,11 +272,21 @@ class DatasetStore: preferred_border_agent_id=preferred_border_agent_id, source=source, tlv=tlv ) self.datasets[entry.id] = entry - # Set to preferred if there is no preferred dataset - if self._preferred_dataset is None: - self._preferred_dataset = entry.id self.async_schedule_save() + # Set the new network as preferred if there is no preferred dataset and there is + # no other router present. We only attempt this once. + if ( + self._preferred_dataset is None + and preferred_border_agent_id + and not self._set_preferred_dataset_task + ): + self._set_preferred_dataset_task = self.hass.async_create_task( + self._set_preferred_dataset_if_only_network( + entry.id, preferred_border_agent_id + ) + ) + @callback def async_delete(self, dataset_id: str) -> None: """Delete dataset.""" @@ -310,6 +325,62 @@ class DatasetStore: self._preferred_dataset = dataset_id self.async_schedule_save() + async def _set_preferred_dataset_if_only_network( + self, dataset_id: str, border_agent_id: str + ) -> None: + """Set the preferred dataset, unless there are other routers present.""" + _LOGGER.debug( + "_set_preferred_dataset_if_only_network called for router %s", + border_agent_id, + ) + + own_router_evt = Event() + other_router_evt = Event() + + @callback + def router_discovered( + key: str, data: discovery.ThreadRouterDiscoveryData + ) -> None: + """Handle router discovered.""" + _LOGGER.debug("discovered router with id %s", data.border_agent_id) + if data.border_agent_id == border_agent_id: + own_router_evt.set() + return + + other_router_evt.set() + + # Start Thread router discovery + thread_discovery = discovery.ThreadRouterDiscovery( + self.hass, router_discovered, lambda key: None + ) + await thread_discovery.async_start() + + found_own_router = self.hass.async_create_task(own_router_evt.wait()) + found_other_router = self.hass.async_create_task(other_router_evt.wait()) + pending = {found_own_router, found_other_router} + (done, pending) = await wait(pending, timeout=BORDER_AGENT_DISCOVERY_TIMEOUT) + if found_other_router in done: + # We found another router on the network, don't set the dataset + # as preferred + _LOGGER.debug("Other router found, do not set dataset as default") + + # Note that asyncio.wait does not raise TimeoutError, it instead returns + # the jobs which did not finish in the pending-set. + elif found_own_router in pending: + # Either the router is not there, or mDNS is not working. In any case, + # don't set the router as preferred. + _LOGGER.debug("Own router not found, do not set dataset as default") + + else: + # We've discovered the router connected to the dataset, but we did not + # find any other router on the network - mark the dataset as preferred. + _LOGGER.debug("No other router found, set dataset as default") + self.preferred_dataset = dataset_id + + for task in pending: + task.cancel() + await thread_discovery.async_stop() + async def async_load(self) -> None: """Load the datasets.""" data = await self._store.async_load() diff --git a/tests/components/otbr/__init__.py b/tests/components/otbr/__init__.py index a30275d3569..e72849aa5a1 100644 --- a/tests/components/otbr/__init__.py +++ b/tests/components/otbr/__init__.py @@ -28,3 +28,32 @@ DATASET_INSECURE_PASSPHRASE = bytes.fromhex( ) TEST_BORDER_AGENT_ID = bytes.fromhex("230C6A1AC57F6F4BE262ACF32E5EF52C") + + +ROUTER_DISCOVERY_HASS = { + "type_": "_meshcop._udp.local.", + "name": "HomeAssistant OpenThreadBorderRouter #0BBF._meshcop._udp.local.", + "addresses": [b"\xc0\xa8\x00s"], + "port": 49153, + "weight": 0, + "priority": 0, + "server": "core-silabs-multiprotocol.local.", + "properties": { + b"rv": b"1", + b"id": b"#\x0cj\x1a\xc5\x7foK\xe2b\xac\xf3.^\xf5,", + b"vn": b"HomeAssistant", + b"mn": b"OpenThreadBorderRouter", + b"nn": b"OpenThread HC", + b"xp": b"\xe6\x0f\xc7\xc1\x86!,\xe5", + b"tv": b"1.3.0", + b"xa": b"\xae\xeb/YKW\x0b\xbf", + b"sb": b"\x00\x00\x01\xb1", + b"at": b"\x00\x00\x00\x00\x00\x01\x00\x00", + b"pt": b"\x8f\x06Q~", + b"sq": b"3", + b"bb": b"\xf0\xbf", + b"dn": b"DefaultDomain", + b"omr": b"@\xfd \xbe\x89IZ\x00\x01", + }, + "interface_index": None, +} diff --git a/tests/components/otbr/test_init.py b/tests/components/otbr/test_init.py index 1b5c1e8b60a..496427c083a 100644 --- a/tests/components/otbr/test_init.py +++ b/tests/components/otbr/test_init.py @@ -1,13 +1,16 @@ """Test the Open Thread Border Router integration.""" import asyncio from http import HTTPStatus +from typing import Any from unittest.mock import ANY, AsyncMock, MagicMock, patch import aiohttp import pytest import python_otbr_api +from zeroconf.asyncio import AsyncServiceInfo from homeassistant.components import otbr, thread +from homeassistant.components.thread import discovery from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import issue_registry as ir @@ -21,6 +24,7 @@ from . import ( DATASET_CH16, DATASET_INSECURE_NW_KEY, DATASET_INSECURE_PASSPHRASE, + ROUTER_DISCOVERY_HASS, TEST_BORDER_AGENT_ID, ) @@ -34,8 +38,19 @@ DATASET_NO_CHANNEL = bytes.fromhex( ) -async def test_import_dataset(hass: HomeAssistant) -> None: +async def test_import_dataset(hass: HomeAssistant, mock_async_zeroconf: None) -> None: """Test the active dataset is imported at setup.""" + add_service_listener_called = asyncio.Event() + + async def mock_add_service_listener(type_: str, listener: Any): + add_service_listener_called.set() + + mock_async_zeroconf.async_add_service_listener = AsyncMock( + side_effect=mock_add_service_listener + ) + mock_async_zeroconf.async_remove_service_listener = AsyncMock() + mock_async_zeroconf.async_get_service_info = AsyncMock() + issue_registry = ir.async_get(hass) assert await thread.async_get_preferred_dataset(hass) is None @@ -46,13 +61,37 @@ async def test_import_dataset(hass: HomeAssistant) -> None: title="My OTBR", ) config_entry.add_to_hass(hass) + with patch( "python_otbr_api.OTBR.get_active_dataset_tlvs", return_value=DATASET_CH16 ), patch( "python_otbr_api.OTBR.get_border_agent_id", return_value=TEST_BORDER_AGENT_ID + ), patch( + "homeassistant.components.thread.dataset_store.BORDER_AGENT_DISCOVERY_TIMEOUT", + 0.1, ): assert await hass.config_entries.async_setup(config_entry.entry_id) + # Wait for Thread router discovery to start + await add_service_listener_called.wait() + mock_async_zeroconf.async_add_service_listener.assert_called_once_with( + "_meshcop._udp.local.", ANY + ) + + # Discover a service matching our router + listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = ( + mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1] + ) + mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo( + **ROUTER_DISCOVERY_HASS + ) + listener.add_service( + None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"] + ) + + # Wait for discovery of other routers to time out + await hass.async_block_till_done() + dataset_store = await thread.dataset_store.async_get_store(hass) assert ( list(dataset_store.datasets.values())[0].preferred_border_agent_id diff --git a/tests/components/thread/__init__.py b/tests/components/thread/__init__.py index 7ca6cbaf2ed..155e46a8ee0 100644 --- a/tests/components/thread/__init__.py +++ b/tests/components/thread/__init__.py @@ -18,6 +18,7 @@ DATASET_3 = ( "0212340410445F2B5CA6F2A93A55CE570A70EFEECB0C0402A0F7F8" ) +TEST_BORDER_AGENT_ID = bytes.fromhex("230C6A1AC57F6F4BE262ACF32E5EF52C") ROUTER_DISCOVERY_GOOGLE_1 = { "type_": "_meshcop._udp.local.", diff --git a/tests/components/thread/test_dataset_store.py b/tests/components/thread/test_dataset_store.py index d8822a7d536..246fb88f3ef 100644 --- a/tests/components/thread/test_dataset_store.py +++ b/tests/components/thread/test_dataset_store.py @@ -1,14 +1,24 @@ """Test the thread dataset store.""" +import asyncio from typing import Any +from unittest.mock import ANY, AsyncMock, patch import pytest from python_otbr_api.tlv_parser import TLVError +from zeroconf.asyncio import AsyncServiceInfo -from homeassistant.components.thread import dataset_store +from homeassistant.components.thread import dataset_store, discovery from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from . import DATASET_1, DATASET_2, DATASET_3 +from . import ( + DATASET_1, + DATASET_2, + DATASET_3, + ROUTER_DISCOVERY_GOOGLE_1, + ROUTER_DISCOVERY_HASS, + TEST_BORDER_AGENT_ID, +) from tests.common import flush_store @@ -107,6 +117,7 @@ async def test_delete_preferred_dataset(hass: HomeAssistant) -> None: store = await dataset_store.async_get_store(hass) dataset_id = list(store.datasets.values())[0].id + store.preferred_dataset = dataset_id with pytest.raises(HomeAssistantError, match="attempt to remove preferred dataset"): store.async_delete(dataset_id) @@ -130,6 +141,10 @@ async def test_get_preferred_dataset(hass: HomeAssistant) -> None: await dataset_store.async_add_dataset(hass, "source", DATASET_1) + store = await dataset_store.async_get_store(hass) + dataset_id = list(store.datasets.values())[0].id + store.preferred_dataset = dataset_id + assert (await dataset_store.async_get_preferred_dataset(hass)) == DATASET_1 @@ -256,6 +271,8 @@ async def test_load_datasets(hass: HomeAssistant) -> None: for dataset in datasets: store1.async_add(dataset["source"], dataset["tlv"], None) assert len(store1.datasets) == 3 + dataset_id = list(store1.datasets.values())[0].id + store1.preferred_dataset = dataset_id for dataset in store1.datasets.values(): if dataset.source == "Google": @@ -575,3 +592,252 @@ async def test_set_preferred_border_agent_id(hass: HomeAssistant) -> None: hass, "source", DATASET_1_LARGER_TIMESTAMP, preferred_border_agent_id="blah" ) assert list(store.datasets.values())[1].preferred_border_agent_id == "blah" + + +async def test_automatically_set_preferred_dataset( + hass: HomeAssistant, mock_async_zeroconf: None +) -> None: + """Test automatically setting the first dataset as the preferred dataset.""" + add_service_listener_called = asyncio.Event() + remove_service_listener_called = asyncio.Event() + + async def mock_add_service_listener(type_: str, listener: Any): + add_service_listener_called.set() + + async def mock_remove_service_listener(listener: Any): + remove_service_listener_called.set() + + mock_async_zeroconf.async_add_service_listener = AsyncMock( + side_effect=mock_add_service_listener + ) + mock_async_zeroconf.async_remove_service_listener = AsyncMock( + side_effect=mock_remove_service_listener + ) + mock_async_zeroconf.async_get_service_info = AsyncMock() + + with patch( + "homeassistant.components.thread.dataset_store.BORDER_AGENT_DISCOVERY_TIMEOUT", + 0.1, + ): + await dataset_store.async_add_dataset( + hass, + "source", + DATASET_1, + preferred_border_agent_id=TEST_BORDER_AGENT_ID.hex(), + ) + + # Wait for discovery to start + await add_service_listener_called.wait() + mock_async_zeroconf.async_add_service_listener.assert_called_once_with( + "_meshcop._udp.local.", ANY + ) + + # Discover a service matching our router + listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = ( + mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1] + ) + mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo( + **ROUTER_DISCOVERY_HASS + ) + listener.add_service( + None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"] + ) + + # Wait for discovery of other routers to time out and discovery to stop + await remove_service_listener_called.wait() + + store = await dataset_store.async_get_store(hass) + assert ( + list(store.datasets.values())[0].preferred_border_agent_id + == TEST_BORDER_AGENT_ID.hex() + ) + assert await dataset_store.async_get_preferred_dataset(hass) == DATASET_1 + + +async def test_automatically_set_preferred_dataset_own_and_other_router( + hass: HomeAssistant, mock_async_zeroconf: None +) -> None: + """Test automatically setting the first dataset as the preferred dataset. + + In this test case both our own and another router are found. + """ + add_service_listener_called = asyncio.Event() + remove_service_listener_called = asyncio.Event() + + async def mock_add_service_listener(type_: str, listener: Any): + add_service_listener_called.set() + + async def mock_remove_service_listener(listener: Any): + remove_service_listener_called.set() + + mock_async_zeroconf.async_add_service_listener = AsyncMock( + side_effect=mock_add_service_listener + ) + mock_async_zeroconf.async_remove_service_listener = AsyncMock( + side_effect=mock_remove_service_listener + ) + mock_async_zeroconf.async_get_service_info = AsyncMock() + + with patch( + "homeassistant.components.thread.dataset_store.BORDER_AGENT_DISCOVERY_TIMEOUT", + 0.1, + ): + await dataset_store.async_add_dataset( + hass, + "source", + DATASET_1, + preferred_border_agent_id=TEST_BORDER_AGENT_ID.hex(), + ) + + # Wait for discovery to start + await add_service_listener_called.wait() + mock_async_zeroconf.async_add_service_listener.assert_called_once_with( + "_meshcop._udp.local.", ANY + ) + + # Discover a service matching our router + listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = ( + mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1] + ) + mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo( + **ROUTER_DISCOVERY_HASS + ) + listener.add_service( + None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"] + ) + + # Discover another router + listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = ( + mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1] + ) + mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo( + **ROUTER_DISCOVERY_GOOGLE_1 + ) + listener.add_service( + None, ROUTER_DISCOVERY_GOOGLE_1["type_"], ROUTER_DISCOVERY_GOOGLE_1["name"] + ) + + # Wait for discovery to stop + await remove_service_listener_called.wait() + + store = await dataset_store.async_get_store(hass) + assert ( + list(store.datasets.values())[0].preferred_border_agent_id + == TEST_BORDER_AGENT_ID.hex() + ) + assert await dataset_store.async_get_preferred_dataset(hass) is None + + +async def test_automatically_set_preferred_dataset_other_router( + hass: HomeAssistant, mock_async_zeroconf: None +) -> None: + """Test automatically setting the first dataset as the preferred dataset. + + In this test case another router is found. + """ + add_service_listener_called = asyncio.Event() + remove_service_listener_called = asyncio.Event() + + async def mock_add_service_listener(type_: str, listener: Any): + add_service_listener_called.set() + + async def mock_remove_service_listener(listener: Any): + remove_service_listener_called.set() + + mock_async_zeroconf.async_add_service_listener = AsyncMock( + side_effect=mock_add_service_listener + ) + mock_async_zeroconf.async_remove_service_listener = AsyncMock( + side_effect=mock_remove_service_listener + ) + mock_async_zeroconf.async_get_service_info = AsyncMock() + + with patch( + "homeassistant.components.thread.dataset_store.BORDER_AGENT_DISCOVERY_TIMEOUT", + 0.1, + ): + await dataset_store.async_add_dataset( + hass, + "source", + DATASET_1, + preferred_border_agent_id=TEST_BORDER_AGENT_ID.hex(), + ) + + # Wait for discovery to start + await add_service_listener_called.wait() + mock_async_zeroconf.async_add_service_listener.assert_called_once_with( + "_meshcop._udp.local.", ANY + ) + + # Discover another router + listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = ( + mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1] + ) + mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo( + **ROUTER_DISCOVERY_GOOGLE_1 + ) + listener.add_service( + None, ROUTER_DISCOVERY_GOOGLE_1["type_"], ROUTER_DISCOVERY_GOOGLE_1["name"] + ) + + # Wait for discovery to stop + await remove_service_listener_called.wait() + + store = await dataset_store.async_get_store(hass) + assert ( + list(store.datasets.values())[0].preferred_border_agent_id + == TEST_BORDER_AGENT_ID.hex() + ) + assert await dataset_store.async_get_preferred_dataset(hass) is None + + +async def test_automatically_set_preferred_dataset_no_router( + hass: HomeAssistant, mock_async_zeroconf: None +) -> None: + """Test automatically setting the first dataset as the preferred dataset. + + In this test case no routers are found. + """ + add_service_listener_called = asyncio.Event() + remove_service_listener_called = asyncio.Event() + + async def mock_add_service_listener(type_: str, listener: Any): + add_service_listener_called.set() + + async def mock_remove_service_listener(listener: Any): + remove_service_listener_called.set() + + mock_async_zeroconf.async_add_service_listener = AsyncMock( + side_effect=mock_add_service_listener + ) + mock_async_zeroconf.async_remove_service_listener = AsyncMock( + side_effect=mock_remove_service_listener + ) + mock_async_zeroconf.async_get_service_info = AsyncMock() + + with patch( + "homeassistant.components.thread.dataset_store.BORDER_AGENT_DISCOVERY_TIMEOUT", + 0.1, + ): + await dataset_store.async_add_dataset( + hass, + "source", + DATASET_1, + preferred_border_agent_id=TEST_BORDER_AGENT_ID.hex(), + ) + + # Wait for discovery to start + await add_service_listener_called.wait() + mock_async_zeroconf.async_add_service_listener.assert_called_once_with( + "_meshcop._udp.local.", ANY + ) + + # Wait for discovery of other routers to time out and discovery to stop + await remove_service_listener_called.wait() + + store = await dataset_store.async_get_store(hass) + assert ( + list(store.datasets.values())[0].preferred_border_agent_id + == TEST_BORDER_AGENT_ID.hex() + ) + assert await dataset_store.async_get_preferred_dataset(hass) is None diff --git a/tests/components/thread/test_websocket_api.py b/tests/components/thread/test_websocket_api.py index 75e1b313132..3b05586a1db 100644 --- a/tests/components/thread/test_websocket_api.py +++ b/tests/components/thread/test_websocket_api.py @@ -86,6 +86,17 @@ async def test_delete_dataset( assert msg["success"] datasets = msg["result"]["datasets"] + # Set the first dataset as preferred + await client.send_json_auto_id( + { + "type": "thread/set_preferred_dataset", + "dataset_id": datasets[0]["dataset_id"], + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] is None + # Try deleting the preferred dataset await client.send_json_auto_id( {"type": "thread/delete_dataset", "dataset_id": datasets[0]["dataset_id"]} @@ -139,6 +150,9 @@ async def test_list_get_dataset( await dataset_store.async_add_dataset(hass, dataset["source"], dataset["tlv"]) store = await dataset_store.async_get_store(hass) + dataset_id = list(store.datasets.values())[0].id + store.preferred_dataset = dataset_id + for dataset in store.datasets.values(): if dataset.source == "Google": dataset_1 = dataset