mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 09:47:52 +00:00
Ensure dhcp can still discover new devices from device trackers (#66822)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
c46728c2b2
commit
a18d4c51ff
@ -16,12 +16,21 @@ from homeassistant.const import (
|
||||
)
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.entity import DeviceInfo, Entity, EntityCategory
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.entity_platform import EntityPlatform
|
||||
from homeassistant.helpers.typing import StateType
|
||||
|
||||
from .const import ATTR_HOST_NAME, ATTR_IP, ATTR_MAC, ATTR_SOURCE_TYPE, DOMAIN, LOGGER
|
||||
from .const import (
|
||||
ATTR_HOST_NAME,
|
||||
ATTR_IP,
|
||||
ATTR_MAC,
|
||||
ATTR_SOURCE_TYPE,
|
||||
CONNECTED_DEVICE_REGISTERED,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
)
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
@ -64,9 +73,33 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
return await component.async_unload_entry(entry)
|
||||
|
||||
|
||||
@callback
|
||||
def _async_connected_device_registered(
|
||||
hass: HomeAssistant, mac: str, ip_address: str | None, hostname: str | None
|
||||
) -> None:
|
||||
"""Register a newly seen connected device.
|
||||
|
||||
This is currently used by the dhcp integration
|
||||
to listen for newly registered connected devices
|
||||
for discovery.
|
||||
"""
|
||||
async_dispatcher_send(
|
||||
hass,
|
||||
CONNECTED_DEVICE_REGISTERED,
|
||||
{
|
||||
ATTR_IP: ip_address,
|
||||
ATTR_MAC: mac,
|
||||
ATTR_HOST_NAME: hostname,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def _async_register_mac(
|
||||
hass: HomeAssistant, domain: str, mac: str, unique_id: str
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
mac: str,
|
||||
unique_id: str,
|
||||
) -> None:
|
||||
"""Register a mac address with a unique ID."""
|
||||
data_key = "device_tracker_mac"
|
||||
@ -297,7 +330,17 @@ class ScannerEntity(BaseTrackerEntity):
|
||||
super().add_to_platform_start(hass, platform, parallel_updates)
|
||||
if self.mac_address and self.unique_id:
|
||||
_async_register_mac(
|
||||
hass, platform.platform_name, self.mac_address, self.unique_id
|
||||
hass,
|
||||
platform.platform_name,
|
||||
self.mac_address,
|
||||
self.unique_id,
|
||||
)
|
||||
if self.is_connected:
|
||||
_async_connected_device_registered(
|
||||
hass,
|
||||
self.mac_address,
|
||||
self.ip_address,
|
||||
self.hostname,
|
||||
)
|
||||
|
||||
@callback
|
||||
|
@ -37,3 +37,5 @@ ATTR_MAC: Final = "mac"
|
||||
ATTR_SOURCE_TYPE: Final = "source_type"
|
||||
ATTR_CONSIDER_HOME: Final = "consider_home"
|
||||
ATTR_IP: Final = "ip"
|
||||
|
||||
CONNECTED_DEVICE_REGISTERED: Final = "device_tracker_connected_device_registered"
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""The dhcp integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
import fnmatch
|
||||
@ -25,6 +26,7 @@ from homeassistant.components.device_tracker.const import (
|
||||
ATTR_IP,
|
||||
ATTR_MAC,
|
||||
ATTR_SOURCE_TYPE,
|
||||
CONNECTED_DEVICE_REGISTERED,
|
||||
DOMAIN as DEVICE_TRACKER_DOMAIN,
|
||||
SOURCE_TYPE_ROUTER,
|
||||
)
|
||||
@ -42,6 +44,7 @@ from homeassistant.helpers.device_registry import (
|
||||
async_get,
|
||||
format_mac,
|
||||
)
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.event import (
|
||||
async_track_state_added_domain,
|
||||
async_track_time_interval,
|
||||
@ -109,16 +112,23 @@ class DhcpServiceInfo(BaseServiceInfo):
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the dhcp component."""
|
||||
watchers: list[WatcherBase] = []
|
||||
address_data: dict[str, dict[str, str]] = {}
|
||||
integration_matchers = await async_get_dhcp(hass)
|
||||
|
||||
# For the passive classes we need to start listening
|
||||
# for state changes and connect the dispatchers before
|
||||
# everything else starts up or we will miss events
|
||||
for passive_cls in (DeviceTrackerRegisteredWatcher, DeviceTrackerWatcher):
|
||||
passive_watcher = passive_cls(hass, address_data, integration_matchers)
|
||||
await passive_watcher.async_start()
|
||||
watchers.append(passive_watcher)
|
||||
|
||||
async def _initialize(_):
|
||||
address_data = {}
|
||||
integration_matchers = await async_get_dhcp(hass)
|
||||
watchers = []
|
||||
|
||||
for cls in (DHCPWatcher, DeviceTrackerWatcher, NetworkWatcher):
|
||||
watcher = cls(hass, address_data, integration_matchers)
|
||||
await watcher.async_start()
|
||||
watchers.append(watcher)
|
||||
for active_cls in (DHCPWatcher, NetworkWatcher):
|
||||
active_watcher = active_cls(hass, address_data, integration_matchers)
|
||||
await active_watcher.async_start()
|
||||
watchers.append(active_watcher)
|
||||
|
||||
async def _async_stop(*_):
|
||||
for watcher in watchers:
|
||||
@ -141,6 +151,14 @@ class WatcherBase:
|
||||
self._integration_matchers = integration_matchers
|
||||
self._address_data = address_data
|
||||
|
||||
@abstractmethod
|
||||
async def async_stop(self):
|
||||
"""Stop the watcher."""
|
||||
|
||||
@abstractmethod
|
||||
async def async_start(self):
|
||||
"""Start the watcher."""
|
||||
|
||||
def process_client(self, ip_address, hostname, mac_address):
|
||||
"""Process a client."""
|
||||
return run_callback_threadsafe(
|
||||
@ -320,6 +338,39 @@ class DeviceTrackerWatcher(WatcherBase):
|
||||
self.async_process_client(ip_address, hostname, _format_mac(mac_address))
|
||||
|
||||
|
||||
class DeviceTrackerRegisteredWatcher(WatcherBase):
|
||||
"""Class to watch data from device tracker registrations."""
|
||||
|
||||
def __init__(self, hass, address_data, integration_matchers):
|
||||
"""Initialize class."""
|
||||
super().__init__(hass, address_data, integration_matchers)
|
||||
self._unsub = None
|
||||
|
||||
async def async_stop(self):
|
||||
"""Stop watching for device tracker registrations."""
|
||||
if self._unsub:
|
||||
self._unsub()
|
||||
self._unsub = None
|
||||
|
||||
async def async_start(self):
|
||||
"""Stop watching for device tracker registrations."""
|
||||
self._unsub = async_dispatcher_connect(
|
||||
self.hass, CONNECTED_DEVICE_REGISTERED, self._async_process_device_state
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_process_device_state(self, data: dict[str, Any]) -> None:
|
||||
"""Process a device tracker state."""
|
||||
ip_address = data.get(ATTR_IP)
|
||||
hostname = data.get(ATTR_HOST_NAME, "")
|
||||
mac_address = data.get(ATTR_MAC)
|
||||
|
||||
if ip_address is None or mac_address is None:
|
||||
return
|
||||
|
||||
self.async_process_client(ip_address, hostname, _format_mac(mac_address))
|
||||
|
||||
|
||||
class DHCPWatcher(WatcherBase):
|
||||
"""Class to watch dhcp requests."""
|
||||
|
||||
|
@ -1,8 +1,15 @@
|
||||
"""Test Device Tracker config entry things."""
|
||||
from homeassistant.components.device_tracker import DOMAIN, config_entry as ce
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
MockEntityPlatform,
|
||||
MockPlatform,
|
||||
mock_registry,
|
||||
)
|
||||
|
||||
|
||||
def test_tracker_entity():
|
||||
@ -128,3 +135,87 @@ async def test_register_mac(hass):
|
||||
entity_entry_1 = ent_reg.async_get(entity_entry_1.entity_id)
|
||||
|
||||
assert entity_entry_1.disabled_by is None
|
||||
|
||||
|
||||
async def test_connected_device_registered(hass):
|
||||
"""Test dispatch on connected device being registered."""
|
||||
|
||||
registry = mock_registry(hass)
|
||||
dispatches = []
|
||||
|
||||
@callback
|
||||
def _save_dispatch(msg):
|
||||
dispatches.append(msg)
|
||||
|
||||
unsub = async_dispatcher_connect(
|
||||
hass, ce.CONNECTED_DEVICE_REGISTERED, _save_dispatch
|
||||
)
|
||||
|
||||
class MockScannerEntity(ce.ScannerEntity):
|
||||
"""Mock a scanner entity."""
|
||||
|
||||
@property
|
||||
def ip_address(self) -> str:
|
||||
return "5.4.3.2"
|
||||
|
||||
@property
|
||||
def unique_id(self) -> str:
|
||||
return self.mac_address
|
||||
|
||||
class MockDisconnectedScannerEntity(MockScannerEntity):
|
||||
"""Mock a disconnected scanner entity."""
|
||||
|
||||
@property
|
||||
def mac_address(self) -> str:
|
||||
return "aa:bb:cc:dd:ee:ff"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def hostname(self) -> str:
|
||||
return "connected"
|
||||
|
||||
class MockConnectedScannerEntity(MockScannerEntity):
|
||||
"""Mock a disconnected scanner entity."""
|
||||
|
||||
@property
|
||||
def mac_address(self) -> str:
|
||||
return "aa:bb:cc:dd:ee:00"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def hostname(self) -> str:
|
||||
return "disconnected"
|
||||
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Mock setup entry method."""
|
||||
async_add_entities(
|
||||
[MockConnectedScannerEntity(), MockDisconnectedScannerEntity()]
|
||||
)
|
||||
return True
|
||||
|
||||
platform = MockPlatform(async_setup_entry=async_setup_entry)
|
||||
config_entry = MockConfigEntry(entry_id="super-mock-id")
|
||||
entity_platform = MockEntityPlatform(
|
||||
hass, platform_name=config_entry.domain, platform=platform
|
||||
)
|
||||
|
||||
assert await entity_platform.async_setup_entry(config_entry)
|
||||
await hass.async_block_till_done()
|
||||
full_name = f"{entity_platform.domain}.{config_entry.domain}"
|
||||
assert full_name in hass.config.components
|
||||
assert len(hass.states.async_entity_ids()) == 0 # should be disabled
|
||||
assert len(registry.entities) == 2
|
||||
assert (
|
||||
registry.entities["test_domain.test_aa_bb_cc_dd_ee_ff"].config_entry_id
|
||||
== "super-mock-id"
|
||||
)
|
||||
unsub()
|
||||
assert dispatches == [
|
||||
{"ip": "5.4.3.2", "mac": "aa:bb:cc:dd:ee:ff", "host_name": "connected"}
|
||||
]
|
||||
|
@ -16,6 +16,7 @@ from homeassistant.components.device_tracker.const import (
|
||||
ATTR_IP,
|
||||
ATTR_MAC,
|
||||
ATTR_SOURCE_TYPE,
|
||||
CONNECTED_DEVICE_REGISTERED,
|
||||
SOURCE_TYPE_ROUTER,
|
||||
)
|
||||
from homeassistant.components.dhcp.const import DOMAIN
|
||||
@ -26,6 +27,7 @@ from homeassistant.const import (
|
||||
STATE_NOT_HOME,
|
||||
)
|
||||
import homeassistant.helpers.device_registry as dr
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
@ -630,6 +632,37 @@ async def test_device_tracker_hostname_and_macaddress_exists_before_start(hass):
|
||||
)
|
||||
|
||||
|
||||
async def test_device_tracker_registered(hass):
|
||||
"""Test matching based on hostname and macaddress when registered."""
|
||||
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
|
||||
device_tracker_watcher = dhcp.DeviceTrackerRegisteredWatcher(
|
||||
hass,
|
||||
{},
|
||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||
)
|
||||
await device_tracker_watcher.async_start()
|
||||
await hass.async_block_till_done()
|
||||
async_dispatcher_send(
|
||||
hass,
|
||||
CONNECTED_DEVICE_REGISTERED,
|
||||
{"ip": "192.168.210.56", "mac": "b8b7f16db533", "host_name": "connect"},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(mock_init.mock_calls) == 1
|
||||
assert mock_init.mock_calls[0][1][0] == "mock-domain"
|
||||
assert mock_init.mock_calls[0][2]["context"] == {
|
||||
"source": config_entries.SOURCE_DHCP
|
||||
}
|
||||
assert mock_init.mock_calls[0][2]["data"] == dhcp.DhcpServiceInfo(
|
||||
ip="192.168.210.56",
|
||||
hostname="connect",
|
||||
macaddress="b8b7f16db533",
|
||||
)
|
||||
await device_tracker_watcher.async_stop()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_device_tracker_hostname_and_macaddress_after_start(hass):
|
||||
"""Test matching based on hostname and macaddress after start."""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user