diff --git a/homeassistant/components/cloud/__init__.py b/homeassistant/components/cloud/__init__.py index 54a221565b4..3bfc5909b0b 100644 --- a/homeassistant/components/cloud/__init__.py +++ b/homeassistant/components/cloud/__init__.py @@ -162,7 +162,7 @@ class Cloud: @property def subscription_expired(self): """Return a boolean if the subscription has expired.""" - return dt_util.utcnow() > self.expiration_date + timedelta(days=3) + return dt_util.utcnow() > self.expiration_date + timedelta(days=7) @property def expiration_date(self): diff --git a/homeassistant/components/cloud/auth_api.py b/homeassistant/components/cloud/auth_api.py index dcf7567482a..042b90bf9cb 100644 --- a/homeassistant/components/cloud/auth_api.py +++ b/homeassistant/components/cloud/auth_api.py @@ -113,6 +113,24 @@ def check_token(cloud): raise _map_aws_exception(err) +def renew_access_token(cloud): + """Renew access token.""" + from botocore.exceptions import ClientError + + cognito = _cognito( + cloud, + access_token=cloud.access_token, + refresh_token=cloud.refresh_token) + + try: + cognito.renew_access_token() + cloud.id_token = cognito.id_token + cloud.access_token = cognito.access_token + cloud.write_user_info() + except ClientError as err: + raise _map_aws_exception(err) + + def _authenticate(cloud, email, password): """Log in and return an authenticated Cognito instance.""" from botocore.exceptions import ClientError diff --git a/homeassistant/components/cloud/http_api.py b/homeassistant/components/cloud/http_api.py index 720ca00cf52..0df4a39406e 100644 --- a/homeassistant/components/cloud/http_api.py +++ b/homeassistant/components/cloud/http_api.py @@ -14,7 +14,7 @@ from homeassistant.components import websocket_api from . import auth_api from .const import DOMAIN, REQUEST_TIMEOUT -from .iot import STATE_DISCONNECTED +from .iot import STATE_DISCONNECTED, STATE_CONNECTED _LOGGER = logging.getLogger(__name__) @@ -249,13 +249,28 @@ async def websocket_subscription(hass, connection, msg): with async_timeout.timeout(REQUEST_TIMEOUT, loop=hass.loop): response = await cloud.fetch_subscription_info() - if response.status == 200: - connection.send_message(websocket_api.result_message( - msg['id'], await response.json())) - else: + if response.status != 200: connection.send_message(websocket_api.error_message( msg['id'], 'request_failed', 'Failed to request subscription')) + data = await response.json() + + # Check if a user is subscribed but local info is outdated + # In that case, let's refresh and reconnect + if data.get('provider') and cloud.iot.state != STATE_CONNECTED: + _LOGGER.debug( + "Found disconnected account with valid subscriotion, connecting") + await hass.async_add_executor_job( + auth_api.renew_access_token, cloud) + + # Cancel reconnect in progress + if cloud.iot.state != STATE_DISCONNECTED: + await cloud.iot.disconnect() + + hass.async_create_task(cloud.iot.connect()) + + connection.send_message(websocket_api.result_message(msg['id'], data)) + @websocket_api.async_response async def websocket_update_prefs(hass, connection, msg): diff --git a/tests/components/cloud/test_http_api.py b/tests/components/cloud/test_http_api.py index 5d4b356b9b2..e27760bd6ed 100644 --- a/tests/components/cloud/test_http_api.py +++ b/tests/components/cloud/test_http_api.py @@ -7,6 +7,7 @@ from jose import jwt from homeassistant.components.cloud import ( DOMAIN, auth_api, iot, STORAGE_ENABLE_GOOGLE, STORAGE_ENABLE_ALEXA) +from homeassistant.util import dt as dt_util from tests.common import mock_coro @@ -352,24 +353,89 @@ async def test_websocket_status_not_logged_in(hass, hass_ws_client): } -async def test_websocket_subscription(hass, hass_ws_client, aioclient_mock, - mock_auth): - """Test querying the status.""" - aioclient_mock.get(SUBSCRIPTION_INFO_URL, json={'return': 'value'}) +async def test_websocket_subscription_reconnect( + hass, hass_ws_client, aioclient_mock, mock_auth): + """Test querying the status and connecting because valid account.""" + aioclient_mock.get(SUBSCRIPTION_INFO_URL, json={'provider': 'stripe'}) + hass.data[DOMAIN].id_token = jwt.encode({ + 'email': 'hello@home-assistant.io', + 'custom:sub-exp': dt_util.utcnow().date().isoformat() + }, 'test') + client = await hass_ws_client(hass) + + with patch( + 'homeassistant.components.cloud.auth_api.renew_access_token' + ) as mock_renew, patch( + 'homeassistant.components.cloud.iot.CloudIoT.connect' + ) as mock_connect: + await client.send_json({ + 'id': 5, + 'type': 'cloud/subscription' + }) + response = await client.receive_json() + + assert response['result'] == { + 'provider': 'stripe' + } + assert len(mock_renew.mock_calls) == 1 + assert len(mock_connect.mock_calls) == 1 + + +async def test_websocket_subscription_no_reconnect_if_connected( + hass, hass_ws_client, aioclient_mock, mock_auth): + """Test querying the status and not reconnecting because still expired.""" + aioclient_mock.get(SUBSCRIPTION_INFO_URL, json={'provider': 'stripe'}) + hass.data[DOMAIN].iot.state = iot.STATE_CONNECTED + hass.data[DOMAIN].id_token = jwt.encode({ + 'email': 'hello@home-assistant.io', + 'custom:sub-exp': dt_util.utcnow().date().isoformat() + }, 'test') + client = await hass_ws_client(hass) + + with patch( + 'homeassistant.components.cloud.auth_api.renew_access_token' + ) as mock_renew, patch( + 'homeassistant.components.cloud.iot.CloudIoT.connect' + ) as mock_connect: + await client.send_json({ + 'id': 5, + 'type': 'cloud/subscription' + }) + response = await client.receive_json() + + assert response['result'] == { + 'provider': 'stripe' + } + assert len(mock_renew.mock_calls) == 0 + assert len(mock_connect.mock_calls) == 0 + + +async def test_websocket_subscription_no_reconnect_if_expired( + hass, hass_ws_client, aioclient_mock, mock_auth): + """Test querying the status and not reconnecting because still expired.""" + aioclient_mock.get(SUBSCRIPTION_INFO_URL, json={'provider': 'stripe'}) hass.data[DOMAIN].id_token = jwt.encode({ 'email': 'hello@home-assistant.io', 'custom:sub-exp': '2018-01-03' }, 'test') client = await hass_ws_client(hass) - await client.send_json({ - 'id': 5, - 'type': 'cloud/subscription' - }) - response = await client.receive_json() + + with patch( + 'homeassistant.components.cloud.auth_api.renew_access_token' + ) as mock_renew, patch( + 'homeassistant.components.cloud.iot.CloudIoT.connect' + ) as mock_connect: + await client.send_json({ + 'id': 5, + 'type': 'cloud/subscription' + }) + response = await client.receive_json() assert response['result'] == { - 'return': 'value' + 'provider': 'stripe' } + assert len(mock_renew.mock_calls) == 1 + assert len(mock_connect.mock_calls) == 1 async def test_websocket_subscription_fail(hass, hass_ws_client, diff --git a/tests/components/cloud/test_init.py b/tests/components/cloud/test_init.py index 8695830eae9..61518f0f0e8 100644 --- a/tests/components/cloud/test_init.py +++ b/tests/components/cloud/test_init.py @@ -155,14 +155,14 @@ def test_subscription_expired(hass): with patch.object(cl, '_decode_claims', return_value=token_val), \ patch('homeassistant.util.dt.utcnow', return_value=utcnow().replace( - year=2017, month=11, day=15, hour=23, minute=59, + year=2017, month=11, day=19, hour=23, minute=59, second=59)): assert not cl.subscription_expired with patch.object(cl, '_decode_claims', return_value=token_val), \ patch('homeassistant.util.dt.utcnow', return_value=utcnow().replace( - year=2017, month=11, day=16, hour=0, minute=0, + year=2017, month=11, day=20, hour=0, minute=0, second=0)): assert cl.subscription_expired