diff --git a/homeassistant/components/unifi/controller.py b/homeassistant/components/unifi/controller.py index f2fd5760471..33e7fc3836b 100644 --- a/homeassistant/components/unifi/controller.py +++ b/homeassistant/components/unifi/controller.py @@ -72,6 +72,8 @@ class UniFiController: self._site_name = None self._site_role = None + self.entities = {} + @property def controller_id(self): """Return the controller ID.""" diff --git a/homeassistant/components/unifi/device_tracker.py b/homeassistant/components/unifi/device_tracker.py index 5e50a75409f..2ef5161a94e 100644 --- a/homeassistant/components/unifi/device_tracker.py +++ b/homeassistant/components/unifi/device_tracker.py @@ -1,10 +1,11 @@ """Track devices using UniFi controllers.""" import logging -from homeassistant.components.device_tracker import DOMAIN as DEVICE_TRACKER_DOMAIN +from homeassistant.components.device_tracker import DOMAIN from homeassistant.components.device_tracker.config_entry import ScannerEntity from homeassistant.components.device_tracker.const import SOURCE_TYPE_ROUTER from homeassistant.components.unifi.config_flow import get_controller_from_config_entry +from homeassistant.components.unifi.unifi_entity_base import UniFiBase from homeassistant.core import callback from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from homeassistant.helpers.dispatcher import async_dispatcher_connect @@ -38,30 +39,26 @@ CLIENT_STATIC_ATTRIBUTES = [ "oui", ] +CLIENT_TRACKER = "client" +DEVICE_TRACKER = "device" + async def async_setup_entry(hass, config_entry, async_add_entities): """Set up device tracker for UniFi component.""" controller = get_controller_from_config_entry(hass, config_entry) - tracked = {} - - option_track_clients = controller.option_track_clients - option_track_devices = controller.option_track_devices - option_track_wired_clients = controller.option_track_wired_clients - option_ssid_filter = controller.option_ssid_filter - - entity_registry = await hass.helpers.entity_registry.async_get_registry() + controller.entities[DOMAIN] = {CLIENT_TRACKER: set(), DEVICE_TRACKER: set()} # Restore clients that is not a part of active clients list. + entity_registry = await hass.helpers.entity_registry.async_get_registry() for entity in entity_registry.entities.values(): if ( entity.config_entry_id == config_entry.entry_id - and entity.domain == DEVICE_TRACKER_DOMAIN + and entity.domain == DOMAIN and "-" in entity.unique_id ): mac, _ = entity.unique_id.split("-", 1) - if mac in controller.api.clients or mac not in controller.api.clients_all: continue @@ -74,99 +71,19 @@ async def async_setup_entry(hass, config_entry, async_add_entities): @callback def items_added(): """Update the values of the controller.""" - nonlocal option_track_clients - nonlocal option_track_devices + if controller.option_track_clients or controller.option_track_devices: + add_entities(controller, async_add_entities) - if not option_track_clients and not option_track_devices: - return - - add_entities(controller, async_add_entities, tracked) - - controller.listeners.append( - async_dispatcher_connect(hass, controller.signal_update, items_added) - ) - - @callback - def items_removed(mac_addresses: set) -> None: - """Items have been removed from the controller.""" - remove_entities(controller, mac_addresses, tracked, entity_registry) - - controller.listeners.append( - async_dispatcher_connect(hass, controller.signal_remove, items_removed) - ) - - @callback - def options_updated(): - """Manage entities affected by config entry options.""" - nonlocal option_track_clients - nonlocal option_track_devices - nonlocal option_track_wired_clients - nonlocal option_ssid_filter - - update = False - remove = set() - - for current_option, config_entry_option, tracker_class in ( - (option_track_clients, controller.option_track_clients, UniFiClientTracker), - (option_track_devices, controller.option_track_devices, UniFiDeviceTracker), - ): - if current_option == config_entry_option: - continue - - if config_entry_option: - update = True - else: - for mac, entity in tracked.items(): - if isinstance(entity, tracker_class): - remove.add(mac) - - if ( - controller.option_track_clients - and option_track_wired_clients != controller.option_track_wired_clients - ): - - if controller.option_track_wired_clients: - update = True - else: - for mac, entity in tracked.items(): - if isinstance(entity, UniFiClientTracker) and entity.is_wired: - remove.add(mac) - - if option_ssid_filter != controller.option_ssid_filter: - update = True - - if controller.option_ssid_filter: - for mac, entity in tracked.items(): - if ( - isinstance(entity, UniFiClientTracker) - and not entity.is_wired - and entity.client.essid not in controller.option_ssid_filter - ): - remove.add(mac) - - option_track_clients = controller.option_track_clients - option_track_devices = controller.option_track_devices - option_track_wired_clients = controller.option_track_wired_clients - option_ssid_filter = controller.option_ssid_filter - - remove_entities(controller, remove, tracked, entity_registry) - - if update: - items_added() - - controller.listeners.append( - async_dispatcher_connect( - hass, controller.signal_options_update, options_updated - ) - ) + for signal in (controller.signal_update, controller.signal_options_update): + controller.listeners.append(async_dispatcher_connect(hass, signal, items_added)) items_added() @callback -def add_entities(controller, async_add_entities, tracked): +def add_entities(controller, async_add_entities): """Add new tracker entities from the controller.""" - new_tracked = [] + trackers = [] for items, tracker_class, track in ( (controller.api.clients, UniFiClientTracker, controller.option_track_clients), @@ -175,46 +92,36 @@ def add_entities(controller, async_add_entities, tracked): if not track: continue - for item_id in items: + for mac in items: - if item_id in tracked: + if mac in controller.entities[DOMAIN][tracker_class.TYPE]: continue + item = items[mac] + if tracker_class is UniFiClientTracker: - client = items[item_id] - if not controller.option_track_wired_clients and client.is_wired: - continue + if item.is_wired: + if not controller.option_track_wired_clients: + continue + else: + if ( + controller.option_ssid_filter + and item.essid not in controller.option_ssid_filter + ): + continue - if ( - controller.option_ssid_filter - and not client.is_wired - and client.essid not in controller.option_ssid_filter - ): - continue + trackers.append(tracker_class(item, controller)) - tracked[item_id] = tracker_class(items[item_id], controller) - new_tracked.append(tracked[item_id]) - - if new_tracked: - async_add_entities(new_tracked) - - -@callback -def remove_entities(controller, mac_addresses, tracked, entity_registry): - """Remove select tracked entities.""" - for mac in mac_addresses: - - if mac not in tracked: - continue - - entity = tracked.pop(mac) - controller.hass.async_create_task(entity.async_remove()) + if trackers: + async_add_entities(trackers) class UniFiClientTracker(UniFiClient, ScannerEntity): """Representation of a network client.""" + TYPE = CLIENT_TRACKER + def __init__(self, client, controller): """Set up tracked client.""" super().__init__(client, controller) @@ -315,34 +222,52 @@ class UniFiClientTracker(UniFiClient, ScannerEntity): return attributes + async def options_updated(self) -> None: + """Config entry options are updated, remove entity if option is disabled.""" + if not self.controller.option_track_clients: + await self.async_remove() -class UniFiDeviceTracker(ScannerEntity): + elif self.is_wired: + if not self.controller.option_track_wired_clients: + await self.async_remove() + else: + if ( + self.controller.option_ssid_filter + and self.client.essid not in self.controller.option_ssid_filter + ): + await self.async_remove() + + +class UniFiDeviceTracker(UniFiBase, ScannerEntity): """Representation of a network infrastructure device.""" + TYPE = DEVICE_TRACKER + def __init__(self, device, controller): """Set up tracked device.""" + super().__init__(controller) self.device = device - self.controller = controller + + @property + def mac(self): + """Return MAC of device.""" + return self.device.mac async def async_added_to_hass(self): """Subscribe to device events.""" + await super().async_added_to_hass() LOGGER.debug("New device %s (%s)", self.entity_id, self.device.mac) self.device.register_callback(self.async_update_callback) - self.async_on_remove( - async_dispatcher_connect( - self.hass, self.controller.signal_reachable, self.async_update_callback - ) - ) async def async_will_remove_from_hass(self) -> None: """Disconnect device object when removed.""" + await super().async_will_remove_from_hass() self.device.remove_callback(self.async_update_callback) @callback def async_update_callback(self): """Update the sensor's state.""" LOGGER.debug("Updating device %s (%s)", self.entity_id, self.device.mac) - self.async_write_ha_state() @property @@ -410,7 +335,7 @@ class UniFiDeviceTracker(ScannerEntity): return attributes - @property - def should_poll(self): - """No polling needed.""" - return True + async def options_updated(self) -> None: + """Config entry options are updated, remove entity if option is disabled.""" + if not self.controller.option_track_devices: + await self.async_remove() diff --git a/homeassistant/components/unifi/sensor.py b/homeassistant/components/unifi/sensor.py index 0eff1eeea35..32f92b4def1 100644 --- a/homeassistant/components/unifi/sensor.py +++ b/homeassistant/components/unifi/sensor.py @@ -1,6 +1,7 @@ """Support for bandwidth sensors with UniFi clients.""" import logging +from homeassistant.components.sensor import DOMAIN from homeassistant.components.unifi.config_flow import get_controller_from_config_entry from homeassistant.const import DATA_MEGABYTES from homeassistant.core import callback @@ -10,6 +11,9 @@ from .unifi_client import UniFiClient LOGGER = logging.getLogger(__name__) +RX_SENSOR = "rx" +TX_SENSOR = "tx" + async def async_setup_platform(hass, config, async_add_entities, discovery_info=None): """Sensor platform doesn't support configuration through configuration.yaml.""" @@ -18,144 +22,74 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= async def async_setup_entry(hass, config_entry, async_add_entities): """Set up sensors for UniFi integration.""" controller = get_controller_from_config_entry(hass, config_entry) - sensors = {} - - option_allow_bandwidth_sensors = controller.option_allow_bandwidth_sensors - - entity_registry = await hass.helpers.entity_registry.async_get_registry() + controller.entities[DOMAIN] = {RX_SENSOR: set(), TX_SENSOR: set()} @callback def items_added(): """Update the values of the controller.""" - nonlocal option_allow_bandwidth_sensors + if controller.option_allow_bandwidth_sensors: + add_entities(controller, async_add_entities) - if not option_allow_bandwidth_sensors: - return - - add_entities(controller, async_add_entities, sensors) - - controller.listeners.append( - async_dispatcher_connect(hass, controller.signal_update, items_added) - ) - - @callback - def items_removed(mac_addresses: set) -> None: - """Items have been removed from the controller.""" - remove_entities(controller, mac_addresses, sensors, entity_registry) - - controller.listeners.append( - async_dispatcher_connect(hass, controller.signal_remove, items_removed) - ) - - @callback - def options_updated(): - """Update the values of the controller.""" - nonlocal option_allow_bandwidth_sensors - - if option_allow_bandwidth_sensors != controller.option_allow_bandwidth_sensors: - option_allow_bandwidth_sensors = controller.option_allow_bandwidth_sensors - - if option_allow_bandwidth_sensors: - items_added() - - else: - for sensor in sensors.values(): - hass.async_create_task(sensor.async_remove()) - - sensors.clear() - - controller.listeners.append( - async_dispatcher_connect( - hass, controller.signal_options_update, options_updated - ) - ) + for signal in (controller.signal_update, controller.signal_options_update): + controller.listeners.append(async_dispatcher_connect(hass, signal, items_added)) items_added() @callback -def add_entities(controller, async_add_entities, sensors): +def add_entities(controller, async_add_entities): """Add new sensor entities from the controller.""" - new_sensors = [] + sensors = [] - for client_id in controller.api.clients: - for direction, sensor_class in ( - ("rx", UniFiRxBandwidthSensor), - ("tx", UniFiTxBandwidthSensor), - ): - item_id = f"{direction}-{client_id}" + for mac in controller.api.clients: + for sensor_class in (UniFiRxBandwidthSensor, UniFiTxBandwidthSensor): + if mac not in controller.entities[DOMAIN][sensor_class.TYPE]: + sensors.append(sensor_class(controller.api.clients[mac], controller)) - if item_id in sensors: - continue - - sensors[item_id] = sensor_class( - controller.api.clients[client_id], controller - ) - new_sensors.append(sensors[item_id]) - - if new_sensors: - async_add_entities(new_sensors) + if sensors: + async_add_entities(sensors) -@callback -def remove_entities(controller, mac_addresses, sensors, entity_registry): - """Remove select sensor entities.""" - for mac in mac_addresses: - - for direction in ("rx", "tx"): - item_id = f"{direction}-{mac}" - - if item_id not in sensors: - continue - - entity = sensors.pop(item_id) - controller.hass.async_create_task(entity.async_remove()) - - -class UniFiRxBandwidthSensor(UniFiClient): - """Receiving bandwidth sensor.""" +class UniFiBandwidthSensor(UniFiClient): + """UniFi bandwidth sensor base class.""" @property - def state(self): - """Return the state of the sensor.""" - if self._is_wired: - return self.client.wired_rx_bytes / 1000000 - return self.client.raw.get("rx_bytes", 0) / 1000000 - - @property - def name(self): + def name(self) -> str: """Return the name of the client.""" - name = self.client.name or self.client.hostname - return f"{name} RX" + return f"{super().name} {self.TYPE.upper()}" @property - def unique_id(self): - """Return a unique identifier for this bandwidth sensor.""" - return f"rx-{self.client.mac}" - - @property - def unit_of_measurement(self): + def unit_of_measurement(self) -> str: """Return the unit of measurement of this entity.""" return DATA_MEGABYTES + async def options_updated(self) -> None: + """Config entry options are updated, remove entity if option is disabled.""" + if not self.controller.option_allow_bandwidth_sensors: + await self.async_remove() -class UniFiTxBandwidthSensor(UniFiRxBandwidthSensor): - """Transmitting bandwidth sensor.""" + +class UniFiRxBandwidthSensor(UniFiBandwidthSensor): + """Receiving bandwidth sensor.""" + + TYPE = RX_SENSOR @property - def state(self): + def state(self) -> int: + """Return the state of the sensor.""" + if self._is_wired: + return self.client.wired_rx_bytes / 1000000 + return self.client.rx_bytes / 1000000 + + +class UniFiTxBandwidthSensor(UniFiBandwidthSensor): + """Transmitting bandwidth sensor.""" + + TYPE = TX_SENSOR + + @property + def state(self) -> int: """Return the state of the sensor.""" if self._is_wired: return self.client.wired_tx_bytes / 1000000 - return self.client.raw.get("tx_bytes", 0) / 1000000 - - @property - def name(self): - """Return the name of the client.""" - name = self.client.name or self.client.hostname - return f"{name} TX" - - @property - def unique_id(self): - """Return a unique identifier for this bandwidth sensor.""" - return f"tx-{self.client.mac}" + return self.client.tx_bytes / 1000000 diff --git a/homeassistant/components/unifi/switch.py b/homeassistant/components/unifi/switch.py index a0b7d865a1b..7257cafd2fc 100644 --- a/homeassistant/components/unifi/switch.py +++ b/homeassistant/components/unifi/switch.py @@ -1,7 +1,7 @@ """Support for devices connected to UniFi POE.""" import logging -from homeassistant.components.switch import SwitchDevice +from homeassistant.components.switch import DOMAIN, SwitchDevice from homeassistant.components.unifi.config_flow import get_controller_from_config_entry from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_connect @@ -11,6 +11,9 @@ from .unifi_client import UniFiClient LOGGER = logging.getLogger(__name__) +BLOCK_SWITCH = "block" +POE_SWITCH = "poe" + async def async_setup_platform(hass, config, async_add_entities, discovery_info=None): """Component doesn't support configuration through configuration.yaml.""" @@ -22,24 +25,20 @@ async def async_setup_entry(hass, config_entry, async_add_entities): Switches are controlling network access and switch ports with POE. """ controller = get_controller_from_config_entry(hass, config_entry) + controller.entities[DOMAIN] = {BLOCK_SWITCH: set(), POE_SWITCH: set()} if controller.site_role != "admin": return - switches = {} switches_off = [] - option_block_clients = controller.option_block_clients - option_poe_clients = controller.option_poe_clients - - entity_registry = await hass.helpers.entity_registry.async_get_registry() - # Restore clients that is not a part of active clients list. + entity_registry = await hass.helpers.entity_registry.async_get_registry() for entity in entity_registry.entities.values(): if ( entity.config_entry_id == config_entry.entry_id - and entity.unique_id.startswith("poe-") + and entity.unique_id.startswith(f"{POE_SWITCH}-") ): _, mac = entity.unique_id.split("-", 1) @@ -57,110 +56,53 @@ async def async_setup_entry(hass, config_entry, async_add_entities): @callback def items_added(): """Update the values of the controller.""" - add_entities(controller, async_add_entities, switches, switches_off) + if controller.option_block_clients or controller.option_poe_clients: + add_entities(controller, async_add_entities, switches_off) - controller.listeners.append( - async_dispatcher_connect(hass, controller.signal_update, items_added) - ) - - @callback - def items_removed(mac_addresses: set) -> None: - """Items have been removed from the controller.""" - remove_entities(controller, mac_addresses, switches, entity_registry) - - controller.listeners.append( - async_dispatcher_connect(hass, controller.signal_remove, items_removed) - ) - - @callback - def options_updated(): - """Manage entities affected by config entry options.""" - nonlocal option_block_clients - nonlocal option_poe_clients - - update = set() - remove = set() - - if option_block_clients != controller.option_block_clients: - option_block_clients = controller.option_block_clients - - for block_client_id, entity in switches.items(): - if not isinstance(entity, UniFiBlockClientSwitch): - continue - - if entity.client.mac in option_block_clients: - update.add(block_client_id) - else: - remove.add(block_client_id) - - if option_poe_clients != controller.option_poe_clients: - option_poe_clients = controller.option_poe_clients - - if option_poe_clients: - update.add("poe_clients_enabled") - else: - for poe_client_id, entity in switches.items(): - if isinstance(entity, UniFiPOEClientSwitch): - remove.add(poe_client_id) - - for client_id in remove: - entity = switches.pop(client_id) - hass.async_create_task(entity.async_remove()) - - if len(update) != len(option_block_clients): - items_added() - - controller.listeners.append( - async_dispatcher_connect( - hass, controller.signal_options_update, options_updated - ) - ) + for signal in (controller.signal_update, controller.signal_options_update): + controller.listeners.append(async_dispatcher_connect(hass, signal, items_added)) items_added() switches_off.clear() @callback -def add_entities(controller, async_add_entities, switches, switches_off): +def add_entities(controller, async_add_entities, switches_off): """Add new switch entities from the controller.""" - new_switches = [] - devices = controller.api.devices + switches = [] - for client_id in controller.option_block_clients: + for mac in controller.option_block_clients: - client = None - block_client_id = f"block-{client_id}" - - if block_client_id in switches: + if mac in controller.entities[DOMAIN][BLOCK_SWITCH]: continue - if client_id in controller.api.clients: - client = controller.api.clients[client_id] + client = None - elif client_id in controller.api.clients_all: - client = controller.api.clients_all[client_id] + if mac in controller.api.clients: + client = controller.api.clients[mac] + + elif mac in controller.api.clients_all: + client = controller.api.clients_all[mac] if not client: continue - switches[block_client_id] = UniFiBlockClientSwitch(client, controller) - new_switches.append(switches[block_client_id]) + switches.append(UniFiBlockClientSwitch(client, controller)) if controller.option_poe_clients: - for client_id in controller.api.clients: + devices = controller.api.devices - poe_client_id = f"poe-{client_id}" + for mac in controller.api.clients: - if poe_client_id in switches: + poe_client_id = f"{POE_SWITCH}-{mac}" + + if mac in controller.entities[DOMAIN][POE_SWITCH]: continue - client = controller.api.clients[client_id] + client = controller.api.clients[mac] - if poe_client_id in switches_off: - pass - # Network device with active POE - elif ( - client_id in controller.wireless_clients + if poe_client_id not in switches_off and ( + mac in controller.wireless_clients or client.sw_mac not in devices or not devices[client.sw_mac].ports[client.sw_port].port_poe or not devices[client.sw_mac].ports[client.sw_port].poe_enable @@ -187,31 +129,17 @@ def add_entities(controller, async_add_entities, switches, switches_off): if multi_clients_on_port: continue - switches[poe_client_id] = UniFiPOEClientSwitch(client, controller) - new_switches.append(switches[poe_client_id]) + switches.append(UniFiPOEClientSwitch(client, controller)) - if new_switches: - async_add_entities(new_switches) - - -@callback -def remove_entities(controller, mac_addresses, switches, entity_registry): - """Remove select switch entities.""" - for mac in mac_addresses: - - for switch_type in ("block", "poe"): - item_id = f"{switch_type}-{mac}" - - if item_id not in switches: - continue - - entity = switches.pop(item_id) - controller.hass.async_create_task(entity.async_remove()) + if switches: + async_add_entities(switches) class UniFiPOEClientSwitch(UniFiClient, SwitchDevice, RestoreEntity): """Representation of a client that uses POE.""" + TYPE = POE_SWITCH + def __init__(self, client, controller): """Set up POE switch.""" super().__init__(client, controller) @@ -225,7 +153,6 @@ class UniFiPOEClientSwitch(UniFiClient, SwitchDevice, RestoreEntity): await super().async_added_to_hass() state = await self.async_get_last_state() - if state is None: return @@ -238,11 +165,6 @@ class UniFiPOEClientSwitch(UniFiClient, SwitchDevice, RestoreEntity): if not self.client.sw_port: self.client.raw["sw_port"] = state.attributes["port"] - @property - def unique_id(self): - """Return a unique identifier for this switch.""" - return f"poe-{self.client.mac}" - @property def is_on(self): """Return true if POE is active.""" @@ -301,14 +223,16 @@ class UniFiPOEClientSwitch(UniFiClient, SwitchDevice, RestoreEntity): self.client.sw_port, ) + async def options_updated(self) -> None: + """Config entry options are updated, remove entity if option is disabled.""" + if not self.controller.option_poe_clients: + await self.async_remove() + class UniFiBlockClientSwitch(UniFiClient, SwitchDevice): """Representation of a blockable client.""" - @property - def unique_id(self): - """Return a unique identifier for this switch.""" - return f"block-{self.client.mac}" + TYPE = BLOCK_SWITCH @property def is_on(self): @@ -329,3 +253,8 @@ class UniFiBlockClientSwitch(UniFiClient, SwitchDevice): if self.is_blocked: return "mdi:network-off" return "mdi:network" + + async def options_updated(self) -> None: + """Config entry options are updated, remove entity if option is disabled.""" + if self.client.mac not in self.controller.option_block_clients: + await self.async_remove() diff --git a/homeassistant/components/unifi/unifi_client.py b/homeassistant/components/unifi/unifi_client.py index a30dc21854d..97c53d27f55 100644 --- a/homeassistant/components/unifi/unifi_client.py +++ b/homeassistant/components/unifi/unifi_client.py @@ -15,10 +15,9 @@ from aiounifi.events import ( WIRELESS_CLIENT_UNBLOCKED, ) +from homeassistant.components.unifi.unifi_entity_base import UniFiBase from homeassistant.core import callback from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC -from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.entity import Entity LOGGER = logging.getLogger(__name__) @@ -32,31 +31,33 @@ WIRELESS_CLIENT = ( ) -class UniFiClient(Entity): +class UniFiClient(UniFiBase): """Base class for UniFi clients.""" def __init__(self, client, controller) -> None: """Set up client.""" + super().__init__(controller) self.client = client - self.controller = controller self._is_wired = self.client.mac not in controller.wireless_clients self.is_blocked = self.client.blocked self.wired_connection = None self.wireless_connection = None + @property + def mac(self): + """Return MAC of client.""" + return self.client.mac + async def async_added_to_hass(self) -> None: """Client entity created.""" + await super().async_added_to_hass() LOGGER.debug("New client %s (%s)", self.entity_id, self.client.mac) self.client.register_callback(self.async_update_callback) - self.async_on_remove( - async_dispatcher_connect( - self.hass, self.controller.signal_reachable, self.async_update_callback - ) - ) async def async_will_remove_from_hass(self) -> None: """Disconnect client object when removed.""" + await super().async_will_remove_from_hass() self.client.remove_callback(self.async_update_callback) @callback @@ -93,6 +94,11 @@ class UniFiClient(Entity): return self.client.is_wired return self._is_wired + @property + def unique_id(self): + """Return a unique identifier for this switch.""" + return f"{self.TYPE}-{self.client.mac}" + @property def name(self) -> str: """Return the name of the client.""" @@ -107,8 +113,3 @@ class UniFiClient(Entity): def device_info(self) -> dict: """Return a client description for device registry.""" return {"connections": {(CONNECTION_NETWORK_MAC, self.client.mac)}} - - @property - def should_poll(self) -> bool: - """No polling needed.""" - return True diff --git a/homeassistant/components/unifi/unifi_entity_base.py b/homeassistant/components/unifi/unifi_entity_base.py new file mode 100644 index 00000000000..53c51d2476a --- /dev/null +++ b/homeassistant/components/unifi/unifi_entity_base.py @@ -0,0 +1,80 @@ +"""Base class for UniFi entities.""" + +from homeassistant.core import callback +from homeassistant.helpers.dispatcher import async_dispatcher_connect +from homeassistant.helpers.entity import Entity +from homeassistant.helpers.entity_registry import async_entries_for_device + + +class UniFiBase(Entity): + """UniFi entity base class.""" + + TYPE = "" + + def __init__(self, controller) -> None: + """Set up UniFi entity base.""" + self.controller = controller + + @property + def mac(self): + """Return MAC of entity.""" + raise NotImplementedError + + async def async_added_to_hass(self) -> None: + """Entity created.""" + self.controller.entities[self.platform.domain][self.TYPE].add(self.mac) + for signal, method in ( + (self.controller.signal_reachable, self.async_update_callback), + (self.controller.signal_options_update, self.options_updated), + (self.controller.signal_remove, self.remove_item), + ): + self.async_on_remove(async_dispatcher_connect(self.hass, signal, method)) + + async def async_will_remove_from_hass(self) -> None: + """Disconnect object when removed.""" + self.controller.entities[self.platform.domain][self.TYPE].remove(self.mac) + + async def async_remove(self): + """Clean up when removing entity. + + Remove entity if no entry in entity registry exist. + Remove entity registry entry if no entry in device registry exist. + Remove device registry entry if there is only one linked entity (this entity). + Remove entity registry entry if there are more than one entity linked to the device registry entry. + """ + entity_registry = await self.hass.helpers.entity_registry.async_get_registry() + entity_entry = entity_registry.async_get(self.entity_id) + if not entity_entry: + await super().async_remove() + return + + device_registry = await self.hass.helpers.device_registry.async_get_registry() + device_entry = device_registry.async_get(entity_entry.device_id) + if not device_entry: + entity_registry.async_remove(self.entity_id) + return + + if len(async_entries_for_device(entity_registry, entity_entry.device_id)) == 1: + device_registry.async_remove_device(device_entry.id) + return + + entity_registry.async_remove(self.entity_id) + + @callback + def async_update_callback(self): + """Update the entity's state.""" + raise NotImplementedError + + async def options_updated(self) -> None: + """Config entry options are updated, remove entity if option is disabled.""" + raise NotImplementedError + + async def remove_item(self, mac_addresses: set) -> None: + """Remove entity if MAC is part of set.""" + if self.mac in mac_addresses: + await self.async_remove() + + @property + def should_poll(self) -> bool: + """No polling needed.""" + return True diff --git a/tests/components/unifi/test_controller.py b/tests/components/unifi/test_controller.py index d78947f3134..844eaa5d222 100644 --- a/tests/components/unifi/test_controller.py +++ b/tests/components/unifi/test_controller.py @@ -207,7 +207,7 @@ async def test_reset_after_successful_setup(hass): """Calling reset when the entry has been setup.""" controller = await setup_unifi_integration(hass) - assert len(controller.listeners) == 9 + assert len(controller.listeners) == 6 result = await controller.async_reset() await hass.async_block_till_done()