Ensure dhcp can still discover new devices from device trackers (#66822)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
J. Nick Koston 2022-02-19 09:01:34 -06:00 committed by GitHub
parent c46728c2b2
commit a18d4c51ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 232 additions and 12 deletions

View File

@ -16,12 +16,21 @@ from homeassistant.const import (
)
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.entity import DeviceInfo, Entity, EntityCategory
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.entity_platform import EntityPlatform
from homeassistant.helpers.typing import StateType
from .const import ATTR_HOST_NAME, ATTR_IP, ATTR_MAC, ATTR_SOURCE_TYPE, DOMAIN, LOGGER
from .const import (
ATTR_HOST_NAME,
ATTR_IP,
ATTR_MAC,
ATTR_SOURCE_TYPE,
CONNECTED_DEVICE_REGISTERED,
DOMAIN,
LOGGER,
)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
@ -64,9 +73,33 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return await component.async_unload_entry(entry)
@callback
def _async_connected_device_registered(
hass: HomeAssistant, mac: str, ip_address: str | None, hostname: str | None
) -> None:
"""Register a newly seen connected device.
This is currently used by the dhcp integration
to listen for newly registered connected devices
for discovery.
"""
async_dispatcher_send(
hass,
CONNECTED_DEVICE_REGISTERED,
{
ATTR_IP: ip_address,
ATTR_MAC: mac,
ATTR_HOST_NAME: hostname,
},
)
@callback
def _async_register_mac(
hass: HomeAssistant, domain: str, mac: str, unique_id: str
hass: HomeAssistant,
domain: str,
mac: str,
unique_id: str,
) -> None:
"""Register a mac address with a unique ID."""
data_key = "device_tracker_mac"
@ -297,8 +330,18 @@ class ScannerEntity(BaseTrackerEntity):
super().add_to_platform_start(hass, platform, parallel_updates)
if self.mac_address and self.unique_id:
_async_register_mac(
hass, platform.platform_name, self.mac_address, self.unique_id
hass,
platform.platform_name,
self.mac_address,
self.unique_id,
)
if self.is_connected:
_async_connected_device_registered(
hass,
self.mac_address,
self.ip_address,
self.hostname,
)
@callback
def find_device_entry(self) -> dr.DeviceEntry | None:

View File

@ -37,3 +37,5 @@ ATTR_MAC: Final = "mac"
ATTR_SOURCE_TYPE: Final = "source_type"
ATTR_CONSIDER_HOME: Final = "consider_home"
ATTR_IP: Final = "ip"
CONNECTED_DEVICE_REGISTERED: Final = "device_tracker_connected_device_registered"

View File

@ -1,6 +1,7 @@
"""The dhcp integration."""
from __future__ import annotations
from abc import abstractmethod
from dataclasses import dataclass
from datetime import timedelta
import fnmatch
@ -25,6 +26,7 @@ from homeassistant.components.device_tracker.const import (
ATTR_IP,
ATTR_MAC,
ATTR_SOURCE_TYPE,
CONNECTED_DEVICE_REGISTERED,
DOMAIN as DEVICE_TRACKER_DOMAIN,
SOURCE_TYPE_ROUTER,
)
@ -42,6 +44,7 @@ from homeassistant.helpers.device_registry import (
async_get,
format_mac,
)
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import (
async_track_state_added_domain,
async_track_time_interval,
@ -109,16 +112,23 @@ class DhcpServiceInfo(BaseServiceInfo):
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the dhcp component."""
watchers: list[WatcherBase] = []
address_data: dict[str, dict[str, str]] = {}
integration_matchers = await async_get_dhcp(hass)
# For the passive classes we need to start listening
# for state changes and connect the dispatchers before
# everything else starts up or we will miss events
for passive_cls in (DeviceTrackerRegisteredWatcher, DeviceTrackerWatcher):
passive_watcher = passive_cls(hass, address_data, integration_matchers)
await passive_watcher.async_start()
watchers.append(passive_watcher)
async def _initialize(_):
address_data = {}
integration_matchers = await async_get_dhcp(hass)
watchers = []
for cls in (DHCPWatcher, DeviceTrackerWatcher, NetworkWatcher):
watcher = cls(hass, address_data, integration_matchers)
await watcher.async_start()
watchers.append(watcher)
for active_cls in (DHCPWatcher, NetworkWatcher):
active_watcher = active_cls(hass, address_data, integration_matchers)
await active_watcher.async_start()
watchers.append(active_watcher)
async def _async_stop(*_):
for watcher in watchers:
@ -141,6 +151,14 @@ class WatcherBase:
self._integration_matchers = integration_matchers
self._address_data = address_data
@abstractmethod
async def async_stop(self):
"""Stop the watcher."""
@abstractmethod
async def async_start(self):
"""Start the watcher."""
def process_client(self, ip_address, hostname, mac_address):
"""Process a client."""
return run_callback_threadsafe(
@ -320,6 +338,39 @@ class DeviceTrackerWatcher(WatcherBase):
self.async_process_client(ip_address, hostname, _format_mac(mac_address))
class DeviceTrackerRegisteredWatcher(WatcherBase):
"""Class to watch data from device tracker registrations."""
def __init__(self, hass, address_data, integration_matchers):
"""Initialize class."""
super().__init__(hass, address_data, integration_matchers)
self._unsub = None
async def async_stop(self):
"""Stop watching for device tracker registrations."""
if self._unsub:
self._unsub()
self._unsub = None
async def async_start(self):
"""Stop watching for device tracker registrations."""
self._unsub = async_dispatcher_connect(
self.hass, CONNECTED_DEVICE_REGISTERED, self._async_process_device_state
)
@callback
def _async_process_device_state(self, data: dict[str, Any]) -> None:
"""Process a device tracker state."""
ip_address = data.get(ATTR_IP)
hostname = data.get(ATTR_HOST_NAME, "")
mac_address = data.get(ATTR_MAC)
if ip_address is None or mac_address is None:
return
self.async_process_client(ip_address, hostname, _format_mac(mac_address))
class DHCPWatcher(WatcherBase):
"""Class to watch dhcp requests."""

View File

@ -1,8 +1,15 @@
"""Test Device Tracker config entry things."""
from homeassistant.components.device_tracker import DOMAIN, config_entry as ce
from homeassistant.core import callback
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from tests.common import MockConfigEntry
from tests.common import (
MockConfigEntry,
MockEntityPlatform,
MockPlatform,
mock_registry,
)
def test_tracker_entity():
@ -128,3 +135,87 @@ async def test_register_mac(hass):
entity_entry_1 = ent_reg.async_get(entity_entry_1.entity_id)
assert entity_entry_1.disabled_by is None
async def test_connected_device_registered(hass):
"""Test dispatch on connected device being registered."""
registry = mock_registry(hass)
dispatches = []
@callback
def _save_dispatch(msg):
dispatches.append(msg)
unsub = async_dispatcher_connect(
hass, ce.CONNECTED_DEVICE_REGISTERED, _save_dispatch
)
class MockScannerEntity(ce.ScannerEntity):
"""Mock a scanner entity."""
@property
def ip_address(self) -> str:
return "5.4.3.2"
@property
def unique_id(self) -> str:
return self.mac_address
class MockDisconnectedScannerEntity(MockScannerEntity):
"""Mock a disconnected scanner entity."""
@property
def mac_address(self) -> str:
return "aa:bb:cc:dd:ee:ff"
@property
def is_connected(self) -> bool:
return True
@property
def hostname(self) -> str:
return "connected"
class MockConnectedScannerEntity(MockScannerEntity):
"""Mock a disconnected scanner entity."""
@property
def mac_address(self) -> str:
return "aa:bb:cc:dd:ee:00"
@property
def is_connected(self) -> bool:
return False
@property
def hostname(self) -> str:
return "disconnected"
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Mock setup entry method."""
async_add_entities(
[MockConnectedScannerEntity(), MockDisconnectedScannerEntity()]
)
return True
platform = MockPlatform(async_setup_entry=async_setup_entry)
config_entry = MockConfigEntry(entry_id="super-mock-id")
entity_platform = MockEntityPlatform(
hass, platform_name=config_entry.domain, platform=platform
)
assert await entity_platform.async_setup_entry(config_entry)
await hass.async_block_till_done()
full_name = f"{entity_platform.domain}.{config_entry.domain}"
assert full_name in hass.config.components
assert len(hass.states.async_entity_ids()) == 0 # should be disabled
assert len(registry.entities) == 2
assert (
registry.entities["test_domain.test_aa_bb_cc_dd_ee_ff"].config_entry_id
== "super-mock-id"
)
unsub()
assert dispatches == [
{"ip": "5.4.3.2", "mac": "aa:bb:cc:dd:ee:ff", "host_name": "connected"}
]

View File

@ -16,6 +16,7 @@ from homeassistant.components.device_tracker.const import (
ATTR_IP,
ATTR_MAC,
ATTR_SOURCE_TYPE,
CONNECTED_DEVICE_REGISTERED,
SOURCE_TYPE_ROUTER,
)
from homeassistant.components.dhcp.const import DOMAIN
@ -26,6 +27,7 @@ from homeassistant.const import (
STATE_NOT_HOME,
)
import homeassistant.helpers.device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
@ -630,6 +632,37 @@ async def test_device_tracker_hostname_and_macaddress_exists_before_start(hass):
)
async def test_device_tracker_registered(hass):
"""Test matching based on hostname and macaddress when registered."""
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
device_tracker_watcher = dhcp.DeviceTrackerRegisteredWatcher(
hass,
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
)
await device_tracker_watcher.async_start()
await hass.async_block_till_done()
async_dispatcher_send(
hass,
CONNECTED_DEVICE_REGISTERED,
{"ip": "192.168.210.56", "mac": "b8b7f16db533", "host_name": "connect"},
)
await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain"
assert mock_init.mock_calls[0][2]["context"] == {
"source": config_entries.SOURCE_DHCP
}
assert mock_init.mock_calls[0][2]["data"] == dhcp.DhcpServiceInfo(
ip="192.168.210.56",
hostname="connect",
macaddress="b8b7f16db533",
)
await device_tracker_watcher.async_stop()
await hass.async_block_till_done()
async def test_device_tracker_hostname_and_macaddress_after_start(hass):
"""Test matching based on hostname and macaddress after start."""