diff --git a/homeassistant/components/alexa/auth.py b/homeassistant/components/alexa/auth.py index d4633d938ed..9f87a6d954e 100644 --- a/homeassistant/components/alexa/auth.py +++ b/homeassistant/components/alexa/auth.py @@ -56,6 +56,11 @@ class Auth: return await self._async_request_new_token(lwa_params) + @callback + def async_invalidate_access_token(self): + """Invalidate access token.""" + self._prefs[STORAGE_ACCESS_TOKEN] = None + async def async_get_access_token(self): """Perform access token or token refresh request.""" async with self._get_token_lock: diff --git a/homeassistant/components/alexa/config.py b/homeassistant/components/alexa/config.py index a22ebbcd30d..f98337d71c5 100644 --- a/homeassistant/components/alexa/config.py +++ b/homeassistant/components/alexa/config.py @@ -1,4 +1,6 @@ """Config helpers for Alexa.""" +from homeassistant.core import callback + from .state_report import async_enable_proactive_mode @@ -55,11 +57,17 @@ class AbstractConfig: unsub_func() self._unsub_proactive_report = None + @callback def should_expose(self, entity_id): """If an entity should be exposed.""" # pylint: disable=no-self-use return False + @callback + def async_invalidate_access_token(self): + """Invalidate access token.""" + raise NotImplementedError + async def async_get_access_token(self): """Get an access token.""" raise NotImplementedError diff --git a/homeassistant/components/alexa/smart_home_http.py b/homeassistant/components/alexa/smart_home_http.py index 7fdd4e3000a..ada00e8a326 100644 --- a/homeassistant/components/alexa/smart_home_http.py +++ b/homeassistant/components/alexa/smart_home_http.py @@ -57,6 +57,11 @@ class AlexaConfig(AbstractConfig): """If an entity should be exposed.""" return self._config[CONF_FILTER](entity_id) + @core.callback + def async_invalidate_access_token(self): + """Invalidate access token.""" + self._auth.async_invalidate_access_token() + async def async_get_access_token(self): """Get an access token.""" return await self._auth.async_get_access_token() diff --git a/homeassistant/components/alexa/state_report.py b/homeassistant/components/alexa/state_report.py index 7e842889977..1e22d5fc09f 100644 --- a/homeassistant/components/alexa/state_report.py +++ b/homeassistant/components/alexa/state_report.py @@ -51,7 +51,9 @@ async def async_enable_proactive_mode(hass, smart_home_config): ) -async def async_send_changereport_message(hass, config, alexa_entity): +async def async_send_changereport_message( + hass, config, alexa_entity, *, invalidate_access_token=True +): """Send a ChangeReport message for an Alexa entity. https://developer.amazon.com/docs/smarthome/state-reporting-for-a-smart-home-skill.html#report-state-with-changereport-events @@ -88,21 +90,30 @@ async def async_send_changereport_message(hass, config, alexa_entity): except (asyncio.TimeoutError, aiohttp.ClientError): _LOGGER.error("Timeout sending report to Alexa.") - return None + return response_text = await response.text() _LOGGER.debug("Sent: %s", json.dumps(message_serialized)) _LOGGER.debug("Received (%s): %s", response.status, response_text) - if response.status != 202: - response_json = json.loads(response_text) - _LOGGER.error( - "Error when sending ChangeReport to Alexa: %s: %s", - response_json["payload"]["code"], - response_json["payload"]["description"], + if response.status == 202 and not invalidate_access_token: + return + + response_json = json.loads(response_text) + + if response_json["payload"]["code"] == "INVALID_ACCESS_TOKEN_EXCEPTION": + config.async_invalidate_access_token() + return await async_send_changereport_message( + hass, config, alexa_entity, invalidate_access_token=False ) + _LOGGER.error( + "Error when sending ChangeReport to Alexa: %s: %s", + response_json["payload"]["code"], + response_json["payload"]["description"], + ) + async def async_send_add_or_update_message(hass, config, entity_ids): """Send an AddOrUpdateReport message for entities. diff --git a/homeassistant/components/cloud/alexa_config.py b/homeassistant/components/cloud/alexa_config.py index d31bcfdfc40..a1432f196bf 100644 --- a/homeassistant/components/cloud/alexa_config.py +++ b/homeassistant/components/cloud/alexa_config.py @@ -7,6 +7,7 @@ import aiohttp import async_timeout from hass_nabucasa import cloud_api +from homeassistant.core import callback from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES from homeassistant.helpers import entity_registry from homeassistant.helpers.event import async_call_later @@ -95,9 +96,14 @@ class AlexaConfig(alexa_config.AbstractConfig): entity_config = entity_configs.get(entity_id, {}) return entity_config.get(PREF_SHOULD_EXPOSE, DEFAULT_SHOULD_EXPOSE) + @callback + def async_invalidate_access_token(self): + """Invalidate access token.""" + self._token_valid = None + async def async_get_access_token(self): """Get an access token.""" - if self._token_valid is not None and self._token_valid < utcnow(): + if self._token_valid is not None and self._token_valid > utcnow(): return self._token resp = await cloud_api.async_alexa_access_token(self._cloud) diff --git a/tests/components/cloud/test_alexa_config.py b/tests/components/cloud/test_alexa_config.py index 22d8c64c3b0..c8e84016a28 100644 --- a/tests/components/cloud/test_alexa_config.py +++ b/tests/components/cloud/test_alexa_config.py @@ -1,6 +1,6 @@ """Test Alexa config.""" import contextlib -from unittest.mock import patch +from unittest.mock import patch, Mock from homeassistant.components.cloud import ALEXA_SCHEMA, alexa_config from homeassistant.util.dt import utcnow @@ -43,6 +43,42 @@ async def test_alexa_config_report_state(hass, cloud_prefs): assert conf.is_reporting_states is False +async def test_alexa_config_invalidate_token(hass, cloud_prefs, aioclient_mock): + """Test Alexa config should expose using prefs.""" + aioclient_mock.post( + "http://example/alexa_token", + json={ + "access_token": "mock-token", + "event_endpoint": "http://example.com/alexa_endpoint", + "expires_in": 30, + }, + ) + conf = alexa_config.AlexaConfig( + hass, + ALEXA_SCHEMA({}), + cloud_prefs, + Mock( + alexa_access_token_url="http://example/alexa_token", + run_executor=Mock(side_effect=mock_coro), + websession=hass.helpers.aiohttp_client.async_get_clientsession(), + ), + ) + + token = await conf.async_get_access_token() + assert token == "mock-token" + assert len(aioclient_mock.mock_calls) == 1 + + token = await conf.async_get_access_token() + assert token == "mock-token" + assert len(aioclient_mock.mock_calls) == 1 + assert conf._token_valid is not None + conf.async_invalidate_access_token() + assert conf._token_valid is None + token = await conf.async_get_access_token() + assert token == "mock-token" + assert len(aioclient_mock.mock_calls) == 2 + + @contextlib.contextmanager def patch_sync_helper(): """Patch sync helper.