Resolve race condition when HA auth provider is loading (#21619)

* Resolve race condition when HA auth provider is loading

* Fix

* Add more tests

* Lint
This commit is contained in:
Jason Hu 2019-03-04 15:55:26 -08:00 committed by Paulus Schoutsen
parent 7a7080055e
commit 4a3b4cf346
6 changed files with 110 additions and 16 deletions

View File

@ -2,6 +2,7 @@
Sending HOTP through notify service Sending HOTP through notify service
""" """
import asyncio
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, Optional, List from typing import Any, Dict, Optional, List
@ -90,6 +91,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
self._include = config.get(CONF_INCLUDE, []) self._include = config.get(CONF_INCLUDE, [])
self._exclude = config.get(CONF_EXCLUDE, []) self._exclude = config.get(CONF_EXCLUDE, [])
self._message_template = config[CONF_MESSAGE] self._message_template = config[CONF_MESSAGE]
self._init_lock = asyncio.Lock()
@property @property
def input_schema(self) -> vol.Schema: def input_schema(self) -> vol.Schema:
@ -98,6 +100,10 @@ class NotifyAuthModule(MultiFactorAuthModule):
async def _async_load(self) -> None: async def _async_load(self) -> None:
"""Load stored data.""" """Load stored data."""
async with self._init_lock:
if self._user_settings is not None:
return
data = await self._user_store.async_load() data = await self._user_store.async_load()
if data is None: if data is None:

View File

@ -1,4 +1,5 @@
"""Time-based One Time Password auth module.""" """Time-based One Time Password auth module."""
import asyncio
import logging import logging
from io import BytesIO from io import BytesIO
from typing import Any, Dict, Optional, Tuple # noqa: F401 from typing import Any, Dict, Optional, Tuple # noqa: F401
@ -68,6 +69,7 @@ class TotpAuthModule(MultiFactorAuthModule):
self._users = None # type: Optional[Dict[str, str]] self._users = None # type: Optional[Dict[str, str]]
self._user_store = hass.helpers.storage.Store( self._user_store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True) STORAGE_VERSION, STORAGE_KEY, private=True)
self._init_lock = asyncio.Lock()
@property @property
def input_schema(self) -> vol.Schema: def input_schema(self) -> vol.Schema:
@ -76,6 +78,10 @@ class TotpAuthModule(MultiFactorAuthModule):
async def _async_load(self) -> None: async def _async_load(self) -> None:
"""Load stored data.""" """Load stored data."""
async with self._init_lock:
if self._users is not None:
return
data = await self._user_store.async_load() data = await self._user_store.async_load()
if data is None: if data is None:

View File

@ -1,4 +1,5 @@
"""Home Assistant auth provider.""" """Home Assistant auth provider."""
import asyncio
import base64 import base64
from collections import OrderedDict from collections import OrderedDict
import logging import logging
@ -204,15 +205,21 @@ class HassAuthProvider(AuthProvider):
DEFAULT_TITLE = 'Home Assistant Local' DEFAULT_TITLE = 'Home Assistant Local'
data = None def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize an Home Assistant auth provider."""
super().__init__(*args, **kwargs)
self.data = None # type: Optional[Data]
self._init_lock = asyncio.Lock()
async def async_initialize(self) -> None: async def async_initialize(self) -> None:
"""Initialize the auth provider.""" """Initialize the auth provider."""
async with self._init_lock:
if self.data is not None: if self.data is not None:
return return
self.data = Data(self.hass) data = Data(self.hass)
await self.data.async_load() await data.async_load()
self.data = data
async def async_login_flow( async def async_login_flow(
self, context: Optional[Dict]) -> LoginFlow: self, context: Optional[Dict]) -> LoginFlow:

View File

@ -1,4 +1,5 @@
"""Test the HMAC-based One Time Password (MFA) auth module.""" """Test the HMAC-based One Time Password (MFA) auth module."""
import asyncio
from unittest.mock import patch from unittest.mock import patch
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
@ -395,3 +396,26 @@ async def test_not_raise_exception_when_service_not_exist(hass):
# wait service call finished # wait service call finished
await hass.async_block_till_done() await hass.async_block_till_done()
async def test_race_condition_in_data_loading(hass):
"""Test race condition in the data loading."""
counter = 0
async def mock_load(_):
"""Mock homeassistant.helpers.storage.Store.async_load."""
nonlocal counter
counter += 1
await asyncio.sleep(0)
notify_auth_module = await auth_mfa_module_from_config(hass, {
'type': 'notify'
})
with patch('homeassistant.helpers.storage.Store.async_load',
new=mock_load):
task1 = notify_auth_module.async_validate('user', {'code': 'value'})
task2 = notify_auth_module.async_validate('user', {'code': 'value'})
results = await asyncio.gather(task1, task2, return_exceptions=True)
assert counter == 1
assert results[0] is False
assert results[1] is False

View File

@ -1,4 +1,5 @@
"""Test the Time-based One Time Password (MFA) auth module.""" """Test the Time-based One Time Password (MFA) auth module."""
import asyncio
from unittest.mock import patch from unittest.mock import patch
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
@ -128,3 +129,26 @@ async def test_login_flow_validates_mfa(hass):
result['flow_id'], {'code': MOCK_CODE}) result['flow_id'], {'code': MOCK_CODE})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result['data'].id == 'mock-user' assert result['data'].id == 'mock-user'
async def test_race_condition_in_data_loading(hass):
"""Test race condition in the data loading."""
counter = 0
async def mock_load(_):
"""Mock of homeassistant.helpers.storage.Store.async_load."""
nonlocal counter
counter += 1
await asyncio.sleep(0)
totp_auth_module = await auth_mfa_module_from_config(hass, {
'type': 'totp'
})
with patch('homeassistant.helpers.storage.Store.async_load',
new=mock_load):
task1 = totp_auth_module.async_validate('user', {'code': 'value'})
task2 = totp_auth_module.async_validate('user', {'code': 'value'})
results = await asyncio.gather(task1, task2, return_exceptions=True)
assert counter == 1
assert results[0] is False
assert results[1] is False

View File

@ -1,4 +1,5 @@
"""Test the Home Assistant local auth provider.""" """Test the Home Assistant local auth provider."""
import asyncio
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
@ -288,3 +289,29 @@ async def test_legacy_get_or_create_credentials(hass, legacy_data):
'username': 'hello ' 'username': 'hello '
}) })
assert credentials1 is not credentials3 assert credentials1 is not credentials3
async def test_race_condition_in_data_loading(hass):
"""Test race condition in the hass_auth.Data loading.
Ref issue: https://github.com/home-assistant/home-assistant/issues/21569
"""
counter = 0
async def mock_load(_):
"""Mock of homeassistant.helpers.storage.Store.async_load."""
nonlocal counter
counter += 1
await asyncio.sleep(0)
provider = hass_auth.HassAuthProvider(hass, auth_store.AuthStore(hass),
{'type': 'homeassistant'})
with patch('homeassistant.helpers.storage.Store.async_load',
new=mock_load):
task1 = provider.async_validate_login('user', 'pass')
task2 = provider.async_validate_login('user', 'pass')
results = await asyncio.gather(task1, task2, return_exceptions=True)
assert counter == 1
assert isinstance(results[0], hass_auth.InvalidAuth)
# results[1] will be a TypeError if race condition occurred
assert isinstance(results[1], hass_auth.InvalidAuth)