diff --git a/homeassistant/components/otbr/__init__.py b/homeassistant/components/otbr/__init__.py index 09a4499b60f..ac59bacbd97 100644 --- a/homeassistant/components/otbr/__init__.py +++ b/homeassistant/components/otbr/__init__.py @@ -7,12 +7,7 @@ import contextlib import aiohttp import python_otbr_api -from homeassistant.components.thread import ( - async_add_dataset, - async_get_preferred_border_agent_id, - async_get_preferred_dataset, - async_set_preferred_border_agent_id, -) +from homeassistant.components.thread import async_add_dataset from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError @@ -50,21 +45,20 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ) as err: raise ConfigEntryNotReady("Unable to connect") from err if dataset_tlvs: - await update_issues(hass, otbrdata, dataset_tlvs) - await async_add_dataset(hass, DOMAIN, dataset_tlvs.hex()) - # If this OTBR's dataset is the preferred one, and there is no preferred router, - # make this the preferred router - border_agent_id: bytes | None = None + border_agent_id: str | None = None with contextlib.suppress( HomeAssistantError, aiohttp.ClientError, asyncio.TimeoutError ): - border_agent_id = await otbrdata.get_border_agent_id() - if ( - await async_get_preferred_dataset(hass) == dataset_tlvs.hex() - and await async_get_preferred_border_agent_id(hass) is None - and border_agent_id - ): - await async_set_preferred_border_agent_id(hass, border_agent_id.hex()) + border_agent_bytes = await otbrdata.get_border_agent_id() + if border_agent_bytes: + border_agent_id = border_agent_bytes.hex() + await update_issues(hass, otbrdata, dataset_tlvs) + await async_add_dataset( + hass, + DOMAIN, + dataset_tlvs.hex(), + preferred_border_agent_id=border_agent_id, + ) entry.async_on_unload(entry.add_update_listener(async_reload_entry)) diff --git a/homeassistant/components/thread/__init__.py b/homeassistant/components/thread/__init__.py index 679127e5202..dd2527763ad 100644 --- a/homeassistant/components/thread/__init__.py +++ b/homeassistant/components/thread/__init__.py @@ -11,9 +11,7 @@ from .dataset_store import ( DatasetEntry, async_add_dataset, async_get_dataset, - async_get_preferred_border_agent_id, async_get_preferred_dataset, - async_set_preferred_border_agent_id, ) from .websocket_api import async_setup as async_setup_ws_api @@ -21,10 +19,8 @@ __all__ = [ "DOMAIN", "DatasetEntry", "async_add_dataset", - "async_get_preferred_border_agent_id", "async_get_dataset", "async_get_preferred_dataset", - "async_set_preferred_border_agent_id", ] CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) diff --git a/homeassistant/components/thread/dataset_store.py b/homeassistant/components/thread/dataset_store.py index 96a9cf8e59e..22e2c1822c1 100644 --- a/homeassistant/components/thread/dataset_store.py +++ b/homeassistant/components/thread/dataset_store.py @@ -33,6 +33,7 @@ class DatasetPreferredError(HomeAssistantError): class DatasetEntry: """Dataset store entry.""" + preferred_border_agent_id: str | None source: str tlv: str @@ -73,6 +74,7 @@ class DatasetEntry: return { "created": self.created.isoformat(), "id": self.id, + "preferred_border_agent_id": self.preferred_border_agent_id, "source": self.source, "tlv": self.tlv, } @@ -97,6 +99,7 @@ class DatasetStoreStore(Store): entry = DatasetEntry( created=created, id=dataset["id"], + preferred_border_agent_id=None, source=dataset["source"], tlv=dataset["tlv"], ) @@ -160,7 +163,8 @@ class DatasetStoreStore(Store): } if old_minor_version < 3: # Add border agent ID - data.setdefault("preferred_border_agent_id", None) + for dataset in data["datasets"]: + dataset.setdefault("preferred_border_agent_id", None) return data @@ -172,7 +176,6 @@ class DatasetStore: """Initialize the dataset store.""" self.hass = hass self.datasets: dict[str, DatasetEntry] = {} - self._preferred_border_agent_id: str | None = None self._preferred_dataset: str | None = None self._store: Store[dict[str, Any]] = DatasetStoreStore( hass, @@ -183,7 +186,9 @@ class DatasetStore: ) @callback - def async_add(self, source: str, tlv: str) -> None: + def async_add( + self, source: str, tlv: str, preferred_border_agent_id: str | None + ) -> None: """Add dataset, does nothing if it already exists.""" # Make sure the tlv is valid dataset = tlv_parser.parse_tlv(tlv) @@ -245,7 +250,9 @@ class DatasetStore: self.async_schedule_save() return - entry = DatasetEntry(source=source, tlv=tlv) + entry = DatasetEntry( + preferred_border_agent_id=preferred_border_agent_id, source=source, tlv=tlv + ) self.datasets[entry.id] = entry # Set to preferred if there is no preferred dataset if self._preferred_dataset is None: @@ -266,14 +273,13 @@ class DatasetStore: return self.datasets.get(dataset_id) @callback - def async_get_preferred_border_agent_id(self) -> str | None: - """Get preferred border agent id.""" - return self._preferred_border_agent_id - - @callback - def async_set_preferred_border_agent_id(self, border_agent_id: str) -> None: - """Set preferred border agent id.""" - self._preferred_border_agent_id = border_agent_id + def async_set_preferred_border_agent_id( + self, dataset_id: str, border_agent_id: str + ) -> None: + """Set preferred border agent id of a dataset.""" + self.datasets[dataset_id] = dataclasses.replace( + self.datasets[dataset_id], preferred_border_agent_id=border_agent_id + ) self.async_schedule_save() @property @@ -296,7 +302,6 @@ class DatasetStore: data = await self._store.async_load() datasets: dict[str, DatasetEntry] = {} - preferred_border_agent_id: str | None = None preferred_dataset: str | None = None if data is not None: @@ -305,14 +310,13 @@ class DatasetStore: datasets[dataset["id"]] = DatasetEntry( created=created, id=dataset["id"], + preferred_border_agent_id=dataset["preferred_border_agent_id"], source=dataset["source"], tlv=dataset["tlv"], ) - preferred_border_agent_id = data["preferred_border_agent_id"] preferred_dataset = data["preferred_dataset"] self.datasets = datasets - self._preferred_border_agent_id = preferred_border_agent_id self._preferred_dataset = preferred_dataset @callback @@ -325,7 +329,6 @@ class DatasetStore: """Return data of datasets to store in a file.""" data: dict[str, Any] = {} data["datasets"] = [dataset.to_json() for dataset in self.datasets.values()] - data["preferred_border_agent_id"] = self._preferred_border_agent_id data["preferred_dataset"] = self._preferred_dataset return data @@ -338,10 +341,16 @@ async def async_get_store(hass: HomeAssistant) -> DatasetStore: return store -async def async_add_dataset(hass: HomeAssistant, source: str, tlv: str) -> None: +async def async_add_dataset( + hass: HomeAssistant, + source: str, + tlv: str, + *, + preferred_border_agent_id: str | None = None, +) -> None: """Add a dataset.""" store = await async_get_store(hass) - store.async_add(source, tlv) + store.async_add(source, tlv, preferred_border_agent_id) async def async_get_dataset(hass: HomeAssistant, dataset_id: str) -> str | None: @@ -352,20 +361,6 @@ async def async_get_dataset(hass: HomeAssistant, dataset_id: str) -> str | None: return entry.tlv -async def async_get_preferred_border_agent_id(hass: HomeAssistant) -> str | None: - """Get the preferred border agent ID.""" - store = await async_get_store(hass) - return store.async_get_preferred_border_agent_id() - - -async def async_set_preferred_border_agent_id( - hass: HomeAssistant, border_agent_id: str -) -> None: - """Get the preferred border agent ID.""" - store = await async_get_store(hass) - store.async_set_preferred_border_agent_id(border_agent_id) - - async def async_get_preferred_dataset(hass: HomeAssistant) -> str | None: """Get the preferred dataset.""" store = await async_get_store(hass) diff --git a/homeassistant/components/thread/websocket_api.py b/homeassistant/components/thread/websocket_api.py index 853d8c3c893..5b289cf1694 100644 --- a/homeassistant/components/thread/websocket_api.py +++ b/homeassistant/components/thread/websocket_api.py @@ -20,7 +20,6 @@ def async_setup(hass: HomeAssistant) -> None: websocket_api.async_register_command(hass, ws_discover_routers) websocket_api.async_register_command(hass, ws_get_dataset) websocket_api.async_register_command(hass, ws_list_datasets) - websocket_api.async_register_command(hass, ws_get_preferred_border_agent_id) websocket_api.async_register_command(hass, ws_set_preferred_border_agent_id) websocket_api.async_register_command(hass, ws_set_preferred_dataset) @@ -52,25 +51,11 @@ async def ws_add_dataset( connection.send_result(msg["id"]) -@websocket_api.require_admin -@websocket_api.websocket_command( - { - vol.Required("type"): "thread/get_preferred_border_agent_id", - } -) -@websocket_api.async_response -async def ws_get_preferred_border_agent_id( - hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] -) -> None: - """Get the preferred border agent ID.""" - border_agent_id = await dataset_store.async_get_preferred_border_agent_id(hass) - connection.send_result(msg["id"], {"border_agent_id": border_agent_id}) - - @websocket_api.require_admin @websocket_api.websocket_command( { vol.Required("type"): "thread/set_preferred_border_agent_id", + vol.Required("dataset_id"): str, vol.Required("border_agent_id"): str, } ) @@ -79,8 +64,10 @@ async def ws_set_preferred_border_agent_id( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] ) -> None: """Set the preferred border agent ID.""" + dataset_id = msg["dataset_id"] border_agent_id = msg["border_agent_id"] - await dataset_store.async_set_preferred_border_agent_id(hass, border_agent_id) + store = await dataset_store.async_get_store(hass) + store.async_set_preferred_border_agent_id(dataset_id, border_agent_id) connection.send_result(msg["id"]) @@ -186,6 +173,7 @@ async def ws_list_datasets( "network_name": dataset.network_name, "pan_id": dataset.pan_id, "preferred": dataset.id == preferred_dataset, + "preferred_border_agent_id": dataset.preferred_border_agent_id, "source": dataset.source, } ) diff --git a/tests/components/otbr/test_init.py b/tests/components/otbr/test_init.py index 63229f4b2e7..18a60cfa196 100644 --- a/tests/components/otbr/test_init.py +++ b/tests/components/otbr/test_init.py @@ -37,7 +37,6 @@ DATASET_NO_CHANNEL = bytes.fromhex( async def test_import_dataset(hass: HomeAssistant) -> None: """Test the active dataset is imported at setup.""" issue_registry = ir.async_get(hass) - assert await thread.async_get_preferred_border_agent_id(hass) is None assert await thread.async_get_preferred_dataset(hass) is None config_entry = MockConfigEntry( @@ -54,8 +53,9 @@ async def test_import_dataset(hass: HomeAssistant) -> None: ): assert await hass.config_entries.async_setup(config_entry.entry_id) + dataset_store = await thread.dataset_store.async_get_store(hass) assert ( - await thread.async_get_preferred_border_agent_id(hass) + list(dataset_store.datasets.values())[0].preferred_border_agent_id == TEST_BORDER_AGENT_ID.hex() ) assert await thread.async_get_preferred_dataset(hass) == DATASET_CH16.hex() @@ -94,7 +94,7 @@ async def test_import_share_radio_channel_collision( ) as mock_add: assert await hass.config_entries.async_setup(config_entry.entry_id) - mock_add.assert_called_once_with(otbr.DOMAIN, DATASET_CH16.hex()) + mock_add.assert_called_once_with(otbr.DOMAIN, DATASET_CH16.hex(), None) assert issue_registry.async_get_issue( domain=otbr.DOMAIN, issue_id=f"otbr_zha_channel_collision_{config_entry.entry_id}", @@ -127,7 +127,7 @@ async def test_import_share_radio_no_channel_collision( ) as mock_add: assert await hass.config_entries.async_setup(config_entry.entry_id) - mock_add.assert_called_once_with(otbr.DOMAIN, dataset.hex()) + mock_add.assert_called_once_with(otbr.DOMAIN, dataset.hex(), None) assert not issue_registry.async_get_issue( domain=otbr.DOMAIN, issue_id=f"otbr_zha_channel_collision_{config_entry.entry_id}", @@ -158,7 +158,7 @@ async def test_import_insecure_dataset(hass: HomeAssistant, dataset: bytes) -> N ) as mock_add: assert await hass.config_entries.async_setup(config_entry.entry_id) - mock_add.assert_called_once_with(otbr.DOMAIN, dataset.hex()) + mock_add.assert_called_once_with(otbr.DOMAIN, dataset.hex(), None) assert issue_registry.async_get_issue( domain=otbr.DOMAIN, issue_id=f"insecure_thread_network_{config_entry.entry_id}" ) diff --git a/tests/components/otbr/test_websocket_api.py b/tests/components/otbr/test_websocket_api.py index d62213ce78b..f149e89cc45 100644 --- a/tests/components/otbr/test_websocket_api.py +++ b/tests/components/otbr/test_websocket_api.py @@ -109,7 +109,7 @@ async def test_create_network( assert set_enabled_mock.mock_calls[0][1][0] is False assert set_enabled_mock.mock_calls[1][1][0] is True get_active_dataset_tlvs_mock.assert_called_once() - mock_add.assert_called_once_with(otbr.DOMAIN, DATASET_CH16.hex()) + mock_add.assert_called_once_with(otbr.DOMAIN, DATASET_CH16.hex(), None) async def test_create_network_no_entry( diff --git a/tests/components/thread/test_dataset_store.py b/tests/components/thread/test_dataset_store.py index 1171c597e99..77102f92019 100644 --- a/tests/components/thread/test_dataset_store.py +++ b/tests/components/thread/test_dataset_store.py @@ -254,7 +254,7 @@ async def test_load_datasets(hass: HomeAssistant) -> None: store1 = await dataset_store.async_get_store(hass) for dataset in datasets: - store1.async_add(dataset["source"], dataset["tlv"]) + store1.async_add(dataset["source"], dataset["tlv"], None) assert len(store1.datasets) == 3 for dataset in store1.datasets.values(): @@ -303,33 +303,31 @@ async def test_loading_datasets_from_storage( { "created": "2023-02-02T09:41:13.746514+00:00", "id": "id1", + "preferred_border_agent_id": "230C6A1AC57F6F4BE262ACF32E5EF52C", "source": "source_1", "tlv": DATASET_1, }, { "created": "2023-02-02T09:41:13.746514+00:00", "id": "id2", + "preferred_border_agent_id": None, "source": "source_2", "tlv": DATASET_2, }, { "created": "2023-02-02T09:41:13.746514+00:00", "id": "id3", + "preferred_border_agent_id": None, "source": "source_3", "tlv": DATASET_3, }, ], - "preferred_border_agent_id": "230C6A1AC57F6F4BE262ACF32E5EF52C", "preferred_dataset": "id1", }, } store = await dataset_store.async_get_store(hass) assert len(store.datasets) == 3 - assert ( - store.async_get_preferred_border_agent_id() - == "230C6A1AC57F6F4BE262ACF32E5EF52C" - ) assert store.preferred_dataset == "id1" @@ -540,11 +538,17 @@ async def test_migrate_set_default_border_agent_id( } store = await dataset_store.async_get_store(hass) - assert store.async_get_preferred_border_agent_id() is None + assert store.datasets[store._preferred_dataset].preferred_border_agent_id is None -async def test_preferred_border_agent_id(hass: HomeAssistant) -> None: - """Test get and set the preferred border agent ID.""" - assert await dataset_store.async_get_preferred_border_agent_id(hass) is None - await dataset_store.async_set_preferred_border_agent_id(hass, "blah") - assert await dataset_store.async_get_preferred_border_agent_id(hass) == "blah" +async def test_set_preferred_border_agent_id(hass: HomeAssistant) -> None: + """Test set the preferred border agent ID of a dataset.""" + assert await dataset_store.async_get_preferred_dataset(hass) is None + + await dataset_store.async_add_dataset( + hass, "source", DATASET_1, preferred_border_agent_id="blah" + ) + + store = await dataset_store.async_get_store(hass) + assert len(store.datasets) == 1 + assert list(store.datasets.values())[0].preferred_border_agent_id == "blah" diff --git a/tests/components/thread/test_websocket_api.py b/tests/components/thread/test_websocket_api.py index 82450474e92..bfe71b8b21c 100644 --- a/tests/components/thread/test_websocket_api.py +++ b/tests/components/thread/test_websocket_api.py @@ -160,6 +160,7 @@ async def test_list_get_dataset( "network_name": "OpenThreadDemo", "pan_id": "1234", "preferred": True, + "preferred_border_agent_id": None, "source": "Google", }, { @@ -170,6 +171,7 @@ async def test_list_get_dataset( "network_name": "HomeAssistant!", "pan_id": "1234", "preferred": False, + "preferred_border_agent_id": None, "source": "Multipan", }, { @@ -180,6 +182,7 @@ async def test_list_get_dataset( "network_name": "~🐣🐥🐤~", "pan_id": "1234", "preferred": False, + "preferred_border_agent_id": None, "source": "🎅", }, ] @@ -200,33 +203,45 @@ async def test_list_get_dataset( assert msg["error"] == {"code": "not_found", "message": "unknown dataset"} -async def test_preferred_border_agent_id( +async def test_set_preferred_border_agent_id( hass: HomeAssistant, hass_ws_client: WebSocketGenerator ) -> None: - """Test setting and getting the preferred border agent ID.""" + """Test setting the preferred border agent ID.""" assert await async_setup_component(hass, DOMAIN, {}) await hass.async_block_till_done() client = await hass_ws_client(hass) - await client.send_json_auto_id({"type": "thread/get_preferred_border_agent_id"}) - msg = await client.receive_json() - assert msg["success"] - assert msg["result"] == {"border_agent_id": None} - await client.send_json_auto_id( - {"type": "thread/set_preferred_border_agent_id", "border_agent_id": "blah"} + {"type": "thread/add_dataset_tlv", "source": "test", "tlv": DATASET_1} ) msg = await client.receive_json() assert msg["success"] assert msg["result"] is None - await client.send_json_auto_id({"type": "thread/get_preferred_border_agent_id"}) + await client.send_json_auto_id({"type": "thread/list_datasets"}) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == {"border_agent_id": "blah"} + datasets = msg["result"]["datasets"] + dataset_id = datasets[0]["dataset_id"] + assert datasets[0]["preferred_border_agent_id"] is None - assert await dataset_store.async_get_preferred_border_agent_id(hass) == "blah" + await client.send_json_auto_id( + { + "type": "thread/set_preferred_border_agent_id", + "dataset_id": dataset_id, + "border_agent_id": "blah", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] is None + + await client.send_json_auto_id({"type": "thread/list_datasets"}) + msg = await client.receive_json() + assert msg["success"] + datasets = msg["result"]["datasets"] + assert datasets[0]["preferred_border_agent_id"] == "blah" async def test_set_preferred_dataset(