Resolve race condition when HA auth provider is loading (#21619)

* Resolve race condition when HA auth provider is loading

* Fix

* Add more tests

* Lint
This commit is contained in:
Jason Hu 2019-03-04 15:55:26 -08:00 committed by Paulus Schoutsen
parent 7a7080055e
commit 4a3b4cf346
6 changed files with 110 additions and 16 deletions

View File

@ -2,6 +2,7 @@
Sending HOTP through notify service
"""
import asyncio
import logging
from collections import OrderedDict
from typing import Any, Dict, Optional, List
@ -90,6 +91,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
self._include = config.get(CONF_INCLUDE, [])
self._exclude = config.get(CONF_EXCLUDE, [])
self._message_template = config[CONF_MESSAGE]
self._init_lock = asyncio.Lock()
@property
def input_schema(self) -> vol.Schema:
@ -98,15 +100,19 @@ class NotifyAuthModule(MultiFactorAuthModule):
async def _async_load(self) -> None:
"""Load stored data."""
data = await self._user_store.async_load()
async with self._init_lock:
if self._user_settings is not None:
return
if data is None:
data = {STORAGE_USERS: {}}
data = await self._user_store.async_load()
self._user_settings = {
user_id: NotifySetting(**setting)
for user_id, setting in data.get(STORAGE_USERS, {}).items()
}
if data is None:
data = {STORAGE_USERS: {}}
self._user_settings = {
user_id: NotifySetting(**setting)
for user_id, setting in data.get(STORAGE_USERS, {}).items()
}
async def _async_save(self) -> None:
"""Save data."""

View File

@ -1,4 +1,5 @@
"""Time-based One Time Password auth module."""
import asyncio
import logging
from io import BytesIO
from typing import Any, Dict, Optional, Tuple # noqa: F401
@ -68,6 +69,7 @@ class TotpAuthModule(MultiFactorAuthModule):
self._users = None # type: Optional[Dict[str, str]]
self._user_store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True)
self._init_lock = asyncio.Lock()
@property
def input_schema(self) -> vol.Schema:
@ -76,12 +78,16 @@ class TotpAuthModule(MultiFactorAuthModule):
async def _async_load(self) -> None:
"""Load stored data."""
data = await self._user_store.async_load()
async with self._init_lock:
if self._users is not None:
return
if data is None:
data = {STORAGE_USERS: {}}
data = await self._user_store.async_load()
self._users = data.get(STORAGE_USERS, {})
if data is None:
data = {STORAGE_USERS: {}}
self._users = data.get(STORAGE_USERS, {})
async def _async_save(self) -> None:
"""Save data."""

View File

@ -1,4 +1,5 @@
"""Home Assistant auth provider."""
import asyncio
import base64
from collections import OrderedDict
import logging
@ -204,15 +205,21 @@ class HassAuthProvider(AuthProvider):
DEFAULT_TITLE = 'Home Assistant Local'
data = None
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize an Home Assistant auth provider."""
super().__init__(*args, **kwargs)
self.data = None # type: Optional[Data]
self._init_lock = asyncio.Lock()
async def async_initialize(self) -> None:
"""Initialize the auth provider."""
if self.data is not None:
return
async with self._init_lock:
if self.data is not None:
return
self.data = Data(self.hass)
await self.data.async_load()
data = Data(self.hass)
await data.async_load()
self.data = data
async def async_login_flow(
self, context: Optional[Dict]) -> LoginFlow:

View File

@ -1,4 +1,5 @@
"""Test the HMAC-based One Time Password (MFA) auth module."""
import asyncio
from unittest.mock import patch
from homeassistant import data_entry_flow
@ -395,3 +396,26 @@ async def test_not_raise_exception_when_service_not_exist(hass):
# wait service call finished
await hass.async_block_till_done()
async def test_race_condition_in_data_loading(hass):
"""Test race condition in the data loading."""
counter = 0
async def mock_load(_):
"""Mock homeassistant.helpers.storage.Store.async_load."""
nonlocal counter
counter += 1
await asyncio.sleep(0)
notify_auth_module = await auth_mfa_module_from_config(hass, {
'type': 'notify'
})
with patch('homeassistant.helpers.storage.Store.async_load',
new=mock_load):
task1 = notify_auth_module.async_validate('user', {'code': 'value'})
task2 = notify_auth_module.async_validate('user', {'code': 'value'})
results = await asyncio.gather(task1, task2, return_exceptions=True)
assert counter == 1
assert results[0] is False
assert results[1] is False

View File

@ -1,4 +1,5 @@
"""Test the Time-based One Time Password (MFA) auth module."""
import asyncio
from unittest.mock import patch
from homeassistant import data_entry_flow
@ -128,3 +129,26 @@ async def test_login_flow_validates_mfa(hass):
result['flow_id'], {'code': MOCK_CODE})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result['data'].id == 'mock-user'
async def test_race_condition_in_data_loading(hass):
"""Test race condition in the data loading."""
counter = 0
async def mock_load(_):
"""Mock of homeassistant.helpers.storage.Store.async_load."""
nonlocal counter
counter += 1
await asyncio.sleep(0)
totp_auth_module = await auth_mfa_module_from_config(hass, {
'type': 'totp'
})
with patch('homeassistant.helpers.storage.Store.async_load',
new=mock_load):
task1 = totp_auth_module.async_validate('user', {'code': 'value'})
task2 = totp_auth_module.async_validate('user', {'code': 'value'})
results = await asyncio.gather(task1, task2, return_exceptions=True)
assert counter == 1
assert results[0] is False
assert results[1] is False

View File

@ -1,4 +1,5 @@
"""Test the Home Assistant local auth provider."""
import asyncio
from unittest.mock import Mock, patch
import pytest
@ -288,3 +289,29 @@ async def test_legacy_get_or_create_credentials(hass, legacy_data):
'username': 'hello '
})
assert credentials1 is not credentials3
async def test_race_condition_in_data_loading(hass):
"""Test race condition in the hass_auth.Data loading.
Ref issue: https://github.com/home-assistant/home-assistant/issues/21569
"""
counter = 0
async def mock_load(_):
"""Mock of homeassistant.helpers.storage.Store.async_load."""
nonlocal counter
counter += 1
await asyncio.sleep(0)
provider = hass_auth.HassAuthProvider(hass, auth_store.AuthStore(hass),
{'type': 'homeassistant'})
with patch('homeassistant.helpers.storage.Store.async_load',
new=mock_load):
task1 = provider.async_validate_login('user', 'pass')
task2 = provider.async_validate_login('user', 'pass')
results = await asyncio.gather(task1, task2, return_exceptions=True)
assert counter == 1
assert isinstance(results[0], hass_auth.InvalidAuth)
# results[1] will be a TypeError if race condition occurred
assert isinstance(results[1], hass_auth.InvalidAuth)