mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 06:07:17 +00:00
Force token expires_in to float (#39489)
This commit is contained in:
parent
d0e44893f5
commit
6f43285a28
@ -25,6 +25,8 @@ from homeassistant.helpers.network import get_url
|
|||||||
|
|
||||||
from .aiohttp_client import async_get_clientsession
|
from .aiohttp_client import async_get_clientsession
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
DATA_JWT_SECRET = "oauth2_jwt_secret"
|
DATA_JWT_SECRET = "oauth2_jwt_secret"
|
||||||
DATA_VIEW_REGISTERED = "oauth2_view_reg"
|
DATA_VIEW_REGISTERED = "oauth2_view_reg"
|
||||||
DATA_IMPLEMENTATIONS = "oauth2_impl"
|
DATA_IMPLEMENTATIONS = "oauth2_impl"
|
||||||
@ -77,6 +79,8 @@ class AbstractOAuth2Implementation(ABC):
|
|||||||
async def async_refresh_token(self, token: dict) -> dict:
|
async def async_refresh_token(self, token: dict) -> dict:
|
||||||
"""Refresh a token and update expires info."""
|
"""Refresh a token and update expires info."""
|
||||||
new_token = await self._async_refresh_token(token)
|
new_token = await self._async_refresh_token(token)
|
||||||
|
# Force int for non-compliant oauth2 providers
|
||||||
|
new_token["expires_in"] = int(new_token["expires_in"])
|
||||||
new_token["expires_at"] = time.time() + new_token["expires_in"]
|
new_token["expires_at"] = time.time() + new_token["expires_in"]
|
||||||
return new_token
|
return new_token
|
||||||
|
|
||||||
@ -257,6 +261,12 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Create config entry from external data."""
|
"""Create config entry from external data."""
|
||||||
token = await self.flow_impl.async_resolve_external_data(self.external_data)
|
token = await self.flow_impl.async_resolve_external_data(self.external_data)
|
||||||
|
# Force int for non-compliant oauth2 providers
|
||||||
|
try:
|
||||||
|
token["expires_in"] = int(token["expires_in"])
|
||||||
|
except ValueError as err:
|
||||||
|
_LOGGER.warning("Error converting expires_in to int: %s", err)
|
||||||
|
return self.async_abort(reason="oauth_error")
|
||||||
token["expires_at"] = time.time() + token["expires_in"]
|
token["expires_at"] = time.time() + token["expires_in"]
|
||||||
|
|
||||||
self.logger.info("Successfully authenticated")
|
self.logger.info("Successfully authenticated")
|
||||||
|
@ -128,6 +128,62 @@ async def test_abort_if_authorization_timeout(hass, flow_handler, local_impl):
|
|||||||
assert result["reason"] == "authorize_url_timeout"
|
assert result["reason"] == "authorize_url_timeout"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abort_if_oauth_error(
|
||||||
|
hass, flow_handler, local_impl, aiohttp_client, aioclient_mock, current_request
|
||||||
|
):
|
||||||
|
"""Check bad oauth token."""
|
||||||
|
await async_process_ha_core_config(
|
||||||
|
hass,
|
||||||
|
{"external_url": "https://example.com"},
|
||||||
|
)
|
||||||
|
|
||||||
|
flow_handler.async_register_implementation(hass, local_impl)
|
||||||
|
config_entry_oauth2_flow.async_register_implementation(
|
||||||
|
hass, TEST_DOMAIN, MockOAuth2Implementation()
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result["step_id"] == "pick_implementation"
|
||||||
|
|
||||||
|
# Pick implementation
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], user_input={"implementation": TEST_DOMAIN}
|
||||||
|
)
|
||||||
|
|
||||||
|
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||||
|
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||||
|
assert result["url"] == (
|
||||||
|
f"{AUTHORIZE_URL}?response_type=code&client_id={CLIENT_ID}"
|
||||||
|
"&redirect_uri=https://example.com/auth/external/callback"
|
||||||
|
f"&state={state}&scope=read+write"
|
||||||
|
)
|
||||||
|
|
||||||
|
client = await aiohttp_client(hass.http.app)
|
||||||
|
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert resp.headers["content-type"] == "text/html; charset=utf-8"
|
||||||
|
|
||||||
|
aioclient_mock.post(
|
||||||
|
TOKEN_URL,
|
||||||
|
json={
|
||||||
|
"refresh_token": REFRESH_TOKEN,
|
||||||
|
"access_token": ACCESS_TOKEN_1,
|
||||||
|
"type": "bearer",
|
||||||
|
"expires_in": "badnumber",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||||
|
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||||
|
assert result["reason"] == "oauth_error"
|
||||||
|
|
||||||
|
|
||||||
async def test_step_discovery(hass, flow_handler, local_impl):
|
async def test_step_discovery(hass, flow_handler, local_impl):
|
||||||
"""Check flow triggers from discovery."""
|
"""Check flow triggers from discovery."""
|
||||||
await async_process_ha_core_config(
|
await async_process_ha_core_config(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user