diff --git a/homeassistant/components/unifi/device_tracker.py b/homeassistant/components/unifi/device_tracker.py index 5cfb8624116..8c4e54e0129 100644 --- a/homeassistant/components/unifi/device_tracker.py +++ b/homeassistant/components/unifi/device_tracker.py @@ -6,12 +6,13 @@ import voluptuous as vol from homeassistant import config_entries from homeassistant.components import unifi -from homeassistant.components.device_tracker import PLATFORM_SCHEMA +from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA from homeassistant.components.device_tracker.config_entry import ScannerEntity from homeassistant.components.device_tracker.const import SOURCE_TYPE_ROUTER from homeassistant.core import callback from homeassistant.const import ( CONF_HOST, CONF_USERNAME, CONF_PASSWORD, CONF_PORT, CONF_VERIFY_SSL) +from homeassistant.helpers import entity_registry from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from homeassistant.helpers.dispatcher import async_dispatcher_connect @@ -80,6 +81,23 @@ async def async_setup_entry(hass, config_entry, async_add_entities): controller = hass.data[unifi.DOMAIN][controller_id] tracked = {} + registry = await entity_registry.async_get_registry(hass) + + # Restore clients that is not a part of active clients list. + for entity in registry.entities.values(): + + if entity.config_entry_id == config_entry.entry_id and \ + entity.domain == DOMAIN: + + mac, _ = entity.unique_id.split('-', 1) + + if mac in controller.api.clients or \ + mac not in controller.api.clients_all: + continue + + client = controller.api.clients_all[mac] + controller.api.clients.process_raw([client.raw]) + @callback def update_controller(): """Update the values of the controller.""" diff --git a/tests/components/unifi/test_device_tracker.py b/tests/components/unifi/test_device_tracker.py index 3a209e610ce..f8be08975fa 100644 --- a/tests/components/unifi/test_device_tracker.py +++ b/tests/components/unifi/test_device_tracker.py @@ -6,7 +6,7 @@ from datetime import timedelta import pytest -from aiounifi.clients import Clients +from aiounifi.clients import Clients, ClientsAll from aiounifi.devices import Devices from homeassistant import config_entries @@ -15,6 +15,7 @@ from homeassistant.components.unifi.const import ( CONF_CONTROLLER, CONF_SITE_ID, UNIFI_CONFIG) from homeassistant.const import ( CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME, CONF_VERIFY_SSL) +from homeassistant.helpers import entity_registry from homeassistant.setup import async_setup_component import homeassistant.components.device_tracker as device_tracker @@ -75,6 +76,7 @@ def mock_controller(hass): controller.mock_client_responses = deque() controller.mock_device_responses = deque() + controller.mock_client_all_responses = deque() async def mock_request(method, path, **kwargs): kwargs['method'] = method @@ -84,10 +86,13 @@ def mock_controller(hass): return controller.mock_client_responses.popleft() if path == 's/{site}/stat/device': return controller.mock_device_responses.popleft() + if path == 's/{site}/rest/user': + return controller.mock_client_all_responses.popleft() return None controller.api.clients = Clients({}, mock_request) controller.api.devices = Devices({}, mock_request) + controller.api.clients_all = ClientsAll({}, mock_request) return controller @@ -98,7 +103,7 @@ async def setup_controller(hass, mock_controller): hass.data[unifi.DOMAIN] = {CONTROLLER_ID: mock_controller} config_entry = config_entries.ConfigEntry( 1, unifi.DOMAIN, 'Mock Title', ENTRY_CONFIG, 'test', - config_entries.CONN_CLASS_LOCAL_POLL) + config_entries.CONN_CLASS_LOCAL_POLL, entry_id=1) mock_controller.config_entry = config_entry await mock_controller.async_update() @@ -159,3 +164,30 @@ async def test_tracked_devices(hass, mock_controller): device_1 = hass.states.get('device_tracker.client_1') assert device_1.state == 'home' + + +async def test_restoring_client(hass, mock_controller): + """Test the update_items function with some clients.""" + mock_controller.mock_client_responses.append([CLIENT_2]) + mock_controller.mock_device_responses.append({}) + mock_controller.mock_client_all_responses.append([CLIENT_1]) + mock_controller.unifi_config = { + unifi.CONF_BLOCK_CLIENT: True + } + + registry = await entity_registry.async_get_registry(hass) + registry.async_get_or_create( + device_tracker.DOMAIN, unifi_dt.UNIFI_DOMAIN, + '{}-mock-site'.format(CLIENT_1['mac']), + suggested_object_id=CLIENT_1['hostname'], config_entry_id=1) + registry.async_get_or_create( + device_tracker.DOMAIN, unifi_dt.UNIFI_DOMAIN, + '{}-mock-site'.format(CLIENT_2['mac']), + suggested_object_id=CLIENT_2['hostname'], config_entry_id=1) + + await setup_controller(hass, mock_controller) + assert len(mock_controller.mock_requests) == 3 + assert len(hass.states.async_all()) == 4 + + device_1 = hass.states.get('device_tracker.client_1') + assert device_1 is not None