From 318b8adbed54bfa591337af80042c4cb3f3feb2f Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 14 Aug 2023 13:40:32 +0200 Subject: [PATCH] Set preferred router when importing OTBR dataset (#98379) --- homeassistant/components/otbr/__init__.py | 21 ++++++++++++++++++++- homeassistant/components/otbr/util.py | 5 +++++ tests/components/otbr/__init__.py | 2 ++ tests/components/otbr/conftest.py | 11 ++++++++++- tests/components/otbr/test_init.py | 15 +++++++++++---- 5 files changed, 48 insertions(+), 6 deletions(-) diff --git a/homeassistant/components/otbr/__init__.py b/homeassistant/components/otbr/__init__.py index 8685282acec..09a4499b60f 100644 --- a/homeassistant/components/otbr/__init__.py +++ b/homeassistant/components/otbr/__init__.py @@ -2,11 +2,17 @@ from __future__ import annotations import asyncio +import contextlib import aiohttp import python_otbr_api -from homeassistant.components.thread import async_add_dataset +from homeassistant.components.thread import ( + async_add_dataset, + async_get_preferred_border_agent_id, + async_get_preferred_dataset, + async_set_preferred_border_agent_id, +) from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError @@ -46,6 +52,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: if dataset_tlvs: await update_issues(hass, otbrdata, dataset_tlvs) await async_add_dataset(hass, DOMAIN, dataset_tlvs.hex()) + # If this OTBR's dataset is the preferred one, and there is no preferred router, + # make this the preferred router + border_agent_id: bytes | None = None + with contextlib.suppress( + HomeAssistantError, aiohttp.ClientError, asyncio.TimeoutError + ): + border_agent_id = await otbrdata.get_border_agent_id() + if ( + await async_get_preferred_dataset(hass) == dataset_tlvs.hex() + and await async_get_preferred_border_agent_id(hass) is None + and border_agent_id + ): + await async_set_preferred_border_agent_id(hass, border_agent_id.hex()) entry.async_on_unload(entry.add_update_listener(async_reload_entry)) diff --git a/homeassistant/components/otbr/util.py b/homeassistant/components/otbr/util.py index 4d6efb9a9f0..67f36c09246 100644 --- a/homeassistant/components/otbr/util.py +++ b/homeassistant/components/otbr/util.py @@ -82,6 +82,11 @@ class OTBRData: ) await self.delete_active_dataset() + @_handle_otbr_error + async def get_border_agent_id(self) -> bytes: + """Get the border agent ID.""" + return await self.api.get_border_agent_id() + @_handle_otbr_error async def set_enabled(self, enabled: bool) -> None: """Enable or disable the router.""" diff --git a/tests/components/otbr/__init__.py b/tests/components/otbr/__init__.py index 9f2fd4a4355..a30275d3569 100644 --- a/tests/components/otbr/__init__.py +++ b/tests/components/otbr/__init__.py @@ -26,3 +26,5 @@ DATASET_INSECURE_PASSPHRASE = bytes.fromhex( "0A336069051000112233445566778899AABBCCDDEEFA030E4F70656E54687265616444656D6F01" "0212340410445F2B5CA6F2A93A55CE570A70EFEECB0C0402A0F7F8" ) + +TEST_BORDER_AGENT_ID = bytes.fromhex("230C6A1AC57F6F4BE262ACF32E5EF52C") diff --git a/tests/components/otbr/conftest.py b/tests/components/otbr/conftest.py index e7d5ac8980e..75922e99aa0 100644 --- a/tests/components/otbr/conftest.py +++ b/tests/components/otbr/conftest.py @@ -6,7 +6,12 @@ import pytest from homeassistant.components import otbr from homeassistant.core import HomeAssistant -from . import CONFIG_ENTRY_DATA_MULTIPAN, CONFIG_ENTRY_DATA_THREAD, DATASET_CH16 +from . import ( + CONFIG_ENTRY_DATA_MULTIPAN, + CONFIG_ENTRY_DATA_THREAD, + DATASET_CH16, + TEST_BORDER_AGENT_ID, +) from tests.common import MockConfigEntry @@ -23,6 +28,8 @@ async def otbr_config_entry_multipan_fixture(hass): config_entry.add_to_hass(hass) with patch( "python_otbr_api.OTBR.get_active_dataset_tlvs", return_value=DATASET_CH16 + ), patch( + "python_otbr_api.OTBR.get_border_agent_id", return_value=TEST_BORDER_AGENT_ID ), patch( "homeassistant.components.otbr.util.compute_pskc" ): # Patch to speed up tests @@ -41,6 +48,8 @@ async def otbr_config_entry_thread_fixture(hass): config_entry.add_to_hass(hass) with patch( "python_otbr_api.OTBR.get_active_dataset_tlvs", return_value=DATASET_CH16 + ), patch( + "python_otbr_api.OTBR.get_border_agent_id", return_value=TEST_BORDER_AGENT_ID ), patch( "homeassistant.components.otbr.util.compute_pskc" ): # Patch to speed up tests diff --git a/tests/components/otbr/test_init.py b/tests/components/otbr/test_init.py index 49694cf5585..63229f4b2e7 100644 --- a/tests/components/otbr/test_init.py +++ b/tests/components/otbr/test_init.py @@ -7,7 +7,7 @@ import aiohttp import pytest import python_otbr_api -from homeassistant.components import otbr +from homeassistant.components import otbr, thread from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import issue_registry as ir @@ -21,6 +21,7 @@ from . import ( DATASET_CH16, DATASET_INSECURE_NW_KEY, DATASET_INSECURE_PASSPHRASE, + TEST_BORDER_AGENT_ID, ) from tests.common import MockConfigEntry @@ -36,6 +37,8 @@ DATASET_NO_CHANNEL = bytes.fromhex( async def test_import_dataset(hass: HomeAssistant) -> None: """Test the active dataset is imported at setup.""" issue_registry = ir.async_get(hass) + assert await thread.async_get_preferred_border_agent_id(hass) is None + assert await thread.async_get_preferred_dataset(hass) is None config_entry = MockConfigEntry( data=CONFIG_ENTRY_DATA_MULTIPAN, @@ -47,11 +50,15 @@ async def test_import_dataset(hass: HomeAssistant) -> None: with patch( "python_otbr_api.OTBR.get_active_dataset_tlvs", return_value=DATASET_CH16 ), patch( - "homeassistant.components.thread.dataset_store.DatasetStore.async_add" - ) as mock_add: + "python_otbr_api.OTBR.get_border_agent_id", return_value=TEST_BORDER_AGENT_ID + ): assert await hass.config_entries.async_setup(config_entry.entry_id) - mock_add.assert_called_once_with(otbr.DOMAIN, DATASET_CH16.hex()) + assert ( + await thread.async_get_preferred_border_agent_id(hass) + == TEST_BORDER_AGENT_ID.hex() + ) + assert await thread.async_get_preferred_dataset(hass) == DATASET_CH16.hex() assert not issue_registry.async_get_issue( domain=otbr.DOMAIN, issue_id=f"insecure_thread_network_{config_entry.entry_id}" )