diff --git a/homeassistant/components/device_tracker/ubus.py b/homeassistant/components/device_tracker/ubus.py index 736c1ba3168..5eaa4bf2fca 100644 --- a/homeassistant/components/device_tracker/ubus.py +++ b/homeassistant/components/device_tracker/ubus.py @@ -11,10 +11,11 @@ import threading from datetime import timedelta import requests +import voluptuous as vol -from homeassistant.components.device_tracker import DOMAIN +import homeassistant.helpers.config_validation as cv +from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME -from homeassistant.helpers import validate_config from homeassistant.util import Throttle # Return cached results if last scan was less then this time ago. @@ -22,14 +23,15 @@ MIN_TIME_BETWEEN_SCANS = timedelta(seconds=5) _LOGGER = logging.getLogger(__name__) +PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ + vol.Required(CONF_HOST): cv.string, + vol.Required(CONF_PASSWORD): cv.string, + vol.Required(CONF_USERNAME): cv.string +}) + def get_scanner(hass, config): """Validate the configuration and return an ubus scanner.""" - if not validate_config(config, - {DOMAIN: [CONF_HOST, CONF_USERNAME, CONF_PASSWORD]}, - _LOGGER): - return None - scanner = UbusDeviceScanner(config[DOMAIN]) return scanner if scanner.success_init else None diff --git a/homeassistant/components/device_tracker/unifi.py b/homeassistant/components/device_tracker/unifi.py index 2ae3f76e5e6..d654c3e3eef 100644 --- a/homeassistant/components/device_tracker/unifi.py +++ b/homeassistant/components/device_tracker/unifi.py @@ -6,10 +6,11 @@ https://home-assistant.io/components/device_tracker.unifi/ """ import logging import urllib +import voluptuous as vol -from homeassistant.components.device_tracker import DOMAIN +import homeassistant.helpers.config_validation as cv +from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA from homeassistant.const import CONF_HOST, CONF_USERNAME, CONF_PASSWORD -from homeassistant.helpers import validate_config # Unifi package doesn't list urllib3 as a requirement REQUIREMENTS = ['urllib3', 'unifi==1.2.5'] @@ -18,28 +19,24 @@ _LOGGER = logging.getLogger(__name__) CONF_PORT = 'port' CONF_SITE_ID = 'site_id' +PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ + vol.Optional(CONF_HOST, default='localhost'): cv.string, + vol.Optional(CONF_SITE_ID, default='default'): cv.string, + vol.Required(CONF_PASSWORD): cv.string, + vol.Required(CONF_USERNAME): cv.string, + vol.Required(CONF_PORT, default=8443): cv.port +}) + def get_scanner(hass, config): """Setup Unifi device_tracker.""" from unifi.controller import Controller - if not validate_config(config, {DOMAIN: [CONF_USERNAME, - CONF_PASSWORD]}, - _LOGGER): - _LOGGER.error('Invalid configuration') - return False - - this_config = config[DOMAIN] - host = this_config.get(CONF_HOST, 'localhost') - username = this_config.get(CONF_USERNAME) - password = this_config.get(CONF_PASSWORD) - site_id = this_config.get(CONF_SITE_ID, 'default') - - try: - port = int(this_config.get(CONF_PORT, 8443)) - except ValueError: - _LOGGER.error('Invalid port (must be numeric like 8443)') - return False + host = config[DOMAIN].get(CONF_HOST) + username = config[DOMAIN].get(CONF_USERNAME) + password = config[DOMAIN].get(CONF_PASSWORD) + site_id = config[DOMAIN].get(CONF_SITE_ID) + port = config[DOMAIN].get(CONF_PORT) try: ctrl = Controller(host, username, password, port, 'v4', site_id) diff --git a/tests/components/device_tracker/test_unifi.py b/tests/components/device_tracker/test_unifi.py index e3f64cc84c3..8e43eb7485e 100644 --- a/tests/components/device_tracker/test_unifi.py +++ b/tests/components/device_tracker/test_unifi.py @@ -3,9 +3,12 @@ import unittest from unittest import mock import urllib -from homeassistant.components.device_tracker import unifi as unifi -from homeassistant.const import CONF_HOST, CONF_USERNAME, CONF_PASSWORD from unifi import controller +import voluptuous as vol + +from homeassistant.components.device_tracker import DOMAIN, unifi as unifi +from homeassistant.const import (CONF_HOST, CONF_USERNAME, CONF_PASSWORD, + CONF_PLATFORM) class TestUnifiScanner(unittest.TestCase): @@ -16,13 +19,14 @@ class TestUnifiScanner(unittest.TestCase): def test_config_minimal(self, mock_ctrl, mock_scanner): """Test the setup with minimal configuration.""" config = { - 'device_tracker': { + DOMAIN: unifi.PLATFORM_SCHEMA({ + CONF_PLATFORM: unifi.DOMAIN, CONF_USERNAME: 'foo', CONF_PASSWORD: 'password', - } + }) } result = unifi.get_scanner(None, config) - self.assertEqual(unifi.UnifiScanner.return_value, result) + self.assertEqual(mock_scanner.return_value, result) mock_ctrl.assert_called_once_with('localhost', 'foo', 'password', 8443, 'v4', 'default') mock_scanner.assert_called_once_with(mock_ctrl.return_value) @@ -32,49 +36,38 @@ class TestUnifiScanner(unittest.TestCase): def test_config_full(self, mock_ctrl, mock_scanner): """Test the setup with full configuration.""" config = { - 'device_tracker': { + DOMAIN: unifi.PLATFORM_SCHEMA({ + CONF_PLATFORM: unifi.DOMAIN, CONF_USERNAME: 'foo', CONF_PASSWORD: 'password', CONF_HOST: 'myhost', 'port': 123, 'site_id': 'abcdef01', - } + }) } result = unifi.get_scanner(None, config) - self.assertEqual(unifi.UnifiScanner.return_value, result) + self.assertEqual(mock_scanner.return_value, result) mock_ctrl.assert_called_once_with('myhost', 'foo', 'password', 123, 'v4', 'abcdef01') mock_scanner.assert_called_once_with(mock_ctrl.return_value) - @mock.patch('homeassistant.components.device_tracker.unifi.UnifiScanner') - @mock.patch.object(controller, 'Controller') - def test_config_error(self, mock_ctrl, mock_scanner): + def test_config_error(self): """Test for configuration errors.""" - config = { - 'device_tracker': { + with self.assertRaises(vol.Invalid): + unifi.PLATFORM_SCHEMA({ + # no username + CONF_PLATFORM: unifi.DOMAIN, CONF_HOST: 'myhost', 'port': 123, - } - } - result = unifi.get_scanner(None, config) - self.assertFalse(result) - self.assertFalse(mock_ctrl.called) - - @mock.patch('homeassistant.components.device_tracker.unifi.UnifiScanner') - @mock.patch.object(controller, 'Controller') - def test_config_badport(self, mock_ctrl, mock_scanner): - """Test the setup with a bad port.""" - config = { - 'device_tracker': { + }) + with self.assertRaises(vol.Invalid): + unifi.PLATFORM_SCHEMA({ + CONF_PLATFORM: unifi.DOMAIN, CONF_USERNAME: 'foo', CONF_PASSWORD: 'password', CONF_HOST: 'myhost', - 'port': 'foo', - } - } - result = unifi.get_scanner(None, config) - self.assertFalse(result) - self.assertFalse(mock_ctrl.called) + 'port': 'foo', # bad port! + }) @mock.patch('homeassistant.components.device_tracker.unifi.UnifiScanner') @mock.patch.object(controller, 'Controller') @@ -82,6 +75,7 @@ class TestUnifiScanner(unittest.TestCase): """Test for controller failure.""" config = { 'device_tracker': { + CONF_PLATFORM: unifi.DOMAIN, CONF_USERNAME: 'foo', CONF_PASSWORD: 'password', } @@ -91,7 +85,7 @@ class TestUnifiScanner(unittest.TestCase): result = unifi.get_scanner(None, config) self.assertFalse(result) - def test_scanner_update(self): + def test_scanner_update(self): # pylint: disable=no-self-use """Test the scanner update.""" ctrl = mock.MagicMock() fake_clients = [ @@ -102,7 +96,7 @@ class TestUnifiScanner(unittest.TestCase): unifi.UnifiScanner(ctrl) ctrl.get_clients.assert_called_once_with() - def test_scanner_update_error(self): + def test_scanner_update_error(self): # pylint: disable=no-self-use """Test the scanner update for error.""" ctrl = mock.MagicMock() ctrl.get_clients.side_effect = urllib.error.HTTPError(