From 6f43285a28157653495dc5e2e159da1c47f80114 Mon Sep 17 00:00:00 2001 From: Aidan Timson Date: Mon, 7 Sep 2020 20:26:58 +0100 Subject: [PATCH] Force token expires_in to float (#39489) --- .../helpers/config_entry_oauth2_flow.py | 10 ++++ .../helpers/test_config_entry_oauth2_flow.py | 56 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/homeassistant/helpers/config_entry_oauth2_flow.py b/homeassistant/helpers/config_entry_oauth2_flow.py index b845db966bb..da86c222c13 100644 --- a/homeassistant/helpers/config_entry_oauth2_flow.py +++ b/homeassistant/helpers/config_entry_oauth2_flow.py @@ -25,6 +25,8 @@ from homeassistant.helpers.network import get_url from .aiohttp_client import async_get_clientsession +_LOGGER = logging.getLogger(__name__) + DATA_JWT_SECRET = "oauth2_jwt_secret" DATA_VIEW_REGISTERED = "oauth2_view_reg" DATA_IMPLEMENTATIONS = "oauth2_impl" @@ -77,6 +79,8 @@ class AbstractOAuth2Implementation(ABC): async def async_refresh_token(self, token: dict) -> dict: """Refresh a token and update expires info.""" new_token = await self._async_refresh_token(token) + # Force int for non-compliant oauth2 providers + new_token["expires_in"] = int(new_token["expires_in"]) new_token["expires_at"] = time.time() + new_token["expires_in"] return new_token @@ -257,6 +261,12 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta): ) -> Dict[str, Any]: """Create config entry from external data.""" token = await self.flow_impl.async_resolve_external_data(self.external_data) + # Force int for non-compliant oauth2 providers + try: + token["expires_in"] = int(token["expires_in"]) + except ValueError as err: + _LOGGER.warning("Error converting expires_in to int: %s", err) + return self.async_abort(reason="oauth_error") token["expires_at"] = time.time() + token["expires_in"] self.logger.info("Successfully authenticated") diff --git a/tests/helpers/test_config_entry_oauth2_flow.py b/tests/helpers/test_config_entry_oauth2_flow.py index 227f3e366f3..dc34b0f7876 100644 --- a/tests/helpers/test_config_entry_oauth2_flow.py +++ b/tests/helpers/test_config_entry_oauth2_flow.py @@ -128,6 +128,62 @@ async def test_abort_if_authorization_timeout(hass, flow_handler, local_impl): assert result["reason"] == "authorize_url_timeout" +async def test_abort_if_oauth_error( + hass, flow_handler, local_impl, aiohttp_client, aioclient_mock, current_request +): + """Check bad oauth token.""" + await async_process_ha_core_config( + hass, + {"external_url": "https://example.com"}, + ) + + flow_handler.async_register_implementation(hass, local_impl) + config_entry_oauth2_flow.async_register_implementation( + hass, TEST_DOMAIN, MockOAuth2Implementation() + ) + + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "pick_implementation" + + # Pick implementation + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={"implementation": TEST_DOMAIN} + ) + + state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + + assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP + assert result["url"] == ( + f"{AUTHORIZE_URL}?response_type=code&client_id={CLIENT_ID}" + "&redirect_uri=https://example.com/auth/external/callback" + f"&state={state}&scope=read+write" + ) + + client = await aiohttp_client(hass.http.app) + resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") + assert resp.status == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + + aioclient_mock.post( + TOKEN_URL, + json={ + "refresh_token": REFRESH_TOKEN, + "access_token": ACCESS_TOKEN_1, + "type": "bearer", + "expires_in": "badnumber", + }, + ) + + result = await hass.config_entries.flow.async_configure(result["flow_id"]) + + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result["reason"] == "oauth_error" + + async def test_step_discovery(hass, flow_handler, local_impl): """Check flow triggers from discovery.""" await async_process_ha_core_config(