Don't always set first thread dataset as preferred (#108278)

* Don't always set first thread dataset as preferred

* Update tests

* Make clarifying comments clearer

* Call asyncio.wait with return_when=ALL_COMPLETED

* Update otbr test

* Update homeassistant/components/thread/dataset_store.py

Co-authored-by: Stefan Agner <stefan@agner.ch>

* Update homeassistant/components/thread/dataset_store.py

---------

Co-authored-by: Stefan Agner <stefan@agner.ch>
This commit is contained in:
Erik Montnemery 2024-01-18 16:32:29 +01:00 committed by GitHub
parent bfe21b33f0
commit cdb798bec0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 426 additions and 6 deletions

View File

@ -1,6 +1,7 @@
"""Persistently store thread datasets.""" """Persistently store thread datasets."""
from __future__ import annotations from __future__ import annotations
from asyncio import Event, Task, wait
import dataclasses import dataclasses
from datetime import datetime from datetime import datetime
import logging import logging
@ -16,6 +17,9 @@ from homeassistant.helpers.singleton import singleton
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from homeassistant.util import dt as dt_util, ulid as ulid_util from homeassistant.util import dt as dt_util, ulid as ulid_util
from . import discovery
BORDER_AGENT_DISCOVERY_TIMEOUT = 30
DATA_STORE = "thread.datasets" DATA_STORE = "thread.datasets"
STORAGE_KEY = "thread.datasets" STORAGE_KEY = "thread.datasets"
STORAGE_VERSION_MAJOR = 1 STORAGE_VERSION_MAJOR = 1
@ -177,6 +181,7 @@ class DatasetStore:
self.hass = hass self.hass = hass
self.datasets: dict[str, DatasetEntry] = {} self.datasets: dict[str, DatasetEntry] = {}
self._preferred_dataset: str | None = None self._preferred_dataset: str | None = None
self._set_preferred_dataset_task: Task | None = None
self._store: Store[dict[str, Any]] = DatasetStoreStore( self._store: Store[dict[str, Any]] = DatasetStoreStore(
hass, hass,
STORAGE_VERSION_MAJOR, STORAGE_VERSION_MAJOR,
@ -267,11 +272,21 @@ class DatasetStore:
preferred_border_agent_id=preferred_border_agent_id, source=source, tlv=tlv preferred_border_agent_id=preferred_border_agent_id, source=source, tlv=tlv
) )
self.datasets[entry.id] = entry self.datasets[entry.id] = entry
# Set to preferred if there is no preferred dataset
if self._preferred_dataset is None:
self._preferred_dataset = entry.id
self.async_schedule_save() self.async_schedule_save()
# Set the new network as preferred if there is no preferred dataset and there is
# no other router present. We only attempt this once.
if (
self._preferred_dataset is None
and preferred_border_agent_id
and not self._set_preferred_dataset_task
):
self._set_preferred_dataset_task = self.hass.async_create_task(
self._set_preferred_dataset_if_only_network(
entry.id, preferred_border_agent_id
)
)
@callback @callback
def async_delete(self, dataset_id: str) -> None: def async_delete(self, dataset_id: str) -> None:
"""Delete dataset.""" """Delete dataset."""
@ -310,6 +325,62 @@ class DatasetStore:
self._preferred_dataset = dataset_id self._preferred_dataset = dataset_id
self.async_schedule_save() self.async_schedule_save()
async def _set_preferred_dataset_if_only_network(
self, dataset_id: str, border_agent_id: str
) -> None:
"""Set the preferred dataset, unless there are other routers present."""
_LOGGER.debug(
"_set_preferred_dataset_if_only_network called for router %s",
border_agent_id,
)
own_router_evt = Event()
other_router_evt = Event()
@callback
def router_discovered(
key: str, data: discovery.ThreadRouterDiscoveryData
) -> None:
"""Handle router discovered."""
_LOGGER.debug("discovered router with id %s", data.border_agent_id)
if data.border_agent_id == border_agent_id:
own_router_evt.set()
return
other_router_evt.set()
# Start Thread router discovery
thread_discovery = discovery.ThreadRouterDiscovery(
self.hass, router_discovered, lambda key: None
)
await thread_discovery.async_start()
found_own_router = self.hass.async_create_task(own_router_evt.wait())
found_other_router = self.hass.async_create_task(other_router_evt.wait())
pending = {found_own_router, found_other_router}
(done, pending) = await wait(pending, timeout=BORDER_AGENT_DISCOVERY_TIMEOUT)
if found_other_router in done:
# We found another router on the network, don't set the dataset
# as preferred
_LOGGER.debug("Other router found, do not set dataset as default")
# Note that asyncio.wait does not raise TimeoutError, it instead returns
# the jobs which did not finish in the pending-set.
elif found_own_router in pending:
# Either the router is not there, or mDNS is not working. In any case,
# don't set the router as preferred.
_LOGGER.debug("Own router not found, do not set dataset as default")
else:
# We've discovered the router connected to the dataset, but we did not
# find any other router on the network - mark the dataset as preferred.
_LOGGER.debug("No other router found, set dataset as default")
self.preferred_dataset = dataset_id
for task in pending:
task.cancel()
await thread_discovery.async_stop()
async def async_load(self) -> None: async def async_load(self) -> None:
"""Load the datasets.""" """Load the datasets."""
data = await self._store.async_load() data = await self._store.async_load()

View File

@ -28,3 +28,32 @@ DATASET_INSECURE_PASSPHRASE = bytes.fromhex(
) )
TEST_BORDER_AGENT_ID = bytes.fromhex("230C6A1AC57F6F4BE262ACF32E5EF52C") TEST_BORDER_AGENT_ID = bytes.fromhex("230C6A1AC57F6F4BE262ACF32E5EF52C")
ROUTER_DISCOVERY_HASS = {
"type_": "_meshcop._udp.local.",
"name": "HomeAssistant OpenThreadBorderRouter #0BBF._meshcop._udp.local.",
"addresses": [b"\xc0\xa8\x00s"],
"port": 49153,
"weight": 0,
"priority": 0,
"server": "core-silabs-multiprotocol.local.",
"properties": {
b"rv": b"1",
b"id": b"#\x0cj\x1a\xc5\x7foK\xe2b\xac\xf3.^\xf5,",
b"vn": b"HomeAssistant",
b"mn": b"OpenThreadBorderRouter",
b"nn": b"OpenThread HC",
b"xp": b"\xe6\x0f\xc7\xc1\x86!,\xe5",
b"tv": b"1.3.0",
b"xa": b"\xae\xeb/YKW\x0b\xbf",
b"sb": b"\x00\x00\x01\xb1",
b"at": b"\x00\x00\x00\x00\x00\x01\x00\x00",
b"pt": b"\x8f\x06Q~",
b"sq": b"3",
b"bb": b"\xf0\xbf",
b"dn": b"DefaultDomain",
b"omr": b"@\xfd \xbe\x89IZ\x00\x01",
},
"interface_index": None,
}

View File

@ -1,13 +1,16 @@
"""Test the Open Thread Border Router integration.""" """Test the Open Thread Border Router integration."""
import asyncio import asyncio
from http import HTTPStatus from http import HTTPStatus
from typing import Any
from unittest.mock import ANY, AsyncMock, MagicMock, patch from unittest.mock import ANY, AsyncMock, MagicMock, patch
import aiohttp import aiohttp
import pytest import pytest
import python_otbr_api import python_otbr_api
from zeroconf.asyncio import AsyncServiceInfo
from homeassistant.components import otbr, thread from homeassistant.components import otbr, thread
from homeassistant.components.thread import discovery
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import issue_registry as ir from homeassistant.helpers import issue_registry as ir
@ -21,6 +24,7 @@ from . import (
DATASET_CH16, DATASET_CH16,
DATASET_INSECURE_NW_KEY, DATASET_INSECURE_NW_KEY,
DATASET_INSECURE_PASSPHRASE, DATASET_INSECURE_PASSPHRASE,
ROUTER_DISCOVERY_HASS,
TEST_BORDER_AGENT_ID, TEST_BORDER_AGENT_ID,
) )
@ -34,8 +38,19 @@ DATASET_NO_CHANNEL = bytes.fromhex(
) )
async def test_import_dataset(hass: HomeAssistant) -> None: async def test_import_dataset(hass: HomeAssistant, mock_async_zeroconf: None) -> None:
"""Test the active dataset is imported at setup.""" """Test the active dataset is imported at setup."""
add_service_listener_called = asyncio.Event()
async def mock_add_service_listener(type_: str, listener: Any):
add_service_listener_called.set()
mock_async_zeroconf.async_add_service_listener = AsyncMock(
side_effect=mock_add_service_listener
)
mock_async_zeroconf.async_remove_service_listener = AsyncMock()
mock_async_zeroconf.async_get_service_info = AsyncMock()
issue_registry = ir.async_get(hass) issue_registry = ir.async_get(hass)
assert await thread.async_get_preferred_dataset(hass) is None assert await thread.async_get_preferred_dataset(hass) is None
@ -46,13 +61,37 @@ async def test_import_dataset(hass: HomeAssistant) -> None:
title="My OTBR", title="My OTBR",
) )
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
with patch( with patch(
"python_otbr_api.OTBR.get_active_dataset_tlvs", return_value=DATASET_CH16 "python_otbr_api.OTBR.get_active_dataset_tlvs", return_value=DATASET_CH16
), patch( ), patch(
"python_otbr_api.OTBR.get_border_agent_id", return_value=TEST_BORDER_AGENT_ID "python_otbr_api.OTBR.get_border_agent_id", return_value=TEST_BORDER_AGENT_ID
), patch(
"homeassistant.components.thread.dataset_store.BORDER_AGENT_DISCOVERY_TIMEOUT",
0.1,
): ):
assert await hass.config_entries.async_setup(config_entry.entry_id) assert await hass.config_entries.async_setup(config_entry.entry_id)
# Wait for Thread router discovery to start
await add_service_listener_called.wait()
mock_async_zeroconf.async_add_service_listener.assert_called_once_with(
"_meshcop._udp.local.", ANY
)
# Discover a service matching our router
listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = (
mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1]
)
mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo(
**ROUTER_DISCOVERY_HASS
)
listener.add_service(
None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"]
)
# Wait for discovery of other routers to time out
await hass.async_block_till_done()
dataset_store = await thread.dataset_store.async_get_store(hass) dataset_store = await thread.dataset_store.async_get_store(hass)
assert ( assert (
list(dataset_store.datasets.values())[0].preferred_border_agent_id list(dataset_store.datasets.values())[0].preferred_border_agent_id

View File

@ -18,6 +18,7 @@ DATASET_3 = (
"0212340410445F2B5CA6F2A93A55CE570A70EFEECB0C0402A0F7F8" "0212340410445F2B5CA6F2A93A55CE570A70EFEECB0C0402A0F7F8"
) )
TEST_BORDER_AGENT_ID = bytes.fromhex("230C6A1AC57F6F4BE262ACF32E5EF52C")
ROUTER_DISCOVERY_GOOGLE_1 = { ROUTER_DISCOVERY_GOOGLE_1 = {
"type_": "_meshcop._udp.local.", "type_": "_meshcop._udp.local.",

View File

@ -1,14 +1,24 @@
"""Test the thread dataset store.""" """Test the thread dataset store."""
import asyncio
from typing import Any from typing import Any
from unittest.mock import ANY, AsyncMock, patch
import pytest import pytest
from python_otbr_api.tlv_parser import TLVError from python_otbr_api.tlv_parser import TLVError
from zeroconf.asyncio import AsyncServiceInfo
from homeassistant.components.thread import dataset_store from homeassistant.components.thread import dataset_store, discovery
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from . import DATASET_1, DATASET_2, DATASET_3 from . import (
DATASET_1,
DATASET_2,
DATASET_3,
ROUTER_DISCOVERY_GOOGLE_1,
ROUTER_DISCOVERY_HASS,
TEST_BORDER_AGENT_ID,
)
from tests.common import flush_store from tests.common import flush_store
@ -107,6 +117,7 @@ async def test_delete_preferred_dataset(hass: HomeAssistant) -> None:
store = await dataset_store.async_get_store(hass) store = await dataset_store.async_get_store(hass)
dataset_id = list(store.datasets.values())[0].id dataset_id = list(store.datasets.values())[0].id
store.preferred_dataset = dataset_id
with pytest.raises(HomeAssistantError, match="attempt to remove preferred dataset"): with pytest.raises(HomeAssistantError, match="attempt to remove preferred dataset"):
store.async_delete(dataset_id) store.async_delete(dataset_id)
@ -130,6 +141,10 @@ async def test_get_preferred_dataset(hass: HomeAssistant) -> None:
await dataset_store.async_add_dataset(hass, "source", DATASET_1) await dataset_store.async_add_dataset(hass, "source", DATASET_1)
store = await dataset_store.async_get_store(hass)
dataset_id = list(store.datasets.values())[0].id
store.preferred_dataset = dataset_id
assert (await dataset_store.async_get_preferred_dataset(hass)) == DATASET_1 assert (await dataset_store.async_get_preferred_dataset(hass)) == DATASET_1
@ -256,6 +271,8 @@ async def test_load_datasets(hass: HomeAssistant) -> None:
for dataset in datasets: for dataset in datasets:
store1.async_add(dataset["source"], dataset["tlv"], None) store1.async_add(dataset["source"], dataset["tlv"], None)
assert len(store1.datasets) == 3 assert len(store1.datasets) == 3
dataset_id = list(store1.datasets.values())[0].id
store1.preferred_dataset = dataset_id
for dataset in store1.datasets.values(): for dataset in store1.datasets.values():
if dataset.source == "Google": if dataset.source == "Google":
@ -575,3 +592,252 @@ async def test_set_preferred_border_agent_id(hass: HomeAssistant) -> None:
hass, "source", DATASET_1_LARGER_TIMESTAMP, preferred_border_agent_id="blah" hass, "source", DATASET_1_LARGER_TIMESTAMP, preferred_border_agent_id="blah"
) )
assert list(store.datasets.values())[1].preferred_border_agent_id == "blah" assert list(store.datasets.values())[1].preferred_border_agent_id == "blah"
async def test_automatically_set_preferred_dataset(
hass: HomeAssistant, mock_async_zeroconf: None
) -> None:
"""Test automatically setting the first dataset as the preferred dataset."""
add_service_listener_called = asyncio.Event()
remove_service_listener_called = asyncio.Event()
async def mock_add_service_listener(type_: str, listener: Any):
add_service_listener_called.set()
async def mock_remove_service_listener(listener: Any):
remove_service_listener_called.set()
mock_async_zeroconf.async_add_service_listener = AsyncMock(
side_effect=mock_add_service_listener
)
mock_async_zeroconf.async_remove_service_listener = AsyncMock(
side_effect=mock_remove_service_listener
)
mock_async_zeroconf.async_get_service_info = AsyncMock()
with patch(
"homeassistant.components.thread.dataset_store.BORDER_AGENT_DISCOVERY_TIMEOUT",
0.1,
):
await dataset_store.async_add_dataset(
hass,
"source",
DATASET_1,
preferred_border_agent_id=TEST_BORDER_AGENT_ID.hex(),
)
# Wait for discovery to start
await add_service_listener_called.wait()
mock_async_zeroconf.async_add_service_listener.assert_called_once_with(
"_meshcop._udp.local.", ANY
)
# Discover a service matching our router
listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = (
mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1]
)
mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo(
**ROUTER_DISCOVERY_HASS
)
listener.add_service(
None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"]
)
# Wait for discovery of other routers to time out and discovery to stop
await remove_service_listener_called.wait()
store = await dataset_store.async_get_store(hass)
assert (
list(store.datasets.values())[0].preferred_border_agent_id
== TEST_BORDER_AGENT_ID.hex()
)
assert await dataset_store.async_get_preferred_dataset(hass) == DATASET_1
async def test_automatically_set_preferred_dataset_own_and_other_router(
hass: HomeAssistant, mock_async_zeroconf: None
) -> None:
"""Test automatically setting the first dataset as the preferred dataset.
In this test case both our own and another router are found.
"""
add_service_listener_called = asyncio.Event()
remove_service_listener_called = asyncio.Event()
async def mock_add_service_listener(type_: str, listener: Any):
add_service_listener_called.set()
async def mock_remove_service_listener(listener: Any):
remove_service_listener_called.set()
mock_async_zeroconf.async_add_service_listener = AsyncMock(
side_effect=mock_add_service_listener
)
mock_async_zeroconf.async_remove_service_listener = AsyncMock(
side_effect=mock_remove_service_listener
)
mock_async_zeroconf.async_get_service_info = AsyncMock()
with patch(
"homeassistant.components.thread.dataset_store.BORDER_AGENT_DISCOVERY_TIMEOUT",
0.1,
):
await dataset_store.async_add_dataset(
hass,
"source",
DATASET_1,
preferred_border_agent_id=TEST_BORDER_AGENT_ID.hex(),
)
# Wait for discovery to start
await add_service_listener_called.wait()
mock_async_zeroconf.async_add_service_listener.assert_called_once_with(
"_meshcop._udp.local.", ANY
)
# Discover a service matching our router
listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = (
mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1]
)
mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo(
**ROUTER_DISCOVERY_HASS
)
listener.add_service(
None, ROUTER_DISCOVERY_HASS["type_"], ROUTER_DISCOVERY_HASS["name"]
)
# Discover another router
listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = (
mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1]
)
mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo(
**ROUTER_DISCOVERY_GOOGLE_1
)
listener.add_service(
None, ROUTER_DISCOVERY_GOOGLE_1["type_"], ROUTER_DISCOVERY_GOOGLE_1["name"]
)
# Wait for discovery to stop
await remove_service_listener_called.wait()
store = await dataset_store.async_get_store(hass)
assert (
list(store.datasets.values())[0].preferred_border_agent_id
== TEST_BORDER_AGENT_ID.hex()
)
assert await dataset_store.async_get_preferred_dataset(hass) is None
async def test_automatically_set_preferred_dataset_other_router(
hass: HomeAssistant, mock_async_zeroconf: None
) -> None:
"""Test automatically setting the first dataset as the preferred dataset.
In this test case another router is found.
"""
add_service_listener_called = asyncio.Event()
remove_service_listener_called = asyncio.Event()
async def mock_add_service_listener(type_: str, listener: Any):
add_service_listener_called.set()
async def mock_remove_service_listener(listener: Any):
remove_service_listener_called.set()
mock_async_zeroconf.async_add_service_listener = AsyncMock(
side_effect=mock_add_service_listener
)
mock_async_zeroconf.async_remove_service_listener = AsyncMock(
side_effect=mock_remove_service_listener
)
mock_async_zeroconf.async_get_service_info = AsyncMock()
with patch(
"homeassistant.components.thread.dataset_store.BORDER_AGENT_DISCOVERY_TIMEOUT",
0.1,
):
await dataset_store.async_add_dataset(
hass,
"source",
DATASET_1,
preferred_border_agent_id=TEST_BORDER_AGENT_ID.hex(),
)
# Wait for discovery to start
await add_service_listener_called.wait()
mock_async_zeroconf.async_add_service_listener.assert_called_once_with(
"_meshcop._udp.local.", ANY
)
# Discover another router
listener: discovery.ThreadRouterDiscovery.ThreadServiceListener = (
mock_async_zeroconf.async_add_service_listener.mock_calls[0][1][1]
)
mock_async_zeroconf.async_get_service_info.return_value = AsyncServiceInfo(
**ROUTER_DISCOVERY_GOOGLE_1
)
listener.add_service(
None, ROUTER_DISCOVERY_GOOGLE_1["type_"], ROUTER_DISCOVERY_GOOGLE_1["name"]
)
# Wait for discovery to stop
await remove_service_listener_called.wait()
store = await dataset_store.async_get_store(hass)
assert (
list(store.datasets.values())[0].preferred_border_agent_id
== TEST_BORDER_AGENT_ID.hex()
)
assert await dataset_store.async_get_preferred_dataset(hass) is None
async def test_automatically_set_preferred_dataset_no_router(
hass: HomeAssistant, mock_async_zeroconf: None
) -> None:
"""Test automatically setting the first dataset as the preferred dataset.
In this test case no routers are found.
"""
add_service_listener_called = asyncio.Event()
remove_service_listener_called = asyncio.Event()
async def mock_add_service_listener(type_: str, listener: Any):
add_service_listener_called.set()
async def mock_remove_service_listener(listener: Any):
remove_service_listener_called.set()
mock_async_zeroconf.async_add_service_listener = AsyncMock(
side_effect=mock_add_service_listener
)
mock_async_zeroconf.async_remove_service_listener = AsyncMock(
side_effect=mock_remove_service_listener
)
mock_async_zeroconf.async_get_service_info = AsyncMock()
with patch(
"homeassistant.components.thread.dataset_store.BORDER_AGENT_DISCOVERY_TIMEOUT",
0.1,
):
await dataset_store.async_add_dataset(
hass,
"source",
DATASET_1,
preferred_border_agent_id=TEST_BORDER_AGENT_ID.hex(),
)
# Wait for discovery to start
await add_service_listener_called.wait()
mock_async_zeroconf.async_add_service_listener.assert_called_once_with(
"_meshcop._udp.local.", ANY
)
# Wait for discovery of other routers to time out and discovery to stop
await remove_service_listener_called.wait()
store = await dataset_store.async_get_store(hass)
assert (
list(store.datasets.values())[0].preferred_border_agent_id
== TEST_BORDER_AGENT_ID.hex()
)
assert await dataset_store.async_get_preferred_dataset(hass) is None

View File

@ -86,6 +86,17 @@ async def test_delete_dataset(
assert msg["success"] assert msg["success"]
datasets = msg["result"]["datasets"] datasets = msg["result"]["datasets"]
# Set the first dataset as preferred
await client.send_json_auto_id(
{
"type": "thread/set_preferred_dataset",
"dataset_id": datasets[0]["dataset_id"],
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] is None
# Try deleting the preferred dataset # Try deleting the preferred dataset
await client.send_json_auto_id( await client.send_json_auto_id(
{"type": "thread/delete_dataset", "dataset_id": datasets[0]["dataset_id"]} {"type": "thread/delete_dataset", "dataset_id": datasets[0]["dataset_id"]}
@ -139,6 +150,9 @@ async def test_list_get_dataset(
await dataset_store.async_add_dataset(hass, dataset["source"], dataset["tlv"]) await dataset_store.async_add_dataset(hass, dataset["source"], dataset["tlv"])
store = await dataset_store.async_get_store(hass) store = await dataset_store.async_get_store(hass)
dataset_id = list(store.datasets.values())[0].id
store.preferred_dataset = dataset_id
for dataset in store.datasets.values(): for dataset in store.datasets.values():
if dataset.source == "Google": if dataset.source == "Google":
dataset_1 = dataset dataset_1 = dataset