diff --git a/homeassistant/components/smartthings/__init__.py b/homeassistant/components/smartthings/__init__.py index 3cf38c358bc..53ff6169c0a 100644 --- a/homeassistant/components/smartthings/__init__.py +++ b/homeassistant/components/smartthings/__init__.py @@ -14,16 +14,20 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send) from homeassistant.helpers.entity import Entity +from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.typing import ConfigType, HomeAssistantType from .config_flow import SmartThingsFlowHandler # noqa from .const import ( - CONF_APP_ID, CONF_INSTALLED_APP_ID, DATA_BROKERS, DATA_MANAGER, DOMAIN, - EVENT_BUTTON, SIGNAL_SMARTTHINGS_UPDATE, SUPPORTED_PLATFORMS) + CONF_APP_ID, CONF_INSTALLED_APP_ID, CONF_OAUTH_CLIENT_ID, + CONF_OAUTH_CLIENT_SECRET, CONF_REFRESH_TOKEN, DATA_BROKERS, DATA_MANAGER, + DOMAIN, EVENT_BUTTON, SIGNAL_SMARTTHINGS_UPDATE, SUPPORTED_PLATFORMS, + TOKEN_REFRESH_INTERVAL) from .smartapp import ( - setup_smartapp, setup_smartapp_endpoint, validate_installed_app) + setup_smartapp, setup_smartapp_endpoint, smartapp_sync_subscriptions, + validate_installed_app) -REQUIREMENTS = ['pysmartapp==0.3.0', 'pysmartthings==0.6.2'] +REQUIREMENTS = ['pysmartapp==0.3.0', 'pysmartthings==0.6.3'] DEPENDENCIES = ['webhook'] _LOGGER = logging.getLogger(__name__) @@ -35,6 +39,33 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType): return True +async def async_migrate_entry(hass: HomeAssistantType, entry: ConfigEntry): + """Handle migration of a previous version config entry. + + A config entry created under a previous version must go through the + integration setup again so we can properly retrieve the needed data + elements. Force this by removing the entry and triggering a new flow. + """ + from pysmartthings import SmartThings + + # Delete the installed app + api = SmartThings(async_get_clientsession(hass), + entry.data[CONF_ACCESS_TOKEN]) + await api.delete_installed_app(entry.data[CONF_INSTALLED_APP_ID]) + # Delete the entry + hass.async_create_task( + hass.config_entries.async_remove(entry.entry_id)) + # only create new flow if there isn't a pending one for SmartThings. + flows = hass.config_entries.flow.async_progress() + if not [flow for flow in flows if flow['handler'] == DOMAIN]: + hass.async_create_task( + hass.config_entries.flow.async_init( + DOMAIN, context={'source': 'import'})) + + # Return False because it could not be migrated. + return False + + async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry): """Initialize config entry which represents an installed SmartApp.""" from pysmartthings import SmartThings @@ -62,6 +93,14 @@ async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry): installed_app = await validate_installed_app( api, entry.data[CONF_INSTALLED_APP_ID]) + # Get SmartApp token to sync subscriptions + token = await api.generate_tokens( + entry.data[CONF_OAUTH_CLIENT_ID], + entry.data[CONF_OAUTH_CLIENT_SECRET], + entry.data[CONF_REFRESH_TOKEN]) + entry.data[CONF_REFRESH_TOKEN] = token.refresh_token + hass.config_entries.async_update_entry(entry) + # Get devices and their current status devices = await api.devices( location_ids=[installed_app.location_id]) @@ -71,18 +110,21 @@ async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry): await device.status.refresh() except ClientResponseError: _LOGGER.debug("Unable to update status for device: %s (%s), " - "the device will be ignored", + "the device will be excluded", device.label, device.device_id, exc_info=True) devices.remove(device) await asyncio.gather(*[retrieve_device_status(d) for d in devices.copy()]) + # Sync device subscriptions + await smartapp_sync_subscriptions( + hass, token.access_token, installed_app.location_id, + installed_app.installed_app_id, devices) + # Setup device broker - broker = DeviceBroker(hass, devices, - installed_app.installed_app_id) - broker.event_handler_disconnect = \ - smart_app.connect_event(broker.event_handler) + broker = DeviceBroker(hass, entry, token, smart_app, devices) + broker.connect() hass.data[DOMAIN][DATA_BROKERS][entry.entry_id] = broker except ClientResponseError as ex: @@ -117,8 +159,8 @@ async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry): async def async_unload_entry(hass: HomeAssistantType, entry: ConfigEntry): """Unload a config entry.""" broker = hass.data[DOMAIN][DATA_BROKERS].pop(entry.entry_id, None) - if broker and broker.event_handler_disconnect: - broker.event_handler_disconnect() + if broker: + broker.disconnect() tasks = [hass.config_entries.async_forward_entry_unload(entry, component) for component in SUPPORTED_PLATFORMS] @@ -128,14 +170,18 @@ async def async_unload_entry(hass: HomeAssistantType, entry: ConfigEntry): class DeviceBroker: """Manages an individual SmartThings config entry.""" - def __init__(self, hass: HomeAssistantType, devices: Iterable, - installed_app_id: str): + def __init__(self, hass: HomeAssistantType, entry: ConfigEntry, + token, smart_app, devices: Iterable): """Create a new instance of the DeviceBroker.""" self._hass = hass - self._installed_app_id = installed_app_id - self.assignments = self._assign_capabilities(devices) + self._entry = entry + self._installed_app_id = entry.data[CONF_INSTALLED_APP_ID] + self._smart_app = smart_app + self._token = token + self._event_disconnect = None + self._regenerate_token_remove = None + self._assignments = self._assign_capabilities(devices) self.devices = {device.device_id: device for device in devices} - self.event_handler_disconnect = None def _assign_capabilities(self, devices: Iterable): """Assign platforms to capabilities.""" @@ -158,17 +204,45 @@ class DeviceBroker: assignments[device.device_id] = slots return assignments + def connect(self): + """Connect handlers/listeners for device/lifecycle events.""" + # Setup interval to regenerate the refresh token on a periodic basis. + # Tokens expire in 30 days and once expired, cannot be recovered. + async def regenerate_refresh_token(now): + """Generate a new refresh token and update the config entry.""" + await self._token.refresh( + self._entry.data[CONF_OAUTH_CLIENT_ID], + self._entry.data[CONF_OAUTH_CLIENT_SECRET]) + self._entry.data[CONF_REFRESH_TOKEN] = self._token.refresh_token + self._hass.config_entries.async_update_entry(self._entry) + _LOGGER.debug('Regenerated refresh token for installed app: %s', + self._installed_app_id) + + self._regenerate_token_remove = async_track_time_interval( + self._hass, regenerate_refresh_token, TOKEN_REFRESH_INTERVAL) + + # Connect handler to incoming device events + self._event_disconnect = \ + self._smart_app.connect_event(self._event_handler) + + def disconnect(self): + """Disconnects handlers/listeners for device/lifecycle events.""" + if self._regenerate_token_remove: + self._regenerate_token_remove() + if self._event_disconnect: + self._event_disconnect() + def get_assigned(self, device_id: str, platform: str): """Get the capabilities assigned to the platform.""" - slots = self.assignments.get(device_id, {}) + slots = self._assignments.get(device_id, {}) return [key for key, value in slots.items() if value == platform] def any_assigned(self, device_id: str, platform: str): """Return True if the platform has any assigned capabilities.""" - slots = self.assignments.get(device_id, {}) + slots = self._assignments.get(device_id, {}) return any(value for value in slots.values() if value == platform) - async def event_handler(self, req, resp, app): + async def _event_handler(self, req, resp, app): """Broker for incoming events.""" from pysmartapp.event import EVENT_TYPE_DEVICE from pysmartthings import Capability, Attribute diff --git a/homeassistant/components/smartthings/config_flow.py b/homeassistant/components/smartthings/config_flow.py index 4663222c3b4..c290f0f8e55 100644 --- a/homeassistant/components/smartthings/config_flow.py +++ b/homeassistant/components/smartthings/config_flow.py @@ -9,7 +9,8 @@ from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.helpers.aiohttp_client import async_get_clientsession from .const import ( - CONF_APP_ID, CONF_INSTALLED_APP_ID, CONF_LOCATION_ID, DOMAIN, + APP_OAUTH_CLIENT_NAME, APP_OAUTH_SCOPES, CONF_APP_ID, CONF_INSTALLED_APPS, + CONF_LOCATION_ID, CONF_OAUTH_CLIENT_ID, CONF_OAUTH_CLIENT_SECRET, DOMAIN, VAL_UID_MATCHER) from .smartapp import ( create_app, find_app, setup_smartapp, setup_smartapp_endpoint, update_app) @@ -35,7 +36,7 @@ class SmartThingsFlowHandler(config_entries.ConfigFlow): b) Config entries setup for all installations """ - VERSION = 1 + VERSION = 2 CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_PUSH def __init__(self): @@ -43,6 +44,8 @@ class SmartThingsFlowHandler(config_entries.ConfigFlow): self.access_token = None self.app_id = None self.api = None + self.oauth_client_secret = None + self.oauth_client_id = None async def async_step_import(self, user_input=None): """Occurs when a previously entry setup fails and is re-initiated.""" @@ -50,7 +53,7 @@ class SmartThingsFlowHandler(config_entries.ConfigFlow): async def async_step_user(self, user_input=None): """Get access token and validate it.""" - from pysmartthings import APIResponseError, SmartThings + from pysmartthings import APIResponseError, AppOAuth, SmartThings errors = {} if not self.hass.config.api.base_url.lower().startswith('https://'): @@ -83,10 +86,18 @@ class SmartThingsFlowHandler(config_entries.ConfigFlow): if app: await app.refresh() # load all attributes await update_app(self.hass, app) + # Get oauth client id/secret by regenerating it + app_oauth = AppOAuth(app.app_id) + app_oauth.client_name = APP_OAUTH_CLIENT_NAME + app_oauth.scope.extend(APP_OAUTH_SCOPES) + client = await self.api.generate_app_oauth(app_oauth) else: - app = await create_app(self.hass, self.api) + app, client = await create_app(self.hass, self.api) setup_smartapp(self.hass, app) self.app_id = app.app_id + self.oauth_client_secret = client.client_secret + self.oauth_client_id = client.client_id + except APIResponseError as ex: if ex.is_target_error(): errors['base'] = 'webhook_error' @@ -113,19 +124,23 @@ class SmartThingsFlowHandler(config_entries.ConfigFlow): async def async_step_wait_install(self, user_input=None): """Wait for SmartApp installation.""" - from pysmartthings import InstalledAppStatus - errors = {} if user_input is None: return self._show_step_wait_install(errors) # Find installed apps that were authorized - installed_apps = [app for app in await self.api.installed_apps( - installed_app_status=InstalledAppStatus.AUTHORIZED) - if app.app_id == self.app_id] + installed_apps = self.hass.data[DOMAIN][CONF_INSTALLED_APPS].copy() if not installed_apps: errors['base'] = 'app_not_installed' return self._show_step_wait_install(errors) + self.hass.data[DOMAIN][CONF_INSTALLED_APPS].clear() + + # Enrich the data + for installed_app in installed_apps: + installed_app[CONF_APP_ID] = self.app_id + installed_app[CONF_ACCESS_TOKEN] = self.access_token + installed_app[CONF_OAUTH_CLIENT_ID] = self.oauth_client_id + installed_app[CONF_OAUTH_CLIENT_SECRET] = self.oauth_client_secret # User may have installed the SmartApp in more than one SmartThings # location. Config flows are created for the additional installations @@ -133,21 +148,10 @@ class SmartThingsFlowHandler(config_entries.ConfigFlow): self.hass.async_create_task( self.hass.config_entries.flow.async_init( DOMAIN, context={'source': 'install'}, - data={ - CONF_APP_ID: installed_app.app_id, - CONF_INSTALLED_APP_ID: installed_app.installed_app_id, - CONF_LOCATION_ID: installed_app.location_id, - CONF_ACCESS_TOKEN: self.access_token - })) + data=installed_app)) - # return entity for the first one. - installed_app = installed_apps[0] - return await self.async_step_install({ - CONF_APP_ID: installed_app.app_id, - CONF_INSTALLED_APP_ID: installed_app.installed_app_id, - CONF_LOCATION_ID: installed_app.location_id, - CONF_ACCESS_TOKEN: self.access_token - }) + # Create config entity for the first one. + return await self.async_step_install(installed_apps[0]) def _show_step_user(self, errors): return self.async_show_form( diff --git a/homeassistant/components/smartthings/const.py b/homeassistant/components/smartthings/const.py index 27260b155d1..d423bcde44f 100644 --- a/homeassistant/components/smartthings/const.py +++ b/homeassistant/components/smartthings/const.py @@ -1,14 +1,20 @@ """Constants used by the SmartThings component and platforms.""" +from datetime import timedelta import re +APP_OAUTH_CLIENT_NAME = "Home Assistant" APP_OAUTH_SCOPES = [ 'r:devices:*' ] APP_NAME_PREFIX = 'homeassistant.' CONF_APP_ID = 'app_id' CONF_INSTALLED_APP_ID = 'installed_app_id' +CONF_INSTALLED_APPS = 'installed_apps' CONF_INSTANCE_ID = 'instance_id' CONF_LOCATION_ID = 'location_id' +CONF_OAUTH_CLIENT_ID = 'client_id' +CONF_OAUTH_CLIENT_SECRET = 'client_secret' +CONF_REFRESH_TOKEN = 'refresh_token' DATA_MANAGER = 'manager' DATA_BROKERS = 'brokers' DOMAIN = 'smartthings' @@ -29,6 +35,7 @@ SUPPORTED_PLATFORMS = [ 'binary_sensor', 'sensor' ] +TOKEN_REFRESH_INTERVAL = timedelta(days=14) VAL_UID = "^(?:([0-9a-fA-F]{32})|([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]" \ "{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}))$" VAL_UID_MATCHER = re.compile(VAL_UID) diff --git a/homeassistant/components/smartthings/smartapp.py b/homeassistant/components/smartthings/smartapp.py index 89043d4f76c..5527fda54f4 100644 --- a/homeassistant/components/smartthings/smartapp.py +++ b/homeassistant/components/smartthings/smartapp.py @@ -13,15 +13,16 @@ from uuid import uuid4 from aiohttp import web from homeassistant.components import webhook -from homeassistant.const import CONF_ACCESS_TOKEN, CONF_WEBHOOK_ID +from homeassistant.const import CONF_WEBHOOK_ID from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send) from homeassistant.helpers.typing import HomeAssistantType from .const import ( - APP_NAME_PREFIX, APP_OAUTH_SCOPES, CONF_APP_ID, CONF_INSTALLED_APP_ID, - CONF_INSTANCE_ID, CONF_LOCATION_ID, DATA_BROKERS, DATA_MANAGER, DOMAIN, + APP_NAME_PREFIX, APP_OAUTH_CLIENT_NAME, APP_OAUTH_SCOPES, CONF_APP_ID, + CONF_INSTALLED_APP_ID, CONF_INSTALLED_APPS, CONF_INSTANCE_ID, + CONF_LOCATION_ID, CONF_REFRESH_TOKEN, DATA_BROKERS, DATA_MANAGER, DOMAIN, SETTINGS_INSTANCE_ID, SIGNAL_SMARTAPP_PREFIX, STORAGE_KEY, STORAGE_VERSION) _LOGGER = logging.getLogger(__name__) @@ -83,7 +84,7 @@ async def create_app(hass: HomeAssistantType, api): app = App() for key, value in template.items(): setattr(app, key, value) - app = (await api.create_app(app))[0] + app, client = await api.create_app(app) _LOGGER.debug("Created SmartApp '%s' (%s)", app.app_name, app.app_id) # Set unique hass id in settings @@ -97,12 +98,12 @@ async def create_app(hass: HomeAssistantType, api): # Set oauth scopes oauth = AppOAuth(app.app_id) - oauth.client_name = 'Home Assistant' + oauth.client_name = APP_OAUTH_CLIENT_NAME oauth.scope.extend(APP_OAUTH_SCOPES) await api.update_app_oauth(oauth) _LOGGER.debug("Updated App OAuth for SmartApp '%s' (%s)", app.app_name, app.app_id) - return app + return app, client async def update_app(hass: HomeAssistantType, app): @@ -185,32 +186,24 @@ async def setup_smartapp_endpoint(hass: HomeAssistantType): DATA_MANAGER: manager, CONF_INSTANCE_ID: config[CONF_INSTANCE_ID], DATA_BROKERS: {}, - CONF_WEBHOOK_ID: config[CONF_WEBHOOK_ID] + CONF_WEBHOOK_ID: config[CONF_WEBHOOK_ID], + CONF_INSTALLED_APPS: [] } async def smartapp_sync_subscriptions( hass: HomeAssistantType, auth_token: str, location_id: str, - installed_app_id: str, *, skip_delete=False): + installed_app_id: str, devices): """Synchronize subscriptions of an installed up.""" from pysmartthings import ( - CAPABILITIES, SmartThings, SourceType, Subscription) + CAPABILITIES, SmartThings, SourceType, Subscription, + SubscriptionEntity + ) api = SmartThings(async_get_clientsession(hass), auth_token) - devices = await api.devices(location_ids=[location_id]) + tasks = [] - # Build set of capabilities and prune unsupported ones - capabilities = set() - for device in devices: - capabilities.update(device.capabilities) - capabilities.intersection_update(CAPABILITIES) - - # Remove all (except for installs) - if not skip_delete: - await api.delete_subscriptions(installed_app_id) - - # Create for each capability - async def create_subscription(target): + async def create_subscription(target: str): sub = Subscription() sub.installed_app_id = installed_app_id sub.location_id = location_id @@ -224,52 +217,89 @@ async def smartapp_sync_subscriptions( _LOGGER.exception("Failed to create subscription for '%s' under " "app '%s'", target, installed_app_id) - tasks = [create_subscription(c) for c in capabilities] - await asyncio.gather(*tasks) + async def delete_subscription(sub: SubscriptionEntity): + try: + await api.delete_subscription( + installed_app_id, sub.subscription_id) + _LOGGER.debug("Removed subscription for '%s' under app '%s' " + "because it was no longer needed", + sub.capability, installed_app_id) + except Exception: # pylint:disable=broad-except + _LOGGER.exception("Failed to remove subscription for '%s' under " + "app '%s'", sub.capability, installed_app_id) + + # Build set of capabilities and prune unsupported ones + capabilities = set() + for device in devices: + capabilities.update(device.capabilities) + capabilities.intersection_update(CAPABILITIES) + + # Get current subscriptions and find differences + subscriptions = await api.subscriptions(installed_app_id) + for subscription in subscriptions: + if subscription.capability in capabilities: + capabilities.remove(subscription.capability) + else: + # Delete the subscription + tasks.append(delete_subscription(subscription)) + + # Remaining capabilities need subscriptions created + tasks.extend([create_subscription(c) for c in capabilities]) + + if tasks: + await asyncio.gather(*tasks) + else: + _LOGGER.debug("Subscriptions for app '%s' are up-to-date", + installed_app_id) async def smartapp_install(hass: HomeAssistantType, req, resp, app): """ Handle when a SmartApp is installed by the user into a location. - Setup subscriptions using the access token SmartThings provided in the - event. An explicit subscription is required for each 'capability' in order - to receive the related attribute updates. Finally, create a config entry - representing the installation if this is not the first installation under - the account. + Create a config entry representing the installation if this is not + the first installation under the account, otherwise store the data + for the config flow. """ - await smartapp_sync_subscriptions( - hass, req.auth_token, req.location_id, req.installed_app_id, - skip_delete=True) - - # The permanent access token is copied from another config flow with the - # same parent app_id. If one is not found, that means the user is within - # the initial config flow and the entry at the conclusion. - access_token = next(( - entry.data.get(CONF_ACCESS_TOKEN) for entry + install_data = { + CONF_INSTALLED_APP_ID: req.installed_app_id, + CONF_LOCATION_ID: req.location_id, + CONF_REFRESH_TOKEN: req.refresh_token + } + # App attributes (client id/secret, etc...) are copied from another entry + # with the same parent app_id. If one is not found, the install data is + # stored for the config flow to retrieve during the wait step. + entry = next(( + entry for entry in hass.config_entries.async_entries(DOMAIN) if entry.data[CONF_APP_ID] == app.app_id), None) - if access_token: + if entry: + data = entry.data.copy() + data.update(install_data) # Add as job not needed because the current coroutine was invoked # from the dispatcher and is not being awaited. await hass.config_entries.flow.async_init( DOMAIN, context={'source': 'install'}, - data={ - CONF_APP_ID: app.app_id, - CONF_INSTALLED_APP_ID: req.installed_app_id, - CONF_LOCATION_ID: req.location_id, - CONF_ACCESS_TOKEN: access_token - }) + data=data) + else: + # Store the data where the flow can find it + hass.data[DOMAIN][CONF_INSTALLED_APPS].append(install_data) async def smartapp_update(hass: HomeAssistantType, req, resp, app): """ Handle when a SmartApp is updated (reconfigured) by the user. - Synchronize subscriptions to ensure we're up-to-date. + Store the refresh token in the config entry. """ - await smartapp_sync_subscriptions( - hass, req.auth_token, req.location_id, req.installed_app_id) + # Update refresh token in config entry + entry = next((entry for entry in hass.config_entries.async_entries(DOMAIN) + if entry.data.get(CONF_INSTALLED_APP_ID) == + req.installed_app_id), + None) + if entry: + entry.data[CONF_REFRESH_TOKEN] = req.refresh_token + hass.config_entries.async_update_entry(entry) _LOGGER.debug("SmartApp '%s' under parent app '%s' was updated", req.installed_app_id, app.app_id) diff --git a/requirements_all.txt b/requirements_all.txt index d2bbaefce38..8f850c49af3 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1252,7 +1252,7 @@ pysma==0.3.1 pysmartapp==0.3.0 # homeassistant.components.smartthings -pysmartthings==0.6.2 +pysmartthings==0.6.3 # homeassistant.components.device_tracker.snmp # homeassistant.components.sensor.snmp diff --git a/requirements_test_all.txt b/requirements_test_all.txt index f35d582bcab..64c28534046 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -223,7 +223,7 @@ pyqwikswitch==0.8 pysmartapp==0.3.0 # homeassistant.components.smartthings -pysmartthings==0.6.2 +pysmartthings==0.6.3 # homeassistant.components.sonos pysonos==0.0.6 diff --git a/tests/components/smartthings/conftest.py b/tests/components/smartthings/conftest.py index ee892fb03b9..4622e49b0c6 100644 --- a/tests/components/smartthings/conftest.py +++ b/tests/components/smartthings/conftest.py @@ -4,8 +4,8 @@ from unittest.mock import Mock, patch from uuid import uuid4 from pysmartthings import ( - CLASSIFICATION_AUTOMATION, AppEntity, AppSettings, DeviceEntity, - InstalledApp, Location) + CLASSIFICATION_AUTOMATION, AppEntity, AppOAuthClient, AppSettings, + DeviceEntity, InstalledApp, Location, Subscription) from pysmartthings.api import Api import pytest @@ -13,8 +13,9 @@ from homeassistant.components import webhook from homeassistant.components.smartthings import DeviceBroker from homeassistant.components.smartthings.const import ( APP_NAME_PREFIX, CONF_APP_ID, CONF_INSTALLED_APP_ID, CONF_INSTANCE_ID, - CONF_LOCATION_ID, DATA_BROKERS, DOMAIN, SETTINGS_INSTANCE_ID, STORAGE_KEY, - STORAGE_VERSION) + CONF_LOCATION_ID, CONF_OAUTH_CLIENT_ID, CONF_OAUTH_CLIENT_SECRET, + CONF_REFRESH_TOKEN, DATA_BROKERS, DOMAIN, SETTINGS_INSTANCE_ID, + STORAGE_KEY, STORAGE_VERSION) from homeassistant.config_entries import ( CONN_CLASS_CLOUD_PUSH, SOURCE_USER, ConfigEntry) from homeassistant.const import CONF_ACCESS_TOKEN, CONF_WEBHOOK_ID @@ -26,9 +27,11 @@ from tests.common import mock_coro async def setup_platform(hass, platform: str, *devices): """Set up the SmartThings platform and prerequisites.""" hass.config.components.add(DOMAIN) - broker = DeviceBroker(hass, devices, '') - config_entry = ConfigEntry("1", DOMAIN, "Test", {}, + config_entry = ConfigEntry(2, DOMAIN, "Test", + {CONF_INSTALLED_APP_ID: str(uuid4())}, SOURCE_USER, CONN_CLASS_CLOUD_PUSH) + broker = DeviceBroker(hass, config_entry, Mock(), Mock(), devices) + hass.data[DOMAIN] = { DATA_BROKERS: { config_entry.entry_id: broker @@ -98,6 +101,15 @@ def app_fixture(hass, config_file): return app +@pytest.fixture(name="app_oauth_client") +def app_oauth_client_fixture(): + """Fixture for a single app's oauth.""" + return AppOAuthClient({ + 'oauthClientId': str(uuid4()), + 'oauthClientSecret': str(uuid4()) + }) + + @pytest.fixture(name='app_settings') def app_settings_fixture(app, config_file): """Fixture for an app settings.""" @@ -225,12 +237,25 @@ def config_entry_fixture(hass, installed_app, location): CONF_ACCESS_TOKEN: str(uuid4()), CONF_INSTALLED_APP_ID: installed_app.installed_app_id, CONF_APP_ID: installed_app.app_id, - CONF_LOCATION_ID: location.location_id + CONF_LOCATION_ID: location.location_id, + CONF_REFRESH_TOKEN: str(uuid4()), + CONF_OAUTH_CLIENT_ID: str(uuid4()), + CONF_OAUTH_CLIENT_SECRET: str(uuid4()) } - return ConfigEntry("1", DOMAIN, location.name, data, SOURCE_USER, + return ConfigEntry(2, DOMAIN, location.name, data, SOURCE_USER, CONN_CLASS_CLOUD_PUSH) +@pytest.fixture(name="subscription_factory") +def subscription_factory_fixture(): + """Fixture for creating mock subscriptions.""" + def _factory(capability): + sub = Subscription() + sub.capability = capability + return sub + return _factory + + @pytest.fixture(name="device_factory") def device_factory_fixture(): """Fixture for creating mock devices.""" diff --git a/tests/components/smartthings/test_binary_sensor.py b/tests/components/smartthings/test_binary_sensor.py index 4b47537fa19..6e60ee49ca6 100644 --- a/tests/components/smartthings/test_binary_sensor.py +++ b/tests/components/smartthings/test_binary_sensor.py @@ -6,31 +6,15 @@ real HTTP calls are not initiated during testing. """ from pysmartthings import ATTRIBUTES, CAPABILITIES, Attribute, Capability -from homeassistant.components.binary_sensor import DEVICE_CLASSES -from homeassistant.components.smartthings import DeviceBroker, binary_sensor +from homeassistant.components.binary_sensor import ( + DEVICE_CLASSES, DOMAIN as BINARY_SENSOR_DOMAIN) +from homeassistant.components.smartthings import binary_sensor from homeassistant.components.smartthings.const import ( - DATA_BROKERS, DOMAIN, SIGNAL_SMARTTHINGS_UPDATE) -from homeassistant.config_entries import ( - CONN_CLASS_CLOUD_PUSH, SOURCE_USER, ConfigEntry) + DOMAIN, SIGNAL_SMARTTHINGS_UPDATE) from homeassistant.const import ATTR_FRIENDLY_NAME from homeassistant.helpers.dispatcher import async_dispatcher_send - -async def _setup_platform(hass, *devices): - """Set up the SmartThings binary_sensor platform and prerequisites.""" - hass.config.components.add(DOMAIN) - broker = DeviceBroker(hass, devices, '') - config_entry = ConfigEntry("1", DOMAIN, "Test", {}, - SOURCE_USER, CONN_CLASS_CLOUD_PUSH) - hass.data[DOMAIN] = { - DATA_BROKERS: { - config_entry.entry_id: broker - } - } - await hass.config_entries.async_forward_entry_setup( - config_entry, 'binary_sensor') - await hass.async_block_till_done() - return config_entry +from .conftest import setup_platform async def test_mapping_integrity(): @@ -56,7 +40,7 @@ async def test_entity_state(hass, device_factory): """Tests the state attributes properly match the light types.""" device = device_factory('Motion Sensor 1', [Capability.motion_sensor], {Attribute.motion: 'inactive'}) - await _setup_platform(hass, device) + await setup_platform(hass, BINARY_SENSOR_DOMAIN, device) state = hass.states.get('binary_sensor.motion_sensor_1_motion') assert state.state == 'off' assert state.attributes[ATTR_FRIENDLY_NAME] ==\ @@ -71,7 +55,7 @@ async def test_entity_and_device_attributes(hass, device_factory): entity_registry = await hass.helpers.entity_registry.async_get_registry() device_registry = await hass.helpers.device_registry.async_get_registry() # Act - await _setup_platform(hass, device) + await setup_platform(hass, BINARY_SENSOR_DOMAIN, device) # Assert entry = entity_registry.async_get('binary_sensor.motion_sensor_1_motion') assert entry @@ -89,7 +73,7 @@ async def test_update_from_signal(hass, device_factory): # Arrange device = device_factory('Motion Sensor 1', [Capability.motion_sensor], {Attribute.motion: 'inactive'}) - await _setup_platform(hass, device) + await setup_platform(hass, BINARY_SENSOR_DOMAIN, device) device.status.apply_attribute_update( 'main', Capability.motion_sensor, Attribute.motion, 'active') # Act @@ -107,7 +91,7 @@ async def test_unload_config_entry(hass, device_factory): # Arrange device = device_factory('Motion Sensor 1', [Capability.motion_sensor], {Attribute.motion: 'inactive'}) - config_entry = await _setup_platform(hass, device) + config_entry = await setup_platform(hass, BINARY_SENSOR_DOMAIN, device) # Act await hass.config_entries.async_forward_entry_unload( config_entry, 'binary_sensor') diff --git a/tests/components/smartthings/test_config_flow.py b/tests/components/smartthings/test_config_flow.py index 7d335703131..28aa759a359 100644 --- a/tests/components/smartthings/test_config_flow.py +++ b/tests/components/smartthings/test_config_flow.py @@ -8,6 +8,9 @@ from pysmartthings import APIResponseError from homeassistant import data_entry_flow from homeassistant.components.smartthings.config_flow import ( SmartThingsFlowHandler) +from homeassistant.components.smartthings.const import ( + CONF_INSTALLED_APP_ID, CONF_INSTALLED_APPS, CONF_LOCATION_ID, + CONF_REFRESH_TOKEN, DOMAIN) from homeassistant.config_entries import ConfigEntry from tests.common import mock_coro @@ -171,14 +174,16 @@ async def test_unknown_error(hass, smartthings_mock): assert result['errors'] == {'base': 'app_setup_error'} -async def test_app_created_then_show_wait_form(hass, app, smartthings_mock): +async def test_app_created_then_show_wait_form( + hass, app, app_oauth_client, smartthings_mock): """Test SmartApp is created when one does not exist and shows wait form.""" flow = SmartThingsFlowHandler() flow.hass = hass smartthings = smartthings_mock.return_value smartthings.apps.return_value = mock_coro(return_value=[]) - smartthings.create_app.return_value = mock_coro(return_value=(app, None)) + smartthings.create_app.return_value = \ + mock_coro(return_value=(app, app_oauth_client)) smartthings.update_app_settings.return_value = mock_coro() smartthings.update_app_oauth.return_value = mock_coro() @@ -189,13 +194,15 @@ async def test_app_created_then_show_wait_form(hass, app, smartthings_mock): async def test_app_updated_then_show_wait_form( - hass, app, smartthings_mock): + hass, app, app_oauth_client, smartthings_mock): """Test SmartApp is updated when an existing is already created.""" flow = SmartThingsFlowHandler() flow.hass = hass api = smartthings_mock.return_value api.apps.return_value = mock_coro(return_value=[app]) + api.generate_app_oauth.return_value = \ + mock_coro(return_value=app_oauth_client) result = await flow.async_step_user({'access_token': str(uuid4())}) @@ -219,8 +226,6 @@ async def test_wait_form_displayed_after_checking(hass, smartthings_mock): flow = SmartThingsFlowHandler() flow.hass = hass flow.access_token = str(uuid4()) - flow.api = smartthings_mock.return_value - flow.api.installed_apps.return_value = mock_coro(return_value=[]) result = await flow.async_step_wait_install({}) @@ -235,19 +240,29 @@ async def test_config_entry_created_when_installed( flow = SmartThingsFlowHandler() flow.hass = hass flow.access_token = str(uuid4()) - flow.api = smartthings_mock.return_value flow.app_id = installed_app.app_id - flow.api.installed_apps.return_value = \ - mock_coro(return_value=[installed_app]) + flow.api = smartthings_mock.return_value + flow.oauth_client_id = str(uuid4()) + flow.oauth_client_secret = str(uuid4()) + data = { + CONF_REFRESH_TOKEN: str(uuid4()), + CONF_LOCATION_ID: installed_app.location_id, + CONF_INSTALLED_APP_ID: installed_app.installed_app_id + } + hass.data[DOMAIN][CONF_INSTALLED_APPS].append(data) result = await flow.async_step_wait_install({}) + assert not hass.data[DOMAIN][CONF_INSTALLED_APPS] assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result['data']['app_id'] == installed_app.app_id assert result['data']['installed_app_id'] == \ installed_app.installed_app_id assert result['data']['location_id'] == installed_app.location_id assert result['data']['access_token'] == flow.access_token + assert result['data']['refresh_token'] == data[CONF_REFRESH_TOKEN] + assert result['data']['client_secret'] == flow.oauth_client_secret + assert result['data']['client_id'] == flow.oauth_client_id assert result['title'] == location.name @@ -259,17 +274,31 @@ async def test_multiple_config_entry_created_when_installed( flow.access_token = str(uuid4()) flow.app_id = app.app_id flow.api = smartthings_mock.return_value - flow.api.installed_apps.return_value = \ - mock_coro(return_value=installed_apps) + flow.oauth_client_id = str(uuid4()) + flow.oauth_client_secret = str(uuid4()) + for installed_app in installed_apps: + data = { + CONF_REFRESH_TOKEN: str(uuid4()), + CONF_LOCATION_ID: installed_app.location_id, + CONF_INSTALLED_APP_ID: installed_app.installed_app_id + } + hass.data[DOMAIN][CONF_INSTALLED_APPS].append(data) + install_data = hass.data[DOMAIN][CONF_INSTALLED_APPS].copy() result = await flow.async_step_wait_install({}) + assert not hass.data[DOMAIN][CONF_INSTALLED_APPS] + assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result['data']['app_id'] == installed_apps[0].app_id assert result['data']['installed_app_id'] == \ installed_apps[0].installed_app_id assert result['data']['location_id'] == installed_apps[0].location_id assert result['data']['access_token'] == flow.access_token + assert result['data']['refresh_token'] == \ + install_data[0][CONF_REFRESH_TOKEN] + assert result['data']['client_secret'] == flow.oauth_client_secret + assert result['data']['client_id'] == flow.oauth_client_id assert result['title'] == locations[0].name await hass.async_block_till_done() @@ -280,4 +309,6 @@ async def test_multiple_config_entry_created_when_installed( installed_apps[1].installed_app_id assert entries[0].data['location_id'] == installed_apps[1].location_id assert entries[0].data['access_token'] == flow.access_token + assert entries[0].data['client_secret'] == flow.oauth_client_secret + assert entries[0].data['client_id'] == flow.oauth_client_id assert entries[0].title == locations[1].name diff --git a/tests/components/smartthings/test_fan.py b/tests/components/smartthings/test_fan.py index db8d9b512de..644c0823fd5 100644 --- a/tests/components/smartthings/test_fan.py +++ b/tests/components/smartthings/test_fan.py @@ -7,31 +7,15 @@ real HTTP calls are not initiated during testing. from pysmartthings import Attribute, Capability from homeassistant.components.fan import ( - ATTR_SPEED, ATTR_SPEED_LIST, SPEED_HIGH, SPEED_LOW, SPEED_MEDIUM, - SPEED_OFF, SUPPORT_SET_SPEED) -from homeassistant.components.smartthings import DeviceBroker, fan + ATTR_SPEED, ATTR_SPEED_LIST, DOMAIN as FAN_DOMAIN, SPEED_HIGH, SPEED_LOW, + SPEED_MEDIUM, SPEED_OFF, SUPPORT_SET_SPEED) +from homeassistant.components.smartthings import fan from homeassistant.components.smartthings.const import ( - DATA_BROKERS, DOMAIN, SIGNAL_SMARTTHINGS_UPDATE) -from homeassistant.config_entries import ( - CONN_CLASS_CLOUD_PUSH, SOURCE_USER, ConfigEntry) + DOMAIN, SIGNAL_SMARTTHINGS_UPDATE) from homeassistant.const import ATTR_ENTITY_ID, ATTR_SUPPORTED_FEATURES from homeassistant.helpers.dispatcher import async_dispatcher_send - -async def _setup_platform(hass, *devices): - """Set up the SmartThings fan platform and prerequisites.""" - hass.config.components.add(DOMAIN) - broker = DeviceBroker(hass, devices, '') - config_entry = ConfigEntry("1", DOMAIN, "Test", {}, - SOURCE_USER, CONN_CLASS_CLOUD_PUSH) - hass.data[DOMAIN] = { - DATA_BROKERS: { - config_entry.entry_id: broker - } - } - await hass.config_entries.async_forward_entry_setup(config_entry, 'fan') - await hass.async_block_till_done() - return config_entry +from .conftest import setup_platform async def test_async_setup_platform(): @@ -45,7 +29,7 @@ async def test_entity_state(hass, device_factory): "Fan 1", capabilities=[Capability.switch, Capability.fan_speed], status={Attribute.switch: 'on', Attribute.fan_speed: 2}) - await _setup_platform(hass, device) + await setup_platform(hass, FAN_DOMAIN, device) # Dimmer 1 state = hass.states.get('fan.fan_1') @@ -63,11 +47,10 @@ async def test_entity_and_device_attributes(hass, device_factory): "Fan 1", capabilities=[Capability.switch, Capability.fan_speed], status={Attribute.switch: 'on', Attribute.fan_speed: 2}) - await _setup_platform(hass, device) + # Act + await setup_platform(hass, FAN_DOMAIN, device) entity_registry = await hass.helpers.entity_registry.async_get_registry() device_registry = await hass.helpers.device_registry.async_get_registry() - # Act - await _setup_platform(hass, device) # Assert entry = entity_registry.async_get("fan.fan_1") assert entry @@ -88,7 +71,7 @@ async def test_turn_off(hass, device_factory): "Fan 1", capabilities=[Capability.switch, Capability.fan_speed], status={Attribute.switch: 'on', Attribute.fan_speed: 2}) - await _setup_platform(hass, device) + await setup_platform(hass, FAN_DOMAIN, device) # Act await hass.services.async_call( 'fan', 'turn_off', {'entity_id': 'fan.fan_1'}, @@ -106,7 +89,7 @@ async def test_turn_on(hass, device_factory): "Fan 1", capabilities=[Capability.switch, Capability.fan_speed], status={Attribute.switch: 'off', Attribute.fan_speed: 0}) - await _setup_platform(hass, device) + await setup_platform(hass, FAN_DOMAIN, device) # Act await hass.services.async_call( 'fan', 'turn_on', {ATTR_ENTITY_ID: "fan.fan_1"}, @@ -124,7 +107,7 @@ async def test_turn_on_with_speed(hass, device_factory): "Fan 1", capabilities=[Capability.switch, Capability.fan_speed], status={Attribute.switch: 'off', Attribute.fan_speed: 0}) - await _setup_platform(hass, device) + await setup_platform(hass, FAN_DOMAIN, device) # Act await hass.services.async_call( 'fan', 'turn_on', @@ -145,7 +128,7 @@ async def test_set_speed(hass, device_factory): "Fan 1", capabilities=[Capability.switch, Capability.fan_speed], status={Attribute.switch: 'off', Attribute.fan_speed: 0}) - await _setup_platform(hass, device) + await setup_platform(hass, FAN_DOMAIN, device) # Act await hass.services.async_call( 'fan', 'set_speed', @@ -166,7 +149,7 @@ async def test_update_from_signal(hass, device_factory): "Fan 1", capabilities=[Capability.switch, Capability.fan_speed], status={Attribute.switch: 'off', Attribute.fan_speed: 0}) - await _setup_platform(hass, device) + await setup_platform(hass, FAN_DOMAIN, device) await device.switch_on(True) # Act async_dispatcher_send(hass, SIGNAL_SMARTTHINGS_UPDATE, @@ -185,7 +168,7 @@ async def test_unload_config_entry(hass, device_factory): "Fan 1", capabilities=[Capability.switch, Capability.fan_speed], status={Attribute.switch: 'off', Attribute.fan_speed: 0}) - config_entry = await _setup_platform(hass, device) + config_entry = await setup_platform(hass, FAN_DOMAIN, device) # Act await hass.config_entries.async_forward_entry_unload( config_entry, 'fan') diff --git a/tests/components/smartthings/test_init.py b/tests/components/smartthings/test_init.py index 014cfe7da98..0e35ef80fc2 100644 --- a/tests/components/smartthings/test_init.py +++ b/tests/components/smartthings/test_init.py @@ -8,14 +8,33 @@ import pytest from homeassistant.components import smartthings from homeassistant.components.smartthings.const import ( - DATA_BROKERS, DOMAIN, EVENT_BUTTON, SIGNAL_SMARTTHINGS_UPDATE, - SUPPORTED_PLATFORMS) + CONF_INSTALLED_APP_ID, CONF_REFRESH_TOKEN, DATA_BROKERS, DOMAIN, + EVENT_BUTTON, SIGNAL_SMARTTHINGS_UPDATE, SUPPORTED_PLATFORMS) from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers.dispatcher import async_dispatcher_connect from tests.common import mock_coro +async def test_migration_creates_new_flow( + hass, smartthings_mock, config_entry): + """Test migration deletes app and creates new flow.""" + config_entry.version = 1 + setattr(hass.config_entries, '_entries', [config_entry]) + api = smartthings_mock.return_value + api.delete_installed_app.return_value = mock_coro() + + await smartthings.async_migrate_entry(hass, config_entry) + + assert api.delete_installed_app.call_count == 1 + await hass.async_block_till_done() + assert not hass.config_entries.async_entries(DOMAIN) + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + assert flows[0]['handler'] == 'smartthings' + assert flows[0]['context'] == {'source': 'import'} + + async def test_unrecoverable_api_errors_create_new_flow( hass, config_entry, smartthings_mock): """ @@ -101,14 +120,22 @@ async def test_unauthorized_installed_app_raises_not_ready( async def test_config_entry_loads_platforms( hass, config_entry, app, installed_app, - device, smartthings_mock): + device, smartthings_mock, subscription_factory): """Test config entry loads properly and proxies to platforms.""" setattr(hass.config_entries, '_entries', [config_entry]) api = smartthings_mock.return_value api.app.return_value = mock_coro(return_value=app) api.installed_app.return_value = mock_coro(return_value=installed_app) - api.devices.return_value = mock_coro(return_value=[device]) + api.devices.side_effect = \ + lambda *args, **kwargs: mock_coro(return_value=[device]) + mock_token = Mock() + mock_token.access_token.return_value = str(uuid4()) + mock_token.refresh_token.return_value = str(uuid4()) + api.generate_tokens.return_value = mock_coro(return_value=mock_token) + subscriptions = [subscription_factory(capability) + for capability in device.capabilities] + api.subscriptions.return_value = mock_coro(return_value=subscriptions) with patch.object(hass.config_entries, 'async_forward_entry_setup', return_value=mock_coro()) as forward_mock: @@ -120,8 +147,12 @@ async def test_config_entry_loads_platforms( async def test_unload_entry(hass, config_entry): """Test entries are unloaded correctly.""" - broker = Mock() - broker.event_handler_disconnect = Mock() + connect_disconnect = Mock() + smart_app = Mock() + smart_app.connect_event.return_value = connect_disconnect + broker = smartthings.DeviceBroker( + hass, config_entry, Mock(), smart_app, []) + broker.connect() hass.data[DOMAIN][DATA_BROKERS][config_entry.entry_id] = broker with patch.object(hass.config_entries, 'async_forward_entry_unload', @@ -129,15 +160,41 @@ async def test_unload_entry(hass, config_entry): return_value=True )) as forward_mock: assert await smartthings.async_unload_entry(hass, config_entry) - assert broker.event_handler_disconnect.call_count == 1 + + assert connect_disconnect.call_count == 1 assert config_entry.entry_id not in hass.data[DOMAIN][DATA_BROKERS] # Assert platforms unloaded await hass.async_block_till_done() assert forward_mock.call_count == len(SUPPORTED_PLATFORMS) +async def test_broker_regenerates_token( + hass, config_entry): + """Test the device broker regenerates the refresh token.""" + token = Mock() + token.refresh_token = str(uuid4()) + token.refresh.return_value = mock_coro() + stored_action = None + + def async_track_time_interval(hass, action, interval): + nonlocal stored_action + stored_action = action + + with patch('homeassistant.components.smartthings' + '.async_track_time_interval', + new=async_track_time_interval): + broker = smartthings.DeviceBroker( + hass, config_entry, token, Mock(), []) + broker.connect() + + assert stored_action + await stored_action(None) # pylint:disable=not-callable + assert token.refresh.call_count == 1 + assert config_entry.data[CONF_REFRESH_TOKEN] == token.refresh_token + + async def test_event_handler_dispatches_updated_devices( - hass, device_factory, event_request_factory): + hass, config_entry, device_factory, event_request_factory): """Test the event handler dispatches updated devices.""" devices = [ device_factory('Bedroom 1 Switch', ['switch']), @@ -147,6 +204,7 @@ async def test_event_handler_dispatches_updated_devices( device_ids = [devices[0].device_id, devices[1].device_id, devices[2].device_id] request = event_request_factory(device_ids) + config_entry.data[CONF_INSTALLED_APP_ID] = request.installed_app_id called = False def signal(ids): @@ -154,10 +212,13 @@ async def test_event_handler_dispatches_updated_devices( called = True assert device_ids == ids async_dispatcher_connect(hass, SIGNAL_SMARTTHINGS_UPDATE, signal) - broker = smartthings.DeviceBroker( - hass, devices, request.installed_app_id) - await broker.event_handler(request, None, None) + broker = smartthings.DeviceBroker( + hass, config_entry, Mock(), Mock(), devices) + broker.connect() + + # pylint:disable=protected-access + await broker._event_handler(request, None, None) await hass.async_block_till_done() assert called @@ -166,7 +227,7 @@ async def test_event_handler_dispatches_updated_devices( async def test_event_handler_ignores_other_installed_app( - hass, device_factory, event_request_factory): + hass, config_entry, device_factory, event_request_factory): """Test the event handler dispatches updated devices.""" device = device_factory('Bedroom 1 Switch', ['switch']) request = event_request_factory([device.device_id]) @@ -176,21 +237,26 @@ async def test_event_handler_ignores_other_installed_app( nonlocal called called = True async_dispatcher_connect(hass, SIGNAL_SMARTTHINGS_UPDATE, signal) - broker = smartthings.DeviceBroker(hass, [device], str(uuid4())) + broker = smartthings.DeviceBroker( + hass, config_entry, Mock(), Mock(), [device]) + broker.connect() - await broker.event_handler(request, None, None) + # pylint:disable=protected-access + await broker._event_handler(request, None, None) await hass.async_block_till_done() assert not called async def test_event_handler_fires_button_events( - hass, device_factory, event_factory, event_request_factory): + hass, config_entry, device_factory, event_factory, + event_request_factory): """Test the event handler fires button events.""" device = device_factory('Button 1', ['button']) event = event_factory(device.device_id, capability='button', attribute='button', value='pushed') request = event_request_factory(events=[event]) + config_entry.data[CONF_INSTALLED_APP_ID] = request.installed_app_id called = False def handler(evt): @@ -205,8 +271,11 @@ async def test_event_handler_fires_button_events( } hass.bus.async_listen(EVENT_BUTTON, handler) broker = smartthings.DeviceBroker( - hass, [device], request.installed_app_id) - await broker.event_handler(request, None, None) + hass, config_entry, Mock(), Mock(), [device]) + broker.connect() + + # pylint:disable=protected-access + await broker._event_handler(request, None, None) await hass.async_block_till_done() assert called diff --git a/tests/components/smartthings/test_light.py b/tests/components/smartthings/test_light.py index 72bc5da9063..d31507925d6 100644 --- a/tests/components/smartthings/test_light.py +++ b/tests/components/smartthings/test_light.py @@ -9,15 +9,16 @@ import pytest from homeassistant.components.light import ( ATTR_BRIGHTNESS, ATTR_COLOR_TEMP, ATTR_HS_COLOR, ATTR_TRANSITION, - SUPPORT_BRIGHTNESS, SUPPORT_COLOR, SUPPORT_COLOR_TEMP, SUPPORT_TRANSITION) -from homeassistant.components.smartthings import DeviceBroker, light + DOMAIN as LIGHT_DOMAIN, SUPPORT_BRIGHTNESS, SUPPORT_COLOR, + SUPPORT_COLOR_TEMP, SUPPORT_TRANSITION) +from homeassistant.components.smartthings import light from homeassistant.components.smartthings.const import ( - DATA_BROKERS, DOMAIN, SIGNAL_SMARTTHINGS_UPDATE) -from homeassistant.config_entries import ( - CONN_CLASS_CLOUD_PUSH, SOURCE_USER, ConfigEntry) + DOMAIN, SIGNAL_SMARTTHINGS_UPDATE) from homeassistant.const import ATTR_ENTITY_ID, ATTR_SUPPORTED_FEATURES from homeassistant.helpers.dispatcher import async_dispatcher_send +from .conftest import setup_platform + @pytest.fixture(name="light_devices") def light_devices_fixture(device_factory): @@ -44,22 +45,6 @@ def light_devices_fixture(device_factory): ] -async def _setup_platform(hass, *devices): - """Set up the SmartThings light platform and prerequisites.""" - hass.config.components.add(DOMAIN) - broker = DeviceBroker(hass, devices, '') - config_entry = ConfigEntry("1", DOMAIN, "Test", {}, - SOURCE_USER, CONN_CLASS_CLOUD_PUSH) - hass.data[DOMAIN] = { - DATA_BROKERS: { - config_entry.entry_id: broker - } - } - await hass.config_entries.async_forward_entry_setup(config_entry, 'light') - await hass.async_block_till_done() - return config_entry - - async def test_async_setup_platform(): """Test setup platform does nothing (it uses config entries).""" await light.async_setup_platform(None, None, None) @@ -67,7 +52,7 @@ async def test_async_setup_platform(): async def test_entity_state(hass, light_devices): """Tests the state attributes properly match the light types.""" - await _setup_platform(hass, *light_devices) + await setup_platform(hass, LIGHT_DOMAIN, *light_devices) # Dimmer 1 state = hass.states.get('light.dimmer_1') @@ -101,7 +86,7 @@ async def test_entity_and_device_attributes(hass, device_factory): entity_registry = await hass.helpers.entity_registry.async_get_registry() device_registry = await hass.helpers.device_registry.async_get_registry() # Act - await _setup_platform(hass, device) + await setup_platform(hass, LIGHT_DOMAIN, device) # Assert entry = entity_registry.async_get("light.light_1") assert entry @@ -118,7 +103,7 @@ async def test_entity_and_device_attributes(hass, device_factory): async def test_turn_off(hass, light_devices): """Test the light turns of successfully.""" # Arrange - await _setup_platform(hass, *light_devices) + await setup_platform(hass, LIGHT_DOMAIN, *light_devices) # Act await hass.services.async_call( 'light', 'turn_off', {'entity_id': 'light.color_dimmer_2'}, @@ -132,7 +117,7 @@ async def test_turn_off(hass, light_devices): async def test_turn_off_with_transition(hass, light_devices): """Test the light turns of successfully with transition.""" # Arrange - await _setup_platform(hass, *light_devices) + await setup_platform(hass, LIGHT_DOMAIN, *light_devices) # Act await hass.services.async_call( 'light', 'turn_off', @@ -147,7 +132,7 @@ async def test_turn_off_with_transition(hass, light_devices): async def test_turn_on(hass, light_devices): """Test the light turns of successfully.""" # Arrange - await _setup_platform(hass, *light_devices) + await setup_platform(hass, LIGHT_DOMAIN, *light_devices) # Act await hass.services.async_call( 'light', 'turn_on', {ATTR_ENTITY_ID: "light.color_dimmer_1"}, @@ -161,7 +146,7 @@ async def test_turn_on(hass, light_devices): async def test_turn_on_with_brightness(hass, light_devices): """Test the light turns on to the specified brightness.""" # Arrange - await _setup_platform(hass, *light_devices) + await setup_platform(hass, LIGHT_DOMAIN, *light_devices) # Act await hass.services.async_call( 'light', 'turn_on', @@ -185,7 +170,7 @@ async def test_turn_on_with_minimal_brightness(hass, light_devices): set the level to zero, which turns off the lights in SmartThings. """ # Arrange - await _setup_platform(hass, *light_devices) + await setup_platform(hass, LIGHT_DOMAIN, *light_devices) # Act await hass.services.async_call( 'light', 'turn_on', @@ -203,7 +188,7 @@ async def test_turn_on_with_minimal_brightness(hass, light_devices): async def test_turn_on_with_color(hass, light_devices): """Test the light turns on with color.""" # Arrange - await _setup_platform(hass, *light_devices) + await setup_platform(hass, LIGHT_DOMAIN, *light_devices) # Act await hass.services.async_call( 'light', 'turn_on', @@ -220,7 +205,7 @@ async def test_turn_on_with_color(hass, light_devices): async def test_turn_on_with_color_temp(hass, light_devices): """Test the light turns on with color temp.""" # Arrange - await _setup_platform(hass, *light_devices) + await setup_platform(hass, LIGHT_DOMAIN, *light_devices) # Act await hass.services.async_call( 'light', 'turn_on', @@ -244,7 +229,7 @@ async def test_update_from_signal(hass, device_factory): status={Attribute.switch: 'off', Attribute.level: 100, Attribute.hue: 76.0, Attribute.saturation: 55.0, Attribute.color_temperature: 4500}) - await _setup_platform(hass, device) + await setup_platform(hass, LIGHT_DOMAIN, device) await device.switch_on(True) # Act async_dispatcher_send(hass, SIGNAL_SMARTTHINGS_UPDATE, @@ -266,7 +251,7 @@ async def test_unload_config_entry(hass, device_factory): status={Attribute.switch: 'off', Attribute.level: 100, Attribute.hue: 76.0, Attribute.saturation: 55.0, Attribute.color_temperature: 4500}) - config_entry = await _setup_platform(hass, device) + config_entry = await setup_platform(hass, LIGHT_DOMAIN, device) # Act await hass.config_entries.async_forward_entry_unload( config_entry, 'light') diff --git a/tests/components/smartthings/test_smartapp.py b/tests/components/smartthings/test_smartapp.py index 162a8f9a4e5..46bd1f42f7f 100644 --- a/tests/components/smartthings/test_smartapp.py +++ b/tests/components/smartthings/test_smartapp.py @@ -5,7 +5,9 @@ from uuid import uuid4 from pysmartthings import AppEntity, Capability from homeassistant.components.smartthings import smartapp -from homeassistant.components.smartthings.const import DATA_MANAGER, DOMAIN +from homeassistant.components.smartthings.const import ( + CONF_INSTALLED_APP_ID, CONF_INSTALLED_APPS, CONF_LOCATION_ID, + CONF_REFRESH_TOKEN, DATA_MANAGER, DOMAIN) from tests.common import mock_coro @@ -35,31 +37,26 @@ async def test_update_app_updated_needed(hass, app): assert mock_app.classifications == app.classifications -async def test_smartapp_install_abort_if_no_other( +async def test_smartapp_install_store_if_no_other( hass, smartthings_mock, device_factory): """Test aborts if no other app was configured already.""" # Arrange - api = smartthings_mock.return_value - api.create_subscription.return_value = mock_coro() app = Mock() app.app_id = uuid4() request = Mock() - request.installed_app_id = uuid4() - request.auth_token = uuid4() - request.location_id = uuid4() - devices = [ - device_factory('', [Capability.battery, 'ping']), - device_factory('', [Capability.switch, Capability.switch_level]), - device_factory('', [Capability.switch]) - ] - api.devices = Mock() - api.devices.return_value = mock_coro(return_value=devices) + request.installed_app_id = str(uuid4()) + request.auth_token = str(uuid4()) + request.location_id = str(uuid4()) + request.refresh_token = str(uuid4()) # Act await smartapp.smartapp_install(hass, request, None, app) # Assert entries = hass.config_entries.async_entries('smartthings') assert not entries - assert api.create_subscription.call_count == 3 + data = hass.data[DOMAIN][CONF_INSTALLED_APPS][0] + assert data[CONF_REFRESH_TOKEN] == request.refresh_token + assert data[CONF_LOCATION_ID] == request.location_id + assert data[CONF_INSTALLED_APP_ID] == request.installed_app_id async def test_smartapp_install_creates_flow( @@ -68,12 +65,12 @@ async def test_smartapp_install_creates_flow( # Arrange setattr(hass.config_entries, '_entries', [config_entry]) api = smartthings_mock.return_value - api.create_subscription.return_value = mock_coro() app = Mock() app.app_id = config_entry.data['app_id'] request = Mock() request.installed_app_id = str(uuid4()) request.auth_token = str(uuid4()) + request.refresh_token = str(uuid4()) request.location_id = location.location_id devices = [ device_factory('', [Capability.battery, 'ping']), @@ -88,42 +85,42 @@ async def test_smartapp_install_creates_flow( await hass.async_block_till_done() entries = hass.config_entries.async_entries('smartthings') assert len(entries) == 2 - assert api.create_subscription.call_count == 3 assert entries[1].data['app_id'] == app.app_id assert entries[1].data['installed_app_id'] == request.installed_app_id assert entries[1].data['location_id'] == request.location_id assert entries[1].data['access_token'] == \ config_entry.data['access_token'] + assert entries[1].data['refresh_token'] == request.refresh_token + assert entries[1].data['client_secret'] == \ + config_entry.data['client_secret'] + assert entries[1].data['client_id'] == config_entry.data['client_id'] assert entries[1].title == location.name -async def test_smartapp_update_syncs_subs( - hass, smartthings_mock, config_entry, location, device_factory): - """Test update synchronizes subscriptions.""" +async def test_smartapp_update_saves_token( + hass, smartthings_mock, location, device_factory): + """Test update saves token.""" # Arrange - setattr(hass.config_entries, '_entries', [config_entry]) + entry = Mock() + entry.data = { + 'installed_app_id': str(uuid4()), + 'app_id': str(uuid4()) + } + entry.domain = DOMAIN + + setattr(hass.config_entries, '_entries', [entry]) app = Mock() - app.app_id = config_entry.data['app_id'] - api = smartthings_mock.return_value - api.delete_subscriptions = Mock() - api.delete_subscriptions.return_value = mock_coro() - api.create_subscription.return_value = mock_coro() + app.app_id = entry.data['app_id'] request = Mock() - request.installed_app_id = str(uuid4()) + request.installed_app_id = entry.data['installed_app_id'] request.auth_token = str(uuid4()) + request.refresh_token = str(uuid4()) request.location_id = location.location_id - devices = [ - device_factory('', [Capability.battery, 'ping']), - device_factory('', [Capability.switch, Capability.switch_level]), - device_factory('', [Capability.switch]) - ] - api.devices = Mock() - api.devices.return_value = mock_coro(return_value=devices) + # Act await smartapp.smartapp_update(hass, request, None, app) # Assert - assert api.create_subscription.call_count == 3 - assert api.delete_subscriptions.call_count == 1 + assert entry.data[CONF_REFRESH_TOKEN] == request.refresh_token async def test_smartapp_uninstall(hass, config_entry): @@ -152,3 +149,83 @@ async def test_smartapp_webhook(hass): result = await smartapp.smartapp_webhook(hass, '', request) assert result.body == b'{}' + + +async def test_smartapp_sync_subscriptions( + hass, smartthings_mock, device_factory, subscription_factory): + """Test synchronization adds and removes.""" + api = smartthings_mock.return_value + api.delete_subscription.side_effect = lambda loc_id, sub_id: mock_coro() + api.create_subscription.side_effect = lambda sub: mock_coro() + subscriptions = [ + subscription_factory(Capability.thermostat), + subscription_factory(Capability.switch), + subscription_factory(Capability.switch_level) + ] + api.subscriptions.return_value = mock_coro(return_value=subscriptions) + devices = [ + device_factory('', [Capability.battery, 'ping']), + device_factory('', [Capability.switch, Capability.switch_level]), + device_factory('', [Capability.switch]) + ] + + await smartapp.smartapp_sync_subscriptions( + hass, str(uuid4()), str(uuid4()), str(uuid4()), devices) + + assert api.subscriptions.call_count == 1 + assert api.delete_subscription.call_count == 1 + assert api.create_subscription.call_count == 1 + + +async def test_smartapp_sync_subscriptions_up_to_date( + hass, smartthings_mock, device_factory, subscription_factory): + """Test synchronization does nothing when current.""" + api = smartthings_mock.return_value + api.delete_subscription.side_effect = lambda loc_id, sub_id: mock_coro() + api.create_subscription.side_effect = lambda sub: mock_coro() + subscriptions = [ + subscription_factory(Capability.battery), + subscription_factory(Capability.switch), + subscription_factory(Capability.switch_level) + ] + api.subscriptions.return_value = mock_coro(return_value=subscriptions) + devices = [ + device_factory('', [Capability.battery, 'ping']), + device_factory('', [Capability.switch, Capability.switch_level]), + device_factory('', [Capability.switch]) + ] + + await smartapp.smartapp_sync_subscriptions( + hass, str(uuid4()), str(uuid4()), str(uuid4()), devices) + + assert api.subscriptions.call_count == 1 + assert api.delete_subscription.call_count == 0 + assert api.create_subscription.call_count == 0 + + +async def test_smartapp_sync_subscriptions_handles_exceptions( + hass, smartthings_mock, device_factory, subscription_factory): + """Test synchronization does nothing when current.""" + api = smartthings_mock.return_value + api.delete_subscription.side_effect = \ + lambda loc_id, sub_id: mock_coro(exception=Exception) + api.create_subscription.side_effect = \ + lambda sub: mock_coro(exception=Exception) + subscriptions = [ + subscription_factory(Capability.battery), + subscription_factory(Capability.switch), + subscription_factory(Capability.switch_level) + ] + api.subscriptions.return_value = mock_coro(return_value=subscriptions) + devices = [ + device_factory('', [Capability.thermostat, 'ping']), + device_factory('', [Capability.switch, Capability.switch_level]), + device_factory('', [Capability.switch]) + ] + + await smartapp.smartapp_sync_subscriptions( + hass, str(uuid4()), str(uuid4()), str(uuid4()), devices) + + assert api.subscriptions.call_count == 1 + assert api.delete_subscription.call_count == 1 + assert api.create_subscription.call_count == 1 diff --git a/tests/components/smartthings/test_switch.py b/tests/components/smartthings/test_switch.py index 3f2bedd4f13..6ad87b7ad53 100644 --- a/tests/components/smartthings/test_switch.py +++ b/tests/components/smartthings/test_switch.py @@ -6,28 +6,13 @@ real HTTP calls are not initiated during testing. """ from pysmartthings import Attribute, Capability -from homeassistant.components.smartthings import DeviceBroker, switch +from homeassistant.components.smartthings import switch from homeassistant.components.smartthings.const import ( - DATA_BROKERS, DOMAIN, SIGNAL_SMARTTHINGS_UPDATE) -from homeassistant.config_entries import ( - CONN_CLASS_CLOUD_PUSH, SOURCE_USER, ConfigEntry) + DOMAIN, SIGNAL_SMARTTHINGS_UPDATE) +from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN from homeassistant.helpers.dispatcher import async_dispatcher_send - -async def _setup_platform(hass, *devices): - """Set up the SmartThings switch platform and prerequisites.""" - hass.config.components.add(DOMAIN) - broker = DeviceBroker(hass, devices, '') - config_entry = ConfigEntry("1", DOMAIN, "Test", {}, - SOURCE_USER, CONN_CLASS_CLOUD_PUSH) - hass.data[DOMAIN] = { - DATA_BROKERS: { - config_entry.entry_id: broker - } - } - await hass.config_entries.async_forward_entry_setup(config_entry, 'switch') - await hass.async_block_till_done() - return config_entry +from .conftest import setup_platform async def test_async_setup_platform(): @@ -43,7 +28,7 @@ async def test_entity_and_device_attributes(hass, device_factory): entity_registry = await hass.helpers.entity_registry.async_get_registry() device_registry = await hass.helpers.device_registry.async_get_registry() # Act - await _setup_platform(hass, device) + await setup_platform(hass, SWITCH_DOMAIN, device) # Assert entry = entity_registry.async_get('switch.switch_1') assert entry @@ -62,7 +47,7 @@ async def test_turn_off(hass, device_factory): # Arrange device = device_factory('Switch_1', [Capability.switch], {Attribute.switch: 'on'}) - await _setup_platform(hass, device) + await setup_platform(hass, SWITCH_DOMAIN, device) # Act await hass.services.async_call( 'switch', 'turn_off', {'entity_id': 'switch.switch_1'}, @@ -78,7 +63,7 @@ async def test_turn_on(hass, device_factory): # Arrange device = device_factory('Switch_1', [Capability.switch], {Attribute.switch: 'off'}) - await _setup_platform(hass, device) + await setup_platform(hass, SWITCH_DOMAIN, device) # Act await hass.services.async_call( 'switch', 'turn_on', {'entity_id': 'switch.switch_1'}, @@ -94,7 +79,7 @@ async def test_update_from_signal(hass, device_factory): # Arrange device = device_factory('Switch_1', [Capability.switch], {Attribute.switch: 'off'}) - await _setup_platform(hass, device) + await setup_platform(hass, SWITCH_DOMAIN, device) await device.switch_on(True) # Act async_dispatcher_send(hass, SIGNAL_SMARTTHINGS_UPDATE, @@ -111,7 +96,7 @@ async def test_unload_config_entry(hass, device_factory): # Arrange device = device_factory('Switch 1', [Capability.switch], {Attribute.switch: 'on'}) - config_entry = await _setup_platform(hass, device) + config_entry = await setup_platform(hass, SWITCH_DOMAIN, device) # Act await hass.config_entries.async_forward_entry_unload( config_entry, 'switch')