mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Resolve race condition when HA auth provider is loading (#21619)
* Resolve race condition when HA auth provider is loading * Fix * Add more tests * Lint
This commit is contained in:
parent
7a7080055e
commit
4a3b4cf346
@ -2,6 +2,7 @@
|
||||
|
||||
Sending HOTP through notify service
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional, List
|
||||
@ -90,6 +91,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
||||
self._include = config.get(CONF_INCLUDE, [])
|
||||
self._exclude = config.get(CONF_EXCLUDE, [])
|
||||
self._message_template = config[CONF_MESSAGE]
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def input_schema(self) -> vol.Schema:
|
||||
@ -98,15 +100,19 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
||||
|
||||
async def _async_load(self) -> None:
|
||||
"""Load stored data."""
|
||||
data = await self._user_store.async_load()
|
||||
async with self._init_lock:
|
||||
if self._user_settings is not None:
|
||||
return
|
||||
|
||||
if data is None:
|
||||
data = {STORAGE_USERS: {}}
|
||||
data = await self._user_store.async_load()
|
||||
|
||||
self._user_settings = {
|
||||
user_id: NotifySetting(**setting)
|
||||
for user_id, setting in data.get(STORAGE_USERS, {}).items()
|
||||
}
|
||||
if data is None:
|
||||
data = {STORAGE_USERS: {}}
|
||||
|
||||
self._user_settings = {
|
||||
user_id: NotifySetting(**setting)
|
||||
for user_id, setting in data.get(STORAGE_USERS, {}).items()
|
||||
}
|
||||
|
||||
async def _async_save(self) -> None:
|
||||
"""Save data."""
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Time-based One Time Password auth module."""
|
||||
import asyncio
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, Optional, Tuple # noqa: F401
|
||||
@ -68,6 +69,7 @@ class TotpAuthModule(MultiFactorAuthModule):
|
||||
self._users = None # type: Optional[Dict[str, str]]
|
||||
self._user_store = hass.helpers.storage.Store(
|
||||
STORAGE_VERSION, STORAGE_KEY, private=True)
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def input_schema(self) -> vol.Schema:
|
||||
@ -76,12 +78,16 @@ class TotpAuthModule(MultiFactorAuthModule):
|
||||
|
||||
async def _async_load(self) -> None:
|
||||
"""Load stored data."""
|
||||
data = await self._user_store.async_load()
|
||||
async with self._init_lock:
|
||||
if self._users is not None:
|
||||
return
|
||||
|
||||
if data is None:
|
||||
data = {STORAGE_USERS: {}}
|
||||
data = await self._user_store.async_load()
|
||||
|
||||
self._users = data.get(STORAGE_USERS, {})
|
||||
if data is None:
|
||||
data = {STORAGE_USERS: {}}
|
||||
|
||||
self._users = data.get(STORAGE_USERS, {})
|
||||
|
||||
async def _async_save(self) -> None:
|
||||
"""Save data."""
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Home Assistant auth provider."""
|
||||
import asyncio
|
||||
import base64
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
@ -204,15 +205,21 @@ class HassAuthProvider(AuthProvider):
|
||||
|
||||
DEFAULT_TITLE = 'Home Assistant Local'
|
||||
|
||||
data = None
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Initialize an Home Assistant auth provider."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.data = None # type: Optional[Data]
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
async def async_initialize(self) -> None:
|
||||
"""Initialize the auth provider."""
|
||||
if self.data is not None:
|
||||
return
|
||||
async with self._init_lock:
|
||||
if self.data is not None:
|
||||
return
|
||||
|
||||
self.data = Data(self.hass)
|
||||
await self.data.async_load()
|
||||
data = Data(self.hass)
|
||||
await data.async_load()
|
||||
self.data = data
|
||||
|
||||
async def async_login_flow(
|
||||
self, context: Optional[Dict]) -> LoginFlow:
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Test the HMAC-based One Time Password (MFA) auth module."""
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
@ -395,3 +396,26 @@ async def test_not_raise_exception_when_service_not_exist(hass):
|
||||
|
||||
# wait service call finished
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_race_condition_in_data_loading(hass):
|
||||
"""Test race condition in the data loading."""
|
||||
counter = 0
|
||||
|
||||
async def mock_load(_):
|
||||
"""Mock homeassistant.helpers.storage.Store.async_load."""
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
await asyncio.sleep(0)
|
||||
|
||||
notify_auth_module = await auth_mfa_module_from_config(hass, {
|
||||
'type': 'notify'
|
||||
})
|
||||
with patch('homeassistant.helpers.storage.Store.async_load',
|
||||
new=mock_load):
|
||||
task1 = notify_auth_module.async_validate('user', {'code': 'value'})
|
||||
task2 = notify_auth_module.async_validate('user', {'code': 'value'})
|
||||
results = await asyncio.gather(task1, task2, return_exceptions=True)
|
||||
assert counter == 1
|
||||
assert results[0] is False
|
||||
assert results[1] is False
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Test the Time-based One Time Password (MFA) auth module."""
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
@ -128,3 +129,26 @@ async def test_login_flow_validates_mfa(hass):
|
||||
result['flow_id'], {'code': MOCK_CODE})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert result['data'].id == 'mock-user'
|
||||
|
||||
|
||||
async def test_race_condition_in_data_loading(hass):
|
||||
"""Test race condition in the data loading."""
|
||||
counter = 0
|
||||
|
||||
async def mock_load(_):
|
||||
"""Mock of homeassistant.helpers.storage.Store.async_load."""
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
await asyncio.sleep(0)
|
||||
|
||||
totp_auth_module = await auth_mfa_module_from_config(hass, {
|
||||
'type': 'totp'
|
||||
})
|
||||
with patch('homeassistant.helpers.storage.Store.async_load',
|
||||
new=mock_load):
|
||||
task1 = totp_auth_module.async_validate('user', {'code': 'value'})
|
||||
task2 = totp_auth_module.async_validate('user', {'code': 'value'})
|
||||
results = await asyncio.gather(task1, task2, return_exceptions=True)
|
||||
assert counter == 1
|
||||
assert results[0] is False
|
||||
assert results[1] is False
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Test the Home Assistant local auth provider."""
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@ -288,3 +289,29 @@ async def test_legacy_get_or_create_credentials(hass, legacy_data):
|
||||
'username': 'hello '
|
||||
})
|
||||
assert credentials1 is not credentials3
|
||||
|
||||
|
||||
async def test_race_condition_in_data_loading(hass):
|
||||
"""Test race condition in the hass_auth.Data loading.
|
||||
|
||||
Ref issue: https://github.com/home-assistant/home-assistant/issues/21569
|
||||
"""
|
||||
counter = 0
|
||||
|
||||
async def mock_load(_):
|
||||
"""Mock of homeassistant.helpers.storage.Store.async_load."""
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
await asyncio.sleep(0)
|
||||
|
||||
provider = hass_auth.HassAuthProvider(hass, auth_store.AuthStore(hass),
|
||||
{'type': 'homeassistant'})
|
||||
with patch('homeassistant.helpers.storage.Store.async_load',
|
||||
new=mock_load):
|
||||
task1 = provider.async_validate_login('user', 'pass')
|
||||
task2 = provider.async_validate_login('user', 'pass')
|
||||
results = await asyncio.gather(task1, task2, return_exceptions=True)
|
||||
assert counter == 1
|
||||
assert isinstance(results[0], hass_auth.InvalidAuth)
|
||||
# results[1] will be a TypeError if race condition occurred
|
||||
assert isinstance(results[1], hass_auth.InvalidAuth)
|
||||
|
Loading…
x
Reference in New Issue
Block a user