diff --git a/homeassistant/components/thread/diagnostics.py b/homeassistant/components/thread/diagnostics.py index b945f818d00..eb1e2a5ef68 100644 --- a/homeassistant/components/thread/diagnostics.py +++ b/homeassistant/components/thread/diagnostics.py @@ -17,9 +17,8 @@ some of their thread accessories can't be pinged, but it's still a thread proble from __future__ import annotations -from typing import Any, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict -from pyroute2 import NDB # pylint: disable=no-name-in-module from python_otbr_api.tlv_parser import MeshcopTLVType from homeassistant.components import zeroconf @@ -29,6 +28,9 @@ from homeassistant.core import HomeAssistant from .dataset_store import async_get_store from .discovery import async_read_zeroconf_cache +if TYPE_CHECKING: + from pyroute2 import NDB # pylint: disable=no-name-in-module + class Neighbour(TypedDict): """A neighbour cache entry (ip neigh).""" @@ -67,58 +69,69 @@ class Network(TypedDict): unexpected_routers: set[str] -def _get_possible_thread_routes() -> ( - tuple[dict[str, dict[str, Route]], dict[str, set[str]]] -): +def _get_possible_thread_routes( + ndb: NDB, +) -> tuple[dict[str, dict[str, Route]], dict[str, set[str]]]: # Build a list of possible thread routes # Right now, this is ipv6 /64's that have a gateway # We cross reference with zerconf data to confirm which via's are known border routers routes: dict[str, dict[str, Route]] = {} reverse_routes: dict[str, set[str]] = {} - with NDB() as ndb: - for record in ndb.routes: - # Limit to IPV6 routes - if record.family != 10: - continue - # Limit to /64 prefixes - if record.dst_len != 64: - continue - # Limit to routes with a via - if not record.gateway and not record.nh_gateway: - continue - gateway = record.gateway or record.nh_gateway - route = routes.setdefault(gateway, {}) - route[record.dst] = { - "metrics": record.metrics, - "priority": record.priority, - # NM creates "nexthop" routes - a single route with many via's - # Kernel creates many routes with a single via - "is_nexthop": record.nh_gateway is not None, - } - reverse_routes.setdefault(record.dst, set()).add(gateway) + for record in ndb.routes: + # Limit to IPV6 routes + if record.family != 10: + continue + # Limit to /64 prefixes + if record.dst_len != 64: + continue + # Limit to routes with a via + if not record.gateway and not record.nh_gateway: + continue + gateway = record.gateway or record.nh_gateway + route = routes.setdefault(gateway, {}) + route[record.dst] = { + "metrics": record.metrics, + "priority": record.priority, + # NM creates "nexthop" routes - a single route with many via's + # Kernel creates many routes with a single via + "is_nexthop": record.nh_gateway is not None, + } + reverse_routes.setdefault(record.dst, set()).add(gateway) return routes, reverse_routes -def _get_neighbours() -> dict[str, Neighbour]: - neighbours: dict[str, Neighbour] = {} - - with NDB() as ndb: - for record in ndb.neighbours: - neighbours[record.dst] = { - "lladdr": record.lladdr, - "state": record.state, - "probes": record.probes, - } - +def _get_neighbours(ndb: NDB) -> dict[str, Neighbour]: + # Build a list of neighbours + neighbours: dict[str, Neighbour] = { + record.dst: { + "lladdr": record.lladdr, + "state": record.state, + "probes": record.probes, + } + for record in ndb.neighbours + } return neighbours +def _get_routes_and_neighbors(): + """Get the routes and neighbours from pyroute2.""" + # Import in the executor since import NDB can take a while + from pyroute2 import ( # pylint: disable=no-name-in-module, import-outside-toplevel + NDB, + ) + + with NDB() as ndb: # pylint: disable=not-callable + routes, reverse_routes = _get_possible_thread_routes(ndb) + neighbours = _get_neighbours(ndb) + + return routes, reverse_routes, neighbours + + async def async_get_config_entry_diagnostics( hass: HomeAssistant, entry: ConfigEntry ) -> dict[str, Any]: """Return diagnostics for all known thread networks.""" - networks: dict[str, Network] = {} # Start with all networks that HA knows about @@ -140,13 +153,12 @@ async def async_get_config_entry_diagnostics( # Find all routes currently act that might be thread related, so we can match them to # border routers as we process the zeroconf data. - routes, reverse_routes = await hass.async_add_executor_job( - _get_possible_thread_routes + # + # Also find all neighbours + routes, reverse_routes, neighbours = await hass.async_add_executor_job( + _get_routes_and_neighbors ) - # Find all neighbours - neighbours = await hass.async_add_executor_job(_get_neighbours) - aiozc = await zeroconf.async_get_async_instance(hass) for data in async_read_zeroconf_cache(aiozc): if not data.extended_pan_id: diff --git a/tests/components/thread/test_diagnostics.py b/tests/components/thread/test_diagnostics.py index 1006fa374c3..a551315205b 100644 --- a/tests/components/thread/test_diagnostics.py +++ b/tests/components/thread/test_diagnostics.py @@ -133,9 +133,7 @@ class MockNeighbour: @pytest.fixture def ndb() -> Mock: """Prevent NDB poking the OS route tables.""" - with patch( - "homeassistant.components.thread.diagnostics.NDB" - ) as ndb, ndb() as instance: + with patch("pyroute2.NDB") as ndb, ndb() as instance: instance.neighbours = [] instance.routes = [] yield instance