Force token expires_in to float (#39489)

This commit is contained in:
Aidan Timson 2020-09-07 20:26:58 +01:00 committed by GitHub
parent d0e44893f5
commit 6f43285a28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 0 deletions

View File

@ -25,6 +25,8 @@ from homeassistant.helpers.network import get_url
from .aiohttp_client import async_get_clientsession
_LOGGER = logging.getLogger(__name__)
DATA_JWT_SECRET = "oauth2_jwt_secret"
DATA_VIEW_REGISTERED = "oauth2_view_reg"
DATA_IMPLEMENTATIONS = "oauth2_impl"
@ -77,6 +79,8 @@ class AbstractOAuth2Implementation(ABC):
async def async_refresh_token(self, token: dict) -> dict:
"""Refresh a token and update expires info."""
new_token = await self._async_refresh_token(token)
# Force int for non-compliant oauth2 providers
new_token["expires_in"] = int(new_token["expires_in"])
new_token["expires_at"] = time.time() + new_token["expires_in"]
return new_token
@ -257,6 +261,12 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
) -> Dict[str, Any]:
"""Create config entry from external data."""
token = await self.flow_impl.async_resolve_external_data(self.external_data)
# Force int for non-compliant oauth2 providers
try:
token["expires_in"] = int(token["expires_in"])
except ValueError as err:
_LOGGER.warning("Error converting expires_in to int: %s", err)
return self.async_abort(reason="oauth_error")
token["expires_at"] = time.time() + token["expires_in"]
self.logger.info("Successfully authenticated")

View File

@ -128,6 +128,62 @@ async def test_abort_if_authorization_timeout(hass, flow_handler, local_impl):
assert result["reason"] == "authorize_url_timeout"
async def test_abort_if_oauth_error(
hass, flow_handler, local_impl, aiohttp_client, aioclient_mock, current_request
):
"""Check bad oauth token."""
await async_process_ha_core_config(
hass,
{"external_url": "https://example.com"},
)
flow_handler.async_register_implementation(hass, local_impl)
config_entry_oauth2_flow.async_register_implementation(
hass, TEST_DOMAIN, MockOAuth2Implementation()
)
result = await hass.config_entries.flow.async_init(
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "pick_implementation"
# Pick implementation
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={"implementation": TEST_DOMAIN}
)
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert result["url"] == (
f"{AUTHORIZE_URL}?response_type=code&client_id={CLIENT_ID}"
"&redirect_uri=https://example.com/auth/external/callback"
f"&state={state}&scope=read+write"
)
client = await aiohttp_client(hass.http.app)
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
assert resp.status == 200
assert resp.headers["content-type"] == "text/html; charset=utf-8"
aioclient_mock.post(
TOKEN_URL,
json={
"refresh_token": REFRESH_TOKEN,
"access_token": ACCESS_TOKEN_1,
"type": "bearer",
"expires_in": "badnumber",
},
)
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "oauth_error"
async def test_step_discovery(hass, flow_handler, local_impl):
"""Check flow triggers from discovery."""
await async_process_ha_core_config(