Make device tracker entities work better (#63328)

Co-authored-by: Franck Nijhof <git@frenck.dev>
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Paulus Schoutsen 2022-01-04 23:16:43 -08:00 committed by GitHub
parent d4310f0d70
commit 2b4bb49eb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 546 additions and 357 deletions

View File

@ -4,11 +4,8 @@ from __future__ import annotations
from homeassistant.components.device_tracker import SOURCE_TYPE_ROUTER from homeassistant.components.device_tracker import SOURCE_TYPE_ROUTER
from homeassistant.components.device_tracker.config_entry import ScannerEntity from homeassistant.components.device_tracker.config_entry import ScannerEntity
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_DEFAULT_NAME
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo
from .const import DATA_ASUSWRT, DOMAIN from .const import DATA_ASUSWRT, DOMAIN
from .router import AsusWrtRouter from .router import AsusWrtRouter
@ -62,12 +59,6 @@ class AsusWrtDevice(ScannerEntity):
self._device = device self._device = device
self._attr_unique_id = device.mac self._attr_unique_id = device.mac
self._attr_name = device.name or DEFAULT_DEVICE_NAME self._attr_name = device.name or DEFAULT_DEVICE_NAME
self._attr_device_info = DeviceInfo(
connections={(CONNECTION_NETWORK_MAC, device.mac)},
default_model="ASUSWRT Tracked device",
)
if device.name:
self._attr_device_info[ATTR_DEFAULT_NAME] = device.name
@property @property
def is_connected(self): def is_connected(self):

View File

@ -1,4 +1,6 @@
"""Provide functionality to keep track of devices.""" """Provide functionality to keep track of devices."""
from __future__ import annotations
from homeassistant.const import ATTR_GPS_ACCURACY, STATE_HOME # noqa: F401 from homeassistant.const import ATTR_GPS_ACCURACY, STATE_HOME # noqa: F401
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType

View File

@ -1,6 +1,7 @@
"""Code to set up a device tracker platform using a config entry.""" """Code to set up a device tracker platform using a config entry."""
from __future__ import annotations from __future__ import annotations
import asyncio
from typing import final from typing import final
from homeassistant.components import zone from homeassistant.components import zone
@ -13,9 +14,11 @@ from homeassistant.const import (
STATE_HOME, STATE_HOME,
STATE_NOT_HOME, STATE_NOT_HOME,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers.entity import Entity from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.entity import DeviceInfo, Entity, EntityCategory
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.entity_platform import EntityPlatform
from homeassistant.helpers.typing import StateType from homeassistant.helpers.typing import StateType
from .const import ATTR_HOST_NAME, ATTR_IP, ATTR_MAC, ATTR_SOURCE_TYPE, DOMAIN, LOGGER from .const import ATTR_HOST_NAME, ATTR_IP, ATTR_MAC, ATTR_SOURCE_TYPE, DOMAIN, LOGGER
@ -25,8 +28,32 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up an entry.""" """Set up an entry."""
component: EntityComponent | None = hass.data.get(DOMAIN) component: EntityComponent | None = hass.data.get(DOMAIN)
if component is None: if component is not None:
component = hass.data[DOMAIN] = EntityComponent(LOGGER, DOMAIN, hass) return await component.async_setup_entry(entry)
component = hass.data[DOMAIN] = EntityComponent(LOGGER, DOMAIN, hass)
# Clean up old devices created by device tracker entities in the past.
# Can be removed after 2022.6
ent_reg = er.async_get(hass)
dev_reg = dr.async_get(hass)
devices_with_trackers = set()
devices_with_non_trackers = set()
for entity in ent_reg.entities.values():
if entity.device_id is None:
continue
if entity.domain == DOMAIN:
devices_with_trackers.add(entity.device_id)
else:
devices_with_non_trackers.add(entity.device_id)
for device_id in devices_with_trackers - devices_with_non_trackers:
for entity in er.async_entries_for_device(ent_reg, device_id, True):
ent_reg.async_update_entity(entity.entity_id, device_id=None)
dev_reg.async_remove_device(device_id)
return await component.async_setup_entry(entry) return await component.async_setup_entry(entry)
@ -37,9 +64,80 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return await component.async_unload_entry(entry) return await component.async_unload_entry(entry)
@callback
def _async_register_mac(
hass: HomeAssistant, domain: str, mac: str, unique_id: str
) -> None:
"""Register a mac address with a unique ID."""
data_key = "device_tracker_mac"
mac = dr.format_mac(mac)
if data_key in hass.data:
hass.data[data_key][mac] = (domain, unique_id)
return
# Setup listening.
# dict mapping mac -> partial unique ID
data = hass.data[data_key] = {mac: (domain, unique_id)}
@callback
def handle_device_event(ev: Event) -> None:
"""Enable the online status entity for the mac of a newly created device."""
# Only for new devices
if ev.data["action"] != "create":
return
dev_reg = dr.async_get(hass)
device_entry = dev_reg.async_get(ev.data["device_id"])
if device_entry is None:
return
# Check if device has a mac
mac = None
for conn in device_entry.connections:
if conn[0] == dr.CONNECTION_NETWORK_MAC:
mac = conn[1]
break
if mac is None:
return
# Check if we have an entity for this mac
if (unique_id := data.get(mac)) is None:
return
ent_reg = er.async_get(hass)
entity_id = ent_reg.async_get_entity_id(DOMAIN, *unique_id)
if entity_id is None:
return
entity_entry = ent_reg.async_get(entity_id)
if entity_entry is None:
return
# Make sure entity has a config entry and was disabled by the
# default disable logic in the integration.
if (
entity_entry.config_entry_id is None
or entity_entry.disabled_by != er.RegistryEntryDisabler.INTEGRATION
):
return
# Enable entity
ent_reg.async_update_entity(entity_id, disabled_by=None)
hass.bus.async_listen(dr.EVENT_DEVICE_REGISTRY_UPDATED, handle_device_event)
class BaseTrackerEntity(Entity): class BaseTrackerEntity(Entity):
"""Represent a tracked device.""" """Represent a tracked device."""
_attr_device_info: None = None
_attr_entity_category = EntityCategory.DIAGNOSTIC
@property @property
def battery_level(self) -> int | None: def battery_level(self) -> int | None:
"""Return the battery level of the device. """Return the battery level of the device.
@ -164,6 +262,86 @@ class ScannerEntity(BaseTrackerEntity):
"""Return true if the device is connected to the network.""" """Return true if the device is connected to the network."""
raise NotImplementedError raise NotImplementedError
@property
def unique_id(self) -> str | None:
"""Return unique ID of the entity."""
return self.mac_address
@final
@property
def device_info(self) -> DeviceInfo | None:
"""Device tracker entities should not create device registry entries."""
return None
@property
def entity_registry_enabled_default(self) -> bool:
"""Return if entity is enabled by default."""
# If mac_address is None, we can never find a device entry.
return (
# Do not disable if we won't activate our attach to device logic
self.mac_address is None
or self.device_info is not None
# Disable if we automatically attach but there is no device
or self.find_device_entry() is not None
)
@callback
def add_to_platform_start(
self,
hass: HomeAssistant,
platform: EntityPlatform,
parallel_updates: asyncio.Semaphore | None,
) -> None:
"""Start adding an entity to a platform."""
super().add_to_platform_start(hass, platform, parallel_updates)
if self.mac_address and self.unique_id:
_async_register_mac(
hass, platform.platform_name, self.mac_address, self.unique_id
)
@callback
def find_device_entry(self) -> dr.DeviceEntry | None:
"""Return device entry."""
assert self.mac_address is not None
return dr.async_get(self.hass).async_get_device(
set(), {(dr.CONNECTION_NETWORK_MAC, self.mac_address)}
)
async def async_internal_added_to_hass(self) -> None:
"""Handle added to Home Assistant."""
# Entities without a unique ID don't have a device
if (
not self.registry_entry
or not self.platform
or not self.platform.config_entry
or not self.mac_address
or (device_entry := self.find_device_entry()) is None
# Entities should not have a device info. We opt them out
# of this logic if they do.
or self.device_info
):
if self.device_info:
LOGGER.debug("Entity %s unexpectedly has a device info", self.entity_id)
await super().async_internal_added_to_hass()
return
# Attach entry to device
if self.registry_entry.device_id != device_entry.id:
self.registry_entry = er.async_get(self.hass).async_update_entity(
self.entity_id, device_id=device_entry.id
)
# Attach device to config entry
if self.platform.config_entry.entry_id not in device_entry.config_entries:
dr.async_get(self.hass).async_update_device(
device_entry.id,
add_config_entry_id=self.platform.config_entry.entry_id,
)
# Do this last or else the entity registry update listener has been installed
await super().async_internal_added_to_hass()
@final @final
@property @property
def state_attributes(self) -> dict[str, StateType]: def state_attributes(self) -> dict[str, StateType]:

View File

@ -8,9 +8,7 @@ from homeassistant.components.device_tracker import SOURCE_TYPE_ROUTER
from homeassistant.components.device_tracker.config_entry import ScannerEntity from homeassistant.components.device_tracker.config_entry import ScannerEntity
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo
from .const import DEFAULT_DEVICE_NAME, DEVICE_ICONS, DOMAIN from .const import DEFAULT_DEVICE_NAME, DEVICE_ICONS, DOMAIN
from .router import FreeboxRouter from .router import FreeboxRouter
@ -82,7 +80,7 @@ class FreeboxDevice(ScannerEntity):
self._attrs = device["attrs"] self._attrs = device["attrs"]
@property @property
def unique_id(self) -> str: def mac_address(self) -> str:
"""Return a unique ID.""" """Return a unique ID."""
return self._mac return self._mac
@ -111,16 +109,6 @@ class FreeboxDevice(ScannerEntity):
"""Return the attributes.""" """Return the attributes."""
return self._attrs return self._attrs
@property
def device_info(self) -> DeviceInfo:
"""Return the device information."""
return DeviceInfo(
connections={(CONNECTION_NETWORK_MAC, self._mac)},
identifiers={(DOMAIN, self.unique_id)},
manufacturer=self._manufacturer,
name=self.name,
)
@property @property
def should_poll(self) -> bool: def should_poll(self) -> bool:
"""No polling needed.""" """No polling needed."""

View File

@ -519,21 +519,6 @@ class FritzDeviceBase(update_coordinator.CoordinatorEntity):
return self._router.devices[self._mac].hostname return self._router.devices[self._mac].hostname
return None return None
@property
def device_info(self) -> DeviceInfo:
"""Return the device information."""
return DeviceInfo(
connections={(CONNECTION_NETWORK_MAC, self._mac)},
default_manufacturer="AVM",
default_model="FRITZ!Box Tracked device",
default_name=self.name,
identifiers={(DOMAIN, self._mac)},
via_device=(
DOMAIN,
self._router.unique_id,
),
)
@property @property
def should_poll(self) -> bool: def should_poll(self) -> bool:
"""No polling needed.""" """No polling needed."""

View File

@ -130,6 +130,11 @@ class FritzBoxTracker(FritzDeviceBase, ScannerEntity):
"""Return device unique id.""" """Return device unique id."""
return f"{self._mac}_tracker" return f"{self._mac}_tracker"
@property
def mac_address(self) -> str:
"""Return mac_address."""
return self._mac
@property @property
def icon(self) -> str: def icon(self) -> str:
"""Return device icon.""" """Return device icon."""

View File

@ -19,8 +19,9 @@ from homeassistant.components.network import async_get_source_ip
from homeassistant.components.switch import SwitchEntity from homeassistant.components.switch import SwitchEntity
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import Entity, EntityCategory from homeassistant.helpers.entity import DeviceInfo, Entity, EntityCategory
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import slugify from homeassistant.util import slugify
@ -605,6 +606,21 @@ class FritzBoxProfileSwitch(FritzDeviceBase, SwitchEntity):
"""Switch status.""" """Switch status."""
return self._router.devices[self._mac].wan_access return self._router.devices[self._mac].wan_access
@property
def device_info(self) -> DeviceInfo:
"""Return the device information."""
return DeviceInfo(
connections={(CONNECTION_NETWORK_MAC, self._mac)},
default_manufacturer="AVM",
default_model="FRITZ!Box Tracked device",
default_name=self.name,
identifiers={(DOMAIN, self._mac)},
via_device=(
DOMAIN,
self._router.unique_id,
),
)
async def async_turn_on(self, **kwargs: Any) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on switch.""" """Turn on switch."""
await self._async_handle_turn_on_off(turn_on=True) await self._async_handle_turn_on_off(turn_on=True)

View File

@ -653,14 +653,6 @@ class HuaweiLteBaseEntity(Entity):
"""Huawei LTE entities report their state without polling.""" """Huawei LTE entities report their state without polling."""
return False return False
@property
def device_info(self) -> DeviceInfo:
"""Get info for matching with parent router."""
return DeviceInfo(
connections=self.router.device_connections,
identifiers=self.router.device_identifiers,
)
async def async_update(self) -> None: async def async_update(self) -> None:
"""Update state.""" """Update state."""
raise NotImplementedError raise NotImplementedError
@ -681,3 +673,15 @@ class HuaweiLteBaseEntity(Entity):
for unsub in self._unsub_handlers: for unsub in self._unsub_handlers:
unsub() unsub()
self._unsub_handlers.clear() self._unsub_handlers.clear()
class HuaweiLteBaseEntityWithDevice(HuaweiLteBaseEntity):
"""Base entity with device info."""
@property
def device_info(self) -> DeviceInfo:
"""Get info for matching with parent router."""
return DeviceInfo(
connections=self.router.device_connections,
identifiers=self.router.device_identifiers,
)

View File

@ -16,7 +16,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import HuaweiLteBaseEntity from . import HuaweiLteBaseEntityWithDevice
from .const import ( from .const import (
DOMAIN, DOMAIN,
KEY_MONITORING_CHECK_NOTIFICATIONS, KEY_MONITORING_CHECK_NOTIFICATIONS,
@ -49,7 +49,7 @@ async def async_setup_entry(
@dataclass @dataclass
class HuaweiLteBaseBinarySensor(HuaweiLteBaseEntity, BinarySensorEntity): class HuaweiLteBaseBinarySensor(HuaweiLteBaseEntityWithDevice, BinarySensorEntity):
"""Huawei LTE binary sensor device base class.""" """Huawei LTE binary sensor device base class."""
key: str = field(init=False) key: str = field(init=False)

View File

@ -28,7 +28,7 @@ from homeassistant.helpers.entity import Entity, EntityCategory
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import StateType from homeassistant.helpers.typing import StateType
from . import HuaweiLteBaseEntity from . import HuaweiLteBaseEntityWithDevice
from .const import ( from .const import (
DOMAIN, DOMAIN,
KEY_DEVICE_INFORMATION, KEY_DEVICE_INFORMATION,
@ -523,7 +523,7 @@ def format_default(value: StateType) -> tuple[StateType, str | None]:
@dataclass @dataclass
class HuaweiLteSensor(HuaweiLteBaseEntity, SensorEntity): class HuaweiLteSensor(HuaweiLteBaseEntityWithDevice, SensorEntity):
"""Huawei LTE sensor entity.""" """Huawei LTE sensor entity."""
key: str key: str

View File

@ -15,7 +15,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import HuaweiLteBaseEntity from . import HuaweiLteBaseEntityWithDevice
from .const import DOMAIN, KEY_DIALUP_MOBILE_DATASWITCH from .const import DOMAIN, KEY_DIALUP_MOBILE_DATASWITCH
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -37,7 +37,7 @@ async def async_setup_entry(
@dataclass @dataclass
class HuaweiLteBaseSwitch(HuaweiLteBaseEntity, SwitchEntity): class HuaweiLteBaseSwitch(HuaweiLteBaseEntityWithDevice, SwitchEntity):
"""Huawei LTE switch device base class.""" """Huawei LTE switch device base class."""
key: str = field(init=False) key: str = field(init=False)

View File

@ -24,9 +24,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry from homeassistant.helpers import entity_registry
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import ( from .const import (
@ -217,15 +215,6 @@ class KeeneticTracker(ScannerEntity):
} }
return None return None
@property
def device_info(self) -> DeviceInfo:
"""Return a client description for device registry."""
return DeviceInfo(
connections={(CONNECTION_NETWORK_MAC, self._device.mac)},
identifiers={(DOMAIN, self._device.mac)},
name=self._device.name if self._device.name else None,
)
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Client entity created.""" """Client entity created."""
_LOGGER.debug("New network device tracker %s (%s)", self.name, self.unique_id) _LOGGER.debug("New network device tracker %s (%s)", self.name, self.unique_id)

View File

@ -8,9 +8,7 @@ from homeassistant.components.device_tracker.const import (
) )
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers import entity_registry from homeassistant.helpers import entity_registry
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import DOMAIN from .const import DOMAIN
@ -130,17 +128,6 @@ class MikrotikHubTracker(ScannerEntity):
return {k: v for k, v in self.device.attrs.items() if k not in FILTER_ATTRS} return {k: v for k, v in self.device.attrs.items() if k not in FILTER_ATTRS}
return None return None
@property
def device_info(self) -> DeviceInfo:
"""Return a client description for device registry."""
# We only get generic info from device discovery and so don't want
# to override API specific info that integrations can provide
return DeviceInfo(
connections={(CONNECTION_NETWORK_MAC, self.device.mac)},
default_name=self.name,
identifiers={(DOMAIN, self.device.mac)},
)
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Client entity created.""" """Client entity created."""
_LOGGER.debug("New network device tracker %s (%s)", self.name, self.unique_id) _LOGGER.debug("New network device tracker %s (%s)", self.name, self.unique_id)

View File

@ -22,9 +22,7 @@ from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
from homeassistant.const import CONF_EXCLUDE, CONF_HOSTS from homeassistant.const import CONF_EXCLUDE, CONF_HOSTS
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from . import NmapDevice, NmapDeviceScanner, short_hostname, signal_device_update from . import NmapDevice, NmapDeviceScanner, short_hostname, signal_device_update
@ -169,15 +167,6 @@ class NmapTrackerEntity(ScannerEntity):
"""Return tracker source type.""" """Return tracker source type."""
return SOURCE_TYPE_ROUTER return SOURCE_TYPE_ROUTER
@property
def device_info(self) -> DeviceInfo:
"""Return the device information."""
return DeviceInfo(
connections={(CONNECTION_NETWORK_MAC, self._mac_address)},
default_manufacturer=self._device.manufacturer,
default_name=self.name,
)
@property @property
def should_poll(self) -> bool: def should_poll(self) -> bool:
"""No polling needed.""" """No polling needed."""

View File

@ -6,12 +6,9 @@ from homeassistant.components.device_tracker.config_entry import ScannerEntity
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry from homeassistant.helpers import entity_registry
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import ( from .const import (
API_ACCESS_POINT,
API_CLIENTS, API_CLIENTS,
API_NAME, API_NAME,
COORDINATOR, COORDINATOR,
@ -93,8 +90,8 @@ class RuckusUnleashedDevice(CoordinatorEntity, ScannerEntity):
self._name = name self._name = name
@property @property
def unique_id(self) -> str: def mac_address(self) -> str:
"""Return a unique ID.""" """Return a mac address."""
return self._mac return self._mac
@property @property
@ -116,17 +113,3 @@ class RuckusUnleashedDevice(CoordinatorEntity, ScannerEntity):
def source_type(self) -> str: def source_type(self) -> str:
"""Return the source type.""" """Return the source type."""
return SOURCE_TYPE_ROUTER return SOURCE_TYPE_ROUTER
@property
def device_info(self) -> DeviceInfo | None:
"""Return the device information."""
if self.is_connected:
return DeviceInfo(
name=self.name,
connections={(CONNECTION_NETWORK_MAC, self._mac)},
via_device=(
CONNECTION_NETWORK_MAC,
self.coordinator.data[API_CLIENTS][self._mac][API_ACCESS_POINT],
),
)
return None

View File

@ -19,16 +19,12 @@ from homeassistant.components.device_tracker import DOMAIN
from homeassistant.components.device_tracker.config_entry import ScannerEntity from homeassistant.components.device_tracker.config_entry import ScannerEntity
from homeassistant.components.device_tracker.const import SOURCE_TYPE_ROUTER from homeassistant.components.device_tracker.const import SOURCE_TYPE_ROUTER
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_NAME
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import ATTR_MANUFACTURER, DOMAIN as UNIFI_DOMAIN from .const import DOMAIN as UNIFI_DOMAIN
from .unifi_client import UniFiClient from .unifi_client import UniFiClient
from .unifi_entity_base import UniFiBase from .unifi_entity_base import UniFiBase
@ -242,6 +238,11 @@ class UniFiClientTracker(UniFiClient, ScannerEntity):
self._is_connected = False self._is_connected = False
self.async_write_ha_state() self.async_write_ha_state()
@property
def device_info(self) -> None:
"""Return no device info."""
return None
@property @property
def is_connected(self): def is_connected(self):
"""Return true if the client is connected to the network.""" """Return true if the client is connected to the network."""
@ -365,13 +366,6 @@ class UniFiDeviceTracker(UniFiBase, ScannerEntity):
self._is_connected = True self._is_connected = True
self.schedule_update = True self.schedule_update = True
elif (
self.device.last_updated == SOURCE_EVENT
and self.device.event.event in DEVICE_UPGRADED
):
self.hass.async_create_task(self.async_update_device_registry())
return
if self.schedule_update: if self.schedule_update:
self.schedule_update = False self.schedule_update = False
self.controller.async_heartbeat( self.controller.async_heartbeat(
@ -412,28 +406,6 @@ class UniFiDeviceTracker(UniFiBase, ScannerEntity):
"""Return if controller is available.""" """Return if controller is available."""
return not self.device.disabled and self.controller.available return not self.device.disabled and self.controller.available
@property
def device_info(self) -> DeviceInfo:
"""Return a device description for device registry."""
info = DeviceInfo(
connections={(CONNECTION_NETWORK_MAC, self.device.mac)},
manufacturer=ATTR_MANUFACTURER,
model=self.device.model,
sw_version=self.device.version,
)
if self.device.name:
info[ATTR_NAME] = self.device.name
return info
async def async_update_device_registry(self) -> None:
"""Update device registry."""
device_registry = dr.async_get(self.hass)
device_registry.async_get_or_create(
config_entry_id=self.controller.config_entry.entry_id, **self.device_info
)
@property @property
def extra_state_attributes(self): def extra_state_attributes(self):
"""Return the device state attributes.""" """Return the device state attributes."""

View File

@ -6,7 +6,6 @@ from homeassistant.core import callback
from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_registry import async_entries_for_device
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -102,35 +101,10 @@ class UniFiBase(Entity):
entity_registry.async_remove(self.entity_id) entity_registry.async_remove(self.entity_id)
return return
if ( device_registry.async_update_device(
len( entity_entry.device_id,
entries_for_device := async_entries_for_device( remove_config_entry_id=self.controller.config_entry.entry_id,
entity_registry, )
entity_entry.device_id,
include_disabled_entities=True,
)
)
) == 1:
device_registry.async_remove_device(device_entry.id)
return
if (
len(
entries_for_device_from_this_config_entry := [
entry_for_device
for entry_for_device in entries_for_device
if entry_for_device.config_entry_id
== self.controller.config_entry.entry_id
]
)
!= len(entries_for_device)
and len(entries_for_device_from_this_config_entry) == 1
):
device_registry.async_update_device(
entity_entry.device_id,
remove_config_entry_id=self.controller.config_entry.entry_id,
)
entity_registry.async_remove(self.entity_id) entity_registry.async_remove(self.entity_id)
@property @property

View File

@ -1,4 +1,6 @@
"""Support for the ZHA platform.""" """Support for the ZHA platform."""
from __future__ import annotations
import functools import functools
import time import time
@ -8,6 +10,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .core import discovery from .core import discovery
@ -103,3 +106,19 @@ class ZHADeviceScannerEntity(ScannerEntity, ZhaEntity):
Percentage from 0-100. Percentage from 0-100.
""" """
return self._battery_level return self._battery_level
@property
def device_info( # pylint: disable=overridden-final-method
self,
) -> DeviceInfo | None:
"""Return device info."""
# We opt ZHA device tracker back into overriding this method because
# it doesn't track IP-based devices.
# Call Super because ScannerEntity overrode it.
return super(ZhaEntity, self).device_info
@property
def unique_id(self) -> str | None:
"""Return unique ID."""
# Call Super because ScannerEntity overrode it.
return super(ZhaEntity, self).unique_id

View File

@ -372,7 +372,7 @@ class DeviceRegistry:
) )
entry_type = DeviceEntryType(entry_type) entry_type = DeviceEntryType(entry_type)
device = self._async_update_device( device = self.async_update_device(
device.id, device.id,
add_config_entry_id=config_entry_id, add_config_entry_id=config_entry_id,
configuration_url=configuration_url, configuration_url=configuration_url,
@ -396,45 +396,6 @@ class DeviceRegistry:
@callback @callback
def async_update_device( def async_update_device(
self,
device_id: str,
*,
add_config_entry_id: str | UndefinedType = UNDEFINED,
area_id: str | None | UndefinedType = UNDEFINED,
configuration_url: str | None | UndefinedType = UNDEFINED,
disabled_by: DeviceEntryDisabler | None | UndefinedType = UNDEFINED,
manufacturer: str | None | UndefinedType = UNDEFINED,
model: str | None | UndefinedType = UNDEFINED,
name_by_user: str | None | UndefinedType = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED,
new_identifiers: set[tuple[str, str]] | UndefinedType = UNDEFINED,
remove_config_entry_id: str | UndefinedType = UNDEFINED,
suggested_area: str | None | UndefinedType = UNDEFINED,
sw_version: str | None | UndefinedType = UNDEFINED,
hw_version: str | None | UndefinedType = UNDEFINED,
via_device_id: str | None | UndefinedType = UNDEFINED,
) -> DeviceEntry | None:
"""Update properties of a device."""
return self._async_update_device(
device_id,
add_config_entry_id=add_config_entry_id,
area_id=area_id,
configuration_url=configuration_url,
disabled_by=disabled_by,
manufacturer=manufacturer,
model=model,
name_by_user=name_by_user,
name=name,
new_identifiers=new_identifiers,
remove_config_entry_id=remove_config_entry_id,
suggested_area=suggested_area,
sw_version=sw_version,
hw_version=hw_version,
via_device_id=via_device_id,
)
@callback
def _async_update_device(
self, self,
device_id: str, device_id: str,
*, *,
@ -568,7 +529,7 @@ class DeviceRegistry:
) )
for other_device in list(self.devices.values()): for other_device in list(self.devices.values()):
if other_device.via_device_id == device_id: if other_device.via_device_id == device_id:
self._async_update_device(other_device.id, via_device_id=None) self.async_update_device(other_device.id, via_device_id=None)
self.hass.bus.async_fire( self.hass.bus.async_fire(
EVENT_DEVICE_REGISTRY_UPDATED, {"action": "remove", "device_id": device_id} EVENT_DEVICE_REGISTRY_UPDATED, {"action": "remove", "device_id": device_id}
) )
@ -669,7 +630,7 @@ class DeviceRegistry:
"""Clear config entry from registry entries.""" """Clear config entry from registry entries."""
now_time = time.time() now_time = time.time()
for device in list(self.devices.values()): for device in list(self.devices.values()):
self._async_update_device(device.id, remove_config_entry_id=config_entry_id) self.async_update_device(device.id, remove_config_entry_id=config_entry_id)
for deleted_device in list(self.deleted_devices.values()): for deleted_device in list(self.deleted_devices.values()):
config_entries = deleted_device.config_entries config_entries = deleted_device.config_entries
if config_entry_id not in config_entries: if config_entry_id not in config_entries:
@ -711,7 +672,7 @@ class DeviceRegistry:
"""Clear area id from registry entries.""" """Clear area id from registry entries."""
for dev_id, device in self.devices.items(): for dev_id, device in self.devices.items():
if area_id == device.area_id: if area_id == device.area_id:
self._async_update_device(dev_id, area_id=None) self.async_update_device(dev_id, area_id=None)
@callback @callback

View File

@ -335,7 +335,7 @@ class EntityRegistry:
entity_id = self.async_get_entity_id(domain, platform, unique_id) entity_id = self.async_get_entity_id(domain, platform, unique_id)
if entity_id: if entity_id:
return self._async_update_entity( return self.async_update_entity(
entity_id, entity_id,
area_id=area_id or UNDEFINED, area_id=area_id or UNDEFINED,
capabilities=capabilities or UNDEFINED, capabilities=capabilities or UNDEFINED,
@ -460,43 +460,6 @@ class EntityRegistry:
@callback @callback
def async_update_entity( def async_update_entity(
self,
entity_id: str,
*,
area_id: str | None | UndefinedType = UNDEFINED,
config_entry_id: str | None | UndefinedType = UNDEFINED,
device_class: str | None | UndefinedType = UNDEFINED,
disabled_by: RegistryEntryDisabler | None | UndefinedType = UNDEFINED,
entity_category: str | None | UndefinedType = UNDEFINED,
icon: str | None | UndefinedType = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED,
new_entity_id: str | UndefinedType = UNDEFINED,
new_unique_id: str | UndefinedType = UNDEFINED,
original_device_class: str | None | UndefinedType = UNDEFINED,
original_icon: str | None | UndefinedType = UNDEFINED,
original_name: str | None | UndefinedType = UNDEFINED,
unit_of_measurement: str | None | UndefinedType = UNDEFINED,
) -> RegistryEntry:
"""Update properties of an entity."""
return self._async_update_entity(
entity_id,
area_id=area_id,
config_entry_id=config_entry_id,
device_class=device_class,
disabled_by=disabled_by,
entity_category=entity_category,
icon=icon,
name=name,
new_entity_id=new_entity_id,
new_unique_id=new_unique_id,
original_device_class=original_device_class,
original_icon=original_icon,
original_name=original_name,
unit_of_measurement=unit_of_measurement,
)
@callback
def _async_update_entity(
self, self,
entity_id: str, entity_id: str,
*, *,
@ -693,7 +656,7 @@ class EntityRegistry:
"""Clear area id from registry entries.""" """Clear area id from registry entries."""
for entity_id, entry in self.entities.items(): for entity_id, entry in self.entities.items():
if area_id == entry.area_id: if area_id == entry.area_id:
self._async_update_entity(entity_id, area_id=None) self.async_update_entity(entity_id, area_id=None)
@callback @callback

View File

@ -19,7 +19,7 @@ from homeassistant.const import (
STATE_HOME, STATE_HOME,
STATE_NOT_HOME, STATE_NOT_HOME,
) )
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.util import slugify from homeassistant.util import slugify
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
@ -41,6 +41,9 @@ MOCK_BYTES_TOTAL = [60000000000, 50000000000]
MOCK_CURRENT_TRANSFER_RATES = [20000000, 10000000] MOCK_CURRENT_TRANSFER_RATES = [20000000, 10000000]
MOCK_LOAD_AVG = [1.1, 1.2, 1.3] MOCK_LOAD_AVG = [1.1, 1.2, 1.3]
MOCK_TEMPERATURES = {"2.4GHz": 40, "5.0GHz": 0, "CPU": 71.2} MOCK_TEMPERATURES = {"2.4GHz": 40, "5.0GHz": 0, "CPU": 71.2}
MOCK_MAC_1 = "a1:b1:c1:d1:e1:f1"
MOCK_MAC_2 = "a2:b2:c2:d2:e2:f2"
MOCK_MAC_3 = "a3:b3:c3:d3:e3:f3"
SENSOR_NAMES = [ SENSOR_NAMES = [
"Devices Connected", "Devices Connected",
@ -61,8 +64,8 @@ SENSOR_NAMES = [
def mock_devices_fixture(): def mock_devices_fixture():
"""Mock a list of devices.""" """Mock a list of devices."""
return { return {
"a1:b1:c1:d1:e1:f1": Device("a1:b1:c1:d1:e1:f1", "192.168.1.2", "Test"), MOCK_MAC_1: Device(MOCK_MAC_1, "192.168.1.2", "Test"),
"a2:b2:c2:d2:e2:f2": Device("a2:b2:c2:d2:e2:f2", "192.168.1.3", "TestTwo"), MOCK_MAC_2: Device(MOCK_MAC_2, "192.168.1.3", "TestTwo"),
} }
@ -74,6 +77,26 @@ def mock_available_temps_list():
return [True, False] return [True, False]
@pytest.fixture(name="create_device_registry_devices")
def create_device_registry_devices_fixture(hass):
"""Create device registry devices so the device tracker entities are enabled."""
dev_reg = dr.async_get(hass)
config_entry = MockConfigEntry(domain="something_else")
for idx, device in enumerate(
(
MOCK_MAC_1,
MOCK_MAC_2,
MOCK_MAC_3,
)
):
dev_reg.async_get_or_create(
name=f"Device {idx}",
config_entry_id=config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, device)},
)
@pytest.fixture(name="connect") @pytest.fixture(name="connect")
def mock_controller_connect(mock_devices, mock_available_temps): def mock_controller_connect(mock_devices, mock_available_temps):
"""Mock a successful connection.""" """Mock a successful connection."""
@ -109,7 +132,13 @@ def mock_controller_connect(mock_devices, mock_available_temps):
yield service_mock yield service_mock
async def test_sensors(hass, connect, mock_devices, mock_available_temps): async def test_sensors(
hass,
connect,
mock_devices,
mock_available_temps,
create_device_registry_devices,
):
"""Test creating an AsusWRT sensor.""" """Test creating an AsusWRT sensor."""
entity_reg = er.async_get(hass) entity_reg = er.async_get(hass)
@ -161,10 +190,8 @@ async def test_sensors(hass, connect, mock_devices, mock_available_temps):
assert not hass.states.get(f"{sensor_prefix}_cpu_temperature") assert not hass.states.get(f"{sensor_prefix}_cpu_temperature")
# add one device and remove another # add one device and remove another
mock_devices.pop("a1:b1:c1:d1:e1:f1") mock_devices.pop(MOCK_MAC_1)
mock_devices["a3:b3:c3:d3:e3:f3"] = Device( mock_devices[MOCK_MAC_3] = Device(MOCK_MAC_3, "192.168.1.4", "TestThree")
"a3:b3:c3:d3:e3:f3", "192.168.1.4", "TestThree"
)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=30)) async_fire_time_changed(hass, utcnow() + timedelta(seconds=30))
await hass.async_block_till_done() await hass.async_block_till_done()

View File

@ -1,11 +1,14 @@
"""Test Device Tracker config entry things.""" """Test Device Tracker config entry things."""
from homeassistant.components.device_tracker import config_entry from homeassistant.components.device_tracker import DOMAIN, config_entry as ce
from homeassistant.helpers import device_registry as dr, entity_registry as er
from tests.common import MockConfigEntry
def test_tracker_entity(): def test_tracker_entity():
"""Test tracker entity.""" """Test tracker entity."""
class TestEntry(config_entry.TrackerEntity): class TestEntry(ce.TrackerEntity):
"""Mock tracker class.""" """Mock tracker class."""
should_poll = False should_poll = False
@ -17,3 +20,111 @@ def test_tracker_entity():
instance.should_poll = True instance.should_poll = True
assert not instance.force_update assert not instance.force_update
async def test_cleanup_legacy(hass, enable_custom_integrations):
"""Test we clean up devices created by old device tracker."""
dev_reg = dr.async_get(hass)
ent_reg = er.async_get(hass)
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
device1 = dev_reg.async_get_or_create(
config_entry_id=config_entry.entry_id, identifiers={(DOMAIN, "device1")}
)
device2 = dev_reg.async_get_or_create(
config_entry_id=config_entry.entry_id, identifiers={(DOMAIN, "device2")}
)
device3 = dev_reg.async_get_or_create(
config_entry_id=config_entry.entry_id, identifiers={(DOMAIN, "device3")}
)
# Device with light + device tracker entity
entity1a = ent_reg.async_get_or_create(
DOMAIN,
"test",
"entity1a-unique",
config_entry=config_entry,
device_id=device1.id,
)
entity1b = ent_reg.async_get_or_create(
"light",
"test",
"entity1b-unique",
config_entry=config_entry,
device_id=device1.id,
)
# Just device tracker entity
entity2a = ent_reg.async_get_or_create(
DOMAIN,
"test",
"entity2a-unique",
config_entry=config_entry,
device_id=device2.id,
)
# Device with no device tracker entities
entity3a = ent_reg.async_get_or_create(
"light",
"test",
"entity3a-unique",
config_entry=config_entry,
device_id=device3.id,
)
# Device tracker but no device
entity4a = ent_reg.async_get_or_create(
DOMAIN,
"test",
"entity4a-unique",
config_entry=config_entry,
)
# Completely different entity
entity5a = ent_reg.async_get_or_create(
"light",
"test",
"entity4a-unique",
config_entry=config_entry,
)
await hass.config_entries.async_forward_entry_setup(config_entry, DOMAIN)
await hass.async_block_till_done()
for entity in (entity1a, entity1b, entity3a, entity4a, entity5a):
assert ent_reg.async_get(entity.entity_id) is not None
# We've removed device so device ID cleared
assert ent_reg.async_get(entity2a.entity_id).device_id is None
# Removed because only had device tracker entity
assert dev_reg.async_get(device2.id) is None
async def test_register_mac(hass):
"""Test registering a mac."""
dev_reg = dr.async_get(hass)
ent_reg = er.async_get(hass)
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
mac1 = "12:34:56:AB:CD:EF"
entity_entry_1 = ent_reg.async_get_or_create(
"device_tracker",
"test",
mac1 + "yo1",
original_name="name 1",
config_entry=config_entry,
disabled_by=er.RegistryEntryDisabler.INTEGRATION,
)
ce._async_register_mac(hass, "test", mac1, mac1 + "yo1")
dev_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, mac1)},
)
await hass.async_block_till_done()
entity_entry_1 = ent_reg.async_get(entity_entry_1.entity_id)
assert entity_entry_1.disabled_by is None

View File

@ -14,25 +14,33 @@ from homeassistant.components.device_tracker.const import (
SOURCE_TYPE_ROUTER, SOURCE_TYPE_ROUTER,
) )
from homeassistant.const import ATTR_BATTERY_LEVEL, STATE_HOME, STATE_NOT_HOME from homeassistant.const import ATTR_BATTERY_LEVEL, STATE_HOME, STATE_NOT_HOME
from homeassistant.helpers import device_registry as dr
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
async def test_scanner_entity_device_tracker(hass, enable_custom_integrations): async def test_scanner_entity_device_tracker(hass, enable_custom_integrations):
"""Test ScannerEntity based device tracker.""" """Test ScannerEntity based device tracker."""
# Make device tied to other integration so device tracker entities get enabled
dr.async_get(hass).async_get_or_create(
name="Device from other integration",
config_entry_id=MockConfigEntry().entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "ad:de:ef:be:ed:fe")},
)
config_entry = MockConfigEntry(domain="test") config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
await hass.config_entries.async_forward_entry_setup(config_entry, DOMAIN) await hass.config_entries.async_forward_entry_setup(config_entry, DOMAIN)
await hass.async_block_till_done() await hass.async_block_till_done()
entity_id = "device_tracker.unnamed_device" entity_id = "device_tracker.test_ad_de_ef_be_ed_fe"
entity_state = hass.states.get(entity_id) entity_state = hass.states.get(entity_id)
assert entity_state.attributes == { assert entity_state.attributes == {
ATTR_SOURCE_TYPE: SOURCE_TYPE_ROUTER, ATTR_SOURCE_TYPE: SOURCE_TYPE_ROUTER,
ATTR_BATTERY_LEVEL: 100, ATTR_BATTERY_LEVEL: 100,
ATTR_IP: "0.0.0.0", ATTR_IP: "0.0.0.0",
ATTR_MAC: "ad:de:ef:be:ed:fe:", ATTR_MAC: "ad:de:ef:be:ed:fe",
ATTR_HOST_NAME: "test.hostname.org", ATTR_HOST_NAME: "test.hostname.org",
} }
assert entity_state.state == STATE_NOT_HOME assert entity_state.state == STATE_NOT_HOME

View File

@ -3,6 +3,8 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
from homeassistant.helpers import device_registry as dr
from .const import ( from .const import (
DATA_CALL_GET_CALLS_LOG, DATA_CALL_GET_CALLS_LOG,
DATA_CONNECTION_GET_STATUS, DATA_CONNECTION_GET_STATUS,
@ -12,6 +14,8 @@ from .const import (
WIFI_GET_GLOBAL_CONFIG, WIFI_GET_GLOBAL_CONFIG,
) )
from tests.common import MockConfigEntry
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_path(): def mock_path():
@ -20,8 +24,30 @@ def mock_path():
yield yield
@pytest.fixture
def mock_device_registry_devices(hass):
"""Create device registry devices so the device tracker entities are enabled."""
dev_reg = dr.async_get(hass)
config_entry = MockConfigEntry(domain="something_else")
for idx, device in enumerate(
(
"68:A3:78:00:00:00",
"8C:97:EA:00:00:00",
"DE:00:B0:00:00:00",
"DC:00:B0:00:00:00",
"5E:65:55:00:00:00",
)
):
dev_reg.async_get_or_create(
name=f"Device {idx}",
config_entry_id=config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, device)},
)
@pytest.fixture(name="router") @pytest.fixture(name="router")
def mock_router(): def mock_router(mock_device_registry_devices):
"""Mock a successful connection.""" """Mock a successful connection."""
with patch("homeassistant.components.freebox.router.Freepybox") as service_mock: with patch("homeassistant.components.freebox.router.Freepybox") as service_mock:
instance = service_mock.return_value instance = service_mock.return_value

View File

@ -1,9 +1,11 @@
"""The tests for the Mikrotik device tracker platform.""" """The tests for the Mikrotik device tracker platform."""
from datetime import timedelta from datetime import timedelta
import pytest
from homeassistant.components import mikrotik from homeassistant.components import mikrotik
import homeassistant.components.device_tracker as device_tracker import homeassistant.components.device_tracker as device_tracker
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -15,6 +17,25 @@ from tests.common import MockConfigEntry, patch
DEFAULT_DETECTION_TIME = timedelta(seconds=300) DEFAULT_DETECTION_TIME = timedelta(seconds=300)
@pytest.fixture
def mock_device_registry_devices(hass):
"""Create device registry devices so the device tracker entities are enabled."""
dev_reg = dr.async_get(hass)
config_entry = MockConfigEntry(domain="something_else")
for idx, device in enumerate(
(
"00:00:00:00:00:01",
"00:00:00:00:00:02",
)
):
dev_reg.async_get_or_create(
name=f"Device {idx}",
config_entry_id=config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, device)},
)
def mock_command(self, cmd, params=None): def mock_command(self, cmd, params=None):
"""Mock the Mikrotik command method.""" """Mock the Mikrotik command method."""
if cmd == mikrotik.const.MIKROTIK_SERVICES[mikrotik.const.IS_WIRELESS]: if cmd == mikrotik.const.MIKROTIK_SERVICES[mikrotik.const.IS_WIRELESS]:
@ -39,7 +60,9 @@ async def test_platform_manually_configured(hass):
assert mikrotik.DOMAIN not in hass.data assert mikrotik.DOMAIN not in hass.data
async def test_device_trackers(hass, legacy_patchable_time): async def test_device_trackers(
hass, legacy_patchable_time, mock_device_registry_devices
):
"""Test device_trackers created by mikrotik.""" """Test device_trackers created by mikrotik."""
# test devices are added from wireless list only # test devices are added from wireless list only

View File

@ -16,6 +16,7 @@ from homeassistant.components.ruckus_unleashed.const import (
API_VERSION, API_VERSION,
) )
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.helpers import device_registry as dr
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -68,6 +69,13 @@ def mock_config_entry() -> MockConfigEntry:
async def init_integration(hass) -> MockConfigEntry: async def init_integration(hass) -> MockConfigEntry:
"""Set up the Ruckus Unleashed integration in Home Assistant.""" """Set up the Ruckus Unleashed integration in Home Assistant."""
entry = mock_config_entry() entry = mock_config_entry()
entry.add_to_hass(hass)
# Make device tied to other integration so device tracker entities get enabled
dr.async_get(hass).async_get_or_create(
name="Device from other integration",
config_entry_id=MockConfigEntry().entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, TEST_CLIENT[API_MAC])},
)
with patch( with patch(
"homeassistant.components.ruckus_unleashed.Ruckus.connect", "homeassistant.components.ruckus_unleashed.Ruckus.connect",
return_value=None, return_value=None,
@ -86,7 +94,6 @@ async def init_integration(hass) -> MockConfigEntry:
TEST_CLIENT[API_MAC]: TEST_CLIENT, TEST_CLIENT[API_MAC]: TEST_CLIENT,
}, },
): ):
entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()

View File

@ -3,10 +3,8 @@ from datetime import timedelta
from unittest.mock import patch from unittest.mock import patch
from homeassistant.components.ruckus_unleashed import API_MAC, DOMAIN from homeassistant.components.ruckus_unleashed import API_MAC, DOMAIN
from homeassistant.components.ruckus_unleashed.const import API_AP, API_ID, API_NAME
from homeassistant.const import STATE_HOME, STATE_NOT_HOME, STATE_UNAVAILABLE from homeassistant.const import STATE_HOME, STATE_NOT_HOME, STATE_UNAVAILABLE
from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.util import utcnow from homeassistant.util import utcnow
from tests.common import async_fire_time_changed from tests.common import async_fire_time_changed
@ -112,24 +110,3 @@ async def test_restoring_clients(hass):
device = hass.states.get(TEST_CLIENT_ENTITY_ID) device = hass.states.get(TEST_CLIENT_ENTITY_ID)
assert device is not None assert device is not None
assert device.state == STATE_NOT_HOME assert device.state == STATE_NOT_HOME
async def test_client_device_setup(hass):
"""Test a client device is created."""
await init_integration(hass)
router_info = DEFAULT_AP_INFO[API_AP][API_ID]["1"]
device_registry = dr.async_get(hass)
client_device = device_registry.async_get_device(
identifiers={},
connections={(CONNECTION_NETWORK_MAC, TEST_CLIENT[API_MAC])},
)
router_device = device_registry.async_get_device(
identifiers={(CONNECTION_NETWORK_MAC, router_info[API_MAC])},
connections={(CONNECTION_NETWORK_MAC, router_info[API_MAC])},
)
assert client_device
assert client_device.name == TEST_CLIENT[API_NAME]
assert client_device.via_device_id == router_device.id

View File

@ -6,6 +6,10 @@ from unittest.mock import patch
from aiounifi.websocket import SIGNAL_CONNECTION_STATE, SIGNAL_DATA from aiounifi.websocket import SIGNAL_CONNECTION_STATE, SIGNAL_DATA
import pytest import pytest
from homeassistant.helpers import device_registry as dr
from tests.common import MockConfigEntry
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_unifi_websocket(): def mock_unifi_websocket():
@ -34,3 +38,27 @@ def mock_discovery():
return_value=None, return_value=None,
) as mock: ) as mock:
yield mock yield mock
@pytest.fixture
def mock_device_registry(hass):
"""Mock device registry."""
dev_reg = dr.async_get(hass)
config_entry = MockConfigEntry(domain="something_else")
for idx, device in enumerate(
(
"00:00:00:00:00:01",
"00:00:00:00:00:02",
"00:00:00:00:00:03",
"00:00:00:00:00:04",
"00:00:00:00:00:05",
"00:00:00:00:01:01",
"00:00:00:00:02:02",
)
):
dev_reg.async_get_or_create(
name=f"Device {idx}",
config_entry_id=config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, device)},
)

View File

@ -346,7 +346,9 @@ async def test_reset_fails(hass, aioclient_mock):
assert result is False assert result is False
async def test_connection_state_signalling(hass, aioclient_mock, mock_unifi_websocket): async def test_connection_state_signalling(
hass, aioclient_mock, mock_unifi_websocket, mock_device_registry
):
"""Verify connection statesignalling and connection state are working.""" """Verify connection statesignalling and connection state are working."""
client = { client = {
"hostname": "client", "hostname": "client",

View File

@ -38,7 +38,9 @@ async def test_no_entities(hass, aioclient_mock):
assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 0 assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 0
async def test_tracked_wireless_clients(hass, aioclient_mock, mock_unifi_websocket): async def test_tracked_wireless_clients(
hass, aioclient_mock, mock_unifi_websocket, mock_device_registry
):
"""Verify tracking of wireless clients.""" """Verify tracking of wireless clients."""
client = { client = {
"ap_mac": "00:00:00:00:02:01", "ap_mac": "00:00:00:00:02:01",
@ -157,7 +159,9 @@ async def test_tracked_wireless_clients(hass, aioclient_mock, mock_unifi_websock
assert hass.states.get("device_tracker.client").state == STATE_HOME assert hass.states.get("device_tracker.client").state == STATE_HOME
async def test_tracked_clients(hass, aioclient_mock, mock_unifi_websocket): async def test_tracked_clients(
hass, aioclient_mock, mock_unifi_websocket, mock_device_registry
):
"""Test the update_items function with some clients.""" """Test the update_items function with some clients."""
client_1 = { client_1 = {
"ap_mac": "00:00:00:00:02:01", "ap_mac": "00:00:00:00:02:01",
@ -234,7 +238,9 @@ async def test_tracked_clients(hass, aioclient_mock, mock_unifi_websocket):
assert hass.states.get("device_tracker.client_1").state == STATE_HOME assert hass.states.get("device_tracker.client_1").state == STATE_HOME
async def test_tracked_devices(hass, aioclient_mock, mock_unifi_websocket): async def test_tracked_devices(
hass, aioclient_mock, mock_unifi_websocket, mock_device_registry
):
"""Test the update_items function with some devices.""" """Test the update_items function with some devices."""
device_1 = { device_1 = {
"board_rev": 3, "board_rev": 3,
@ -321,45 +327,10 @@ async def test_tracked_devices(hass, aioclient_mock, mock_unifi_websocket):
assert hass.states.get("device_tracker.device_1").state == STATE_UNAVAILABLE assert hass.states.get("device_tracker.device_1").state == STATE_UNAVAILABLE
assert hass.states.get("device_tracker.device_2").state == STATE_HOME assert hass.states.get("device_tracker.device_2").state == STATE_HOME
# Update device registry when device is upgraded
event = { async def test_remove_clients(
"_id": "5eae7fe02ab79c00f9d38960", hass, aioclient_mock, mock_unifi_websocket, mock_device_registry
"datetime": "2020-05-09T20:06:37Z", ):
"key": "EVT_SW_Upgraded",
"msg": f'Switch[{device_2["mac"]}] was upgraded from "{device_2["version"]}" to "4.3.13.11253"',
"subsystem": "lan",
"sw": device_2["mac"],
"sw_name": device_2["name"],
"time": 1589054797635,
"version_from": {device_2["version"]},
"version_to": "4.3.13.11253",
}
device_2["version"] = event["version_to"]
mock_unifi_websocket(
data={
"meta": {"message": MESSAGE_DEVICE},
"data": [device_2],
}
)
mock_unifi_websocket(
data={
"meta": {"message": MESSAGE_EVENT},
"data": [event],
}
)
await hass.async_block_till_done()
# Verify device registry has been updated
entity_registry = er.async_get(hass)
entry = entity_registry.async_get("device_tracker.device_2")
device_registry = dr.async_get(hass)
device = device_registry.async_get(entry.device_id)
assert device.sw_version == event["version_to"]
async def test_remove_clients(hass, aioclient_mock, mock_unifi_websocket):
"""Test the remove_items function with some clients.""" """Test the remove_items function with some clients."""
client_1 = { client_1 = {
"essid": "ssid", "essid": "ssid",
@ -399,7 +370,7 @@ async def test_remove_clients(hass, aioclient_mock, mock_unifi_websocket):
async def test_remove_client_but_keep_device_entry( async def test_remove_client_but_keep_device_entry(
hass, aioclient_mock, mock_unifi_websocket hass, aioclient_mock, mock_unifi_websocket, mock_device_registry
): ):
"""Test that unifi entity base remove config entry id from a multi integration device registry entry.""" """Test that unifi entity base remove config entry id from a multi integration device registry entry."""
client_1 = { client_1 = {
@ -424,7 +395,7 @@ async def test_remove_client_but_keep_device_entry(
"unique_id", "unique_id",
device_id=device_entry.id, device_id=device_entry.id,
) )
assert len(device_entry.config_entries) == 2 assert len(device_entry.config_entries) == 3
mock_unifi_websocket( mock_unifi_websocket(
data={ data={
@ -438,10 +409,12 @@ async def test_remove_client_but_keep_device_entry(
assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 0 assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 0
device_entry = device_registry.async_get(other_entity.device_id) device_entry = device_registry.async_get(other_entity.device_id)
assert len(device_entry.config_entries) == 1 assert len(device_entry.config_entries) == 2
async def test_controller_state_change(hass, aioclient_mock, mock_unifi_websocket): async def test_controller_state_change(
hass, aioclient_mock, mock_unifi_websocket, mock_device_registry
):
"""Verify entities state reflect on controller becoming unavailable.""" """Verify entities state reflect on controller becoming unavailable."""
client = { client = {
"essid": "ssid", "essid": "ssid",
@ -495,7 +468,7 @@ async def test_controller_state_change(hass, aioclient_mock, mock_unifi_websocke
async def test_controller_state_change_client_to_listen_on_all_state_changes( async def test_controller_state_change_client_to_listen_on_all_state_changes(
hass, aioclient_mock, mock_unifi_websocket hass, aioclient_mock, mock_unifi_websocket, mock_device_registry
): ):
"""Verify entities state reflect on controller becoming unavailable.""" """Verify entities state reflect on controller becoming unavailable."""
client = { client = {
@ -579,7 +552,7 @@ async def test_controller_state_change_client_to_listen_on_all_state_changes(
assert hass.states.get("device_tracker.client").state == STATE_HOME assert hass.states.get("device_tracker.client").state == STATE_HOME
async def test_option_track_clients(hass, aioclient_mock): async def test_option_track_clients(hass, aioclient_mock, mock_device_registry):
"""Test the tracking of clients can be turned off.""" """Test the tracking of clients can be turned off."""
wireless_client = { wireless_client = {
"essid": "ssid", "essid": "ssid",
@ -645,7 +618,7 @@ async def test_option_track_clients(hass, aioclient_mock):
assert hass.states.get("device_tracker.device") assert hass.states.get("device_tracker.device")
async def test_option_track_wired_clients(hass, aioclient_mock): async def test_option_track_wired_clients(hass, aioclient_mock, mock_device_registry):
"""Test the tracking of wired clients can be turned off.""" """Test the tracking of wired clients can be turned off."""
wireless_client = { wireless_client = {
"essid": "ssid", "essid": "ssid",
@ -711,7 +684,7 @@ async def test_option_track_wired_clients(hass, aioclient_mock):
assert hass.states.get("device_tracker.device") assert hass.states.get("device_tracker.device")
async def test_option_track_devices(hass, aioclient_mock): async def test_option_track_devices(hass, aioclient_mock, mock_device_registry):
"""Test the tracking of devices can be turned off.""" """Test the tracking of devices can be turned off."""
client = { client = {
"hostname": "client", "hostname": "client",
@ -764,7 +737,9 @@ async def test_option_track_devices(hass, aioclient_mock):
assert hass.states.get("device_tracker.device") assert hass.states.get("device_tracker.device")
async def test_option_ssid_filter(hass, aioclient_mock, mock_unifi_websocket): async def test_option_ssid_filter(
hass, aioclient_mock, mock_unifi_websocket, mock_device_registry
):
"""Test the SSID filter works. """Test the SSID filter works.
Client will travel from a supported SSID to an unsupported ssid. Client will travel from a supported SSID to an unsupported ssid.
@ -896,7 +871,7 @@ async def test_option_ssid_filter(hass, aioclient_mock, mock_unifi_websocket):
async def test_wireless_client_go_wired_issue( async def test_wireless_client_go_wired_issue(
hass, aioclient_mock, mock_unifi_websocket hass, aioclient_mock, mock_unifi_websocket, mock_device_registry
): ):
"""Test the solution to catch wireless device go wired UniFi issue. """Test the solution to catch wireless device go wired UniFi issue.
@ -979,7 +954,9 @@ async def test_wireless_client_go_wired_issue(
assert client_state.attributes["is_wired"] is False assert client_state.attributes["is_wired"] is False
async def test_option_ignore_wired_bug(hass, aioclient_mock, mock_unifi_websocket): async def test_option_ignore_wired_bug(
hass, aioclient_mock, mock_unifi_websocket, mock_device_registry
):
"""Test option to ignore wired bug.""" """Test option to ignore wired bug."""
client = { client = {
"ap_mac": "00:00:00:00:02:01", "ap_mac": "00:00:00:00:02:01",
@ -1061,7 +1038,7 @@ async def test_option_ignore_wired_bug(hass, aioclient_mock, mock_unifi_websocke
assert client_state.attributes["is_wired"] is False assert client_state.attributes["is_wired"] is False
async def test_restoring_client(hass, aioclient_mock): async def test_restoring_client(hass, aioclient_mock, mock_device_registry):
"""Verify clients are restored from clients_all if they ever was registered to entity registry.""" """Verify clients are restored from clients_all if they ever was registered to entity registry."""
client = { client = {
"hostname": "client", "hostname": "client",
@ -1115,7 +1092,7 @@ async def test_restoring_client(hass, aioclient_mock):
assert not hass.states.get("device_tracker.not_restored") assert not hass.states.get("device_tracker.not_restored")
async def test_dont_track_clients(hass, aioclient_mock): async def test_dont_track_clients(hass, aioclient_mock, mock_device_registry):
"""Test don't track clients config works.""" """Test don't track clients config works."""
wireless_client = { wireless_client = {
"essid": "ssid", "essid": "ssid",
@ -1175,7 +1152,7 @@ async def test_dont_track_clients(hass, aioclient_mock):
assert hass.states.get("device_tracker.device") assert hass.states.get("device_tracker.device")
async def test_dont_track_devices(hass, aioclient_mock): async def test_dont_track_devices(hass, aioclient_mock, mock_device_registry):
"""Test don't track devices config works.""" """Test don't track devices config works."""
client = { client = {
"hostname": "client", "hostname": "client",
@ -1224,7 +1201,7 @@ async def test_dont_track_devices(hass, aioclient_mock):
assert hass.states.get("device_tracker.device") assert hass.states.get("device_tracker.device")
async def test_dont_track_wired_clients(hass, aioclient_mock): async def test_dont_track_wired_clients(hass, aioclient_mock, mock_device_registry):
"""Test don't track wired clients config works.""" """Test don't track wired clients config works."""
wireless_client = { wireless_client = {
"essid": "ssid", "essid": "ssid",

View File

@ -5,7 +5,6 @@ from unittest.mock import patch
from aiounifi.controller import MESSAGE_CLIENT_REMOVED, MESSAGE_EVENT from aiounifi.controller import MESSAGE_CLIENT_REMOVED, MESSAGE_EVENT
from homeassistant import config_entries, core from homeassistant import config_entries, core
from homeassistant.components.device_tracker import DOMAIN as TRACKER_DOMAIN
from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN
from homeassistant.components.unifi.const import ( from homeassistant.components.unifi.const import (
CONF_BLOCK_CLIENT, CONF_BLOCK_CLIENT,
@ -784,8 +783,6 @@ async def test_ignore_multiple_poe_clients_on_same_port(hass, aioclient_mock):
devices_response=[DEVICE_1], devices_response=[DEVICE_1],
) )
assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 3
switch_1 = hass.states.get("switch.poe_client_1") switch_1 = hass.states.get("switch.poe_client_1")
switch_2 = hass.states.get("switch.poe_client_2") switch_2 = hass.states.get("switch.poe_client_2")
assert switch_1 is None assert switch_1 is None

View File

@ -18,7 +18,7 @@ class MockScannerEntity(ScannerEntity):
self.connected = False self.connected = False
self._hostname = "test.hostname.org" self._hostname = "test.hostname.org"
self._ip_address = "0.0.0.0" self._ip_address = "0.0.0.0"
self._mac_address = "ad:de:ef:be:ed:fe:" self._mac_address = "ad:de:ef:be:ed:fe"
@property @property
def source_type(self): def source_type(self):