diff --git a/homeassistant/components/device_tracker/unifi.py b/homeassistant/components/device_tracker/unifi.py index 59b538cd824..54aa9a5972c 100644 --- a/homeassistant/components/device_tracker/unifi.py +++ b/homeassistant/components/device_tracker/unifi.py @@ -21,11 +21,13 @@ _LOGGER = logging.getLogger(__name__) CONF_PORT = 'port' CONF_SITE_ID = 'site_id' CONF_DETECTION_TIME = 'detection_time' +CONF_SSID_FILTER = 'ssid_filter' DEFAULT_HOST = 'localhost' DEFAULT_PORT = 8443 DEFAULT_VERIFY_SSL = True DEFAULT_DETECTION_TIME = timedelta(seconds=300) +DEFAULT_SSID_FILTER = None NOTIFICATION_ID = 'unifi_notification' NOTIFICATION_TITLE = 'Unifi Device Tracker Setup' @@ -39,7 +41,9 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): vol.Any( cv.boolean, cv.isfile), vol.Optional(CONF_DETECTION_TIME, default=DEFAULT_DETECTION_TIME): vol.All( - cv.time_period, cv.positive_timedelta) + cv.time_period, cv.positive_timedelta), + vol.Optional(CONF_SSID_FILTER, default=DEFAULT_SSID_FILTER): vol.All( + cv.ensure_list, [cv.string]) }) @@ -54,6 +58,7 @@ def get_scanner(hass, config): port = config[DOMAIN].get(CONF_PORT) verify_ssl = config[DOMAIN].get(CONF_VERIFY_SSL) detection_time = config[DOMAIN].get(CONF_DETECTION_TIME) + ssid_filter = config[DOMAIN].get(CONF_SSID_FILTER) try: ctrl = Controller(host, username, password, port, version='v4', @@ -69,16 +74,18 @@ def get_scanner(hass, config): notification_id=NOTIFICATION_ID) return False - return UnifiScanner(ctrl, detection_time) + return UnifiScanner(ctrl, detection_time, ssid_filter) class UnifiScanner(DeviceScanner): """Provide device_tracker support from Unifi WAP client data.""" - def __init__(self, controller, detection_time: timedelta) -> None: + def __init__(self, controller, detection_time: timedelta, + ssid_filter) -> None: """Initialize the scanner.""" self._detection_time = detection_time self._controller = controller + self._ssid_filter = ssid_filter self._update() def _update(self): @@ -90,6 +97,11 @@ class UnifiScanner(DeviceScanner): _LOGGER.error("Failed to scan clients: %s", ex) clients = [] + # Filter clients to provided SSID list + if self._ssid_filter: + clients = filter(lambda x: x['essid'] in self._ssid_filter, + clients) + self._clients = { client['mac']: client for client in clients diff --git a/tests/components/device_tracker/test_unifi.py b/tests/components/device_tracker/test_unifi.py index 083315b4c71..ccc58d728ed 100644 --- a/tests/components/device_tracker/test_unifi.py +++ b/tests/components/device_tracker/test_unifi.py @@ -53,7 +53,8 @@ def test_config_valid_verify_ssl(hass, mock_scanner, mock_ctrl): assert mock_scanner.call_count == 1 assert mock_scanner.call_args == mock.call(mock_ctrl.return_value, - DEFAULT_DETECTION_TIME) + DEFAULT_DETECTION_TIME, + None) def test_config_minimal(hass, mock_scanner, mock_ctrl): @@ -74,7 +75,8 @@ def test_config_minimal(hass, mock_scanner, mock_ctrl): assert mock_scanner.call_count == 1 assert mock_scanner.call_args == mock.call(mock_ctrl.return_value, - DEFAULT_DETECTION_TIME) + DEFAULT_DETECTION_TIME, + None) def test_config_full(hass, mock_scanner, mock_ctrl): @@ -100,7 +102,8 @@ def test_config_full(hass, mock_scanner, mock_ctrl): assert mock_scanner.call_count == 1 assert mock_scanner.call_args == mock.call(mock_ctrl.return_value, - DEFAULT_DETECTION_TIME) + DEFAULT_DETECTION_TIME, + None) def test_config_error(): @@ -148,11 +151,13 @@ def test_scanner_update(): """Test the scanner update.""" ctrl = mock.MagicMock() fake_clients = [ - {'mac': '123', 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, - {'mac': '234', 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, + {'mac': '123', 'essid': 'barnet', + 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, + {'mac': '234', 'essid': 'barnet', + 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, ] ctrl.get_clients.return_value = fake_clients - unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME) + unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None) assert ctrl.get_clients.call_count == 1 assert ctrl.get_clients.call_args == mock.call() @@ -162,36 +167,61 @@ def test_scanner_update_error(): ctrl = mock.MagicMock() ctrl.get_clients.side_effect = APIError( '/', 500, 'foo', {}, None) - unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME) + unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None) def test_scan_devices(): """Test the scanning for devices.""" ctrl = mock.MagicMock() fake_clients = [ - {'mac': '123', 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, - {'mac': '234', 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, + {'mac': '123', 'essid': 'barnet', + 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, + {'mac': '234', 'essid': 'barnet', + 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, ] ctrl.get_clients.return_value = fake_clients - scanner = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME) + scanner = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None) assert set(scanner.scan_devices()) == set(['123', '234']) +def test_scan_devices_filtered(): + """Test the scanning for devices based on SSID.""" + ctrl = mock.MagicMock() + fake_clients = [ + {'mac': '123', 'essid': 'foonet', + 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, + {'mac': '234', 'essid': 'foonet', + 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, + {'mac': '567', 'essid': 'notnet', + 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, + {'mac': '890', 'essid': 'barnet', + 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, + ] + + ssid_filter = ['foonet', 'barnet'] + ctrl.get_clients.return_value = fake_clients + scanner = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, ssid_filter) + assert set(scanner.scan_devices()) == set(['123', '234', '890']) + + def test_get_device_name(): """Test the getting of device names.""" ctrl = mock.MagicMock() fake_clients = [ {'mac': '123', 'hostname': 'foobar', + 'essid': 'barnet', 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, {'mac': '234', 'name': 'Nice Name', + 'essid': 'barnet', 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, {'mac': '456', + 'essid': 'barnet', 'last_seen': '1504786810'}, ] ctrl.get_clients.return_value = fake_clients - scanner = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME) + scanner = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None) assert scanner.get_device_name('123') == 'foobar' assert scanner.get_device_name('234') == 'Nice Name' assert scanner.get_device_name('456') is None