From 4a3b4cf3460e539bd213d7b4593c51e4fffb67a1 Mon Sep 17 00:00:00 2001 From: Jason Hu Date: Mon, 4 Mar 2019 15:55:26 -0800 Subject: [PATCH] Resolve race condition when HA auth provider is loading (#21619) * Resolve race condition when HA auth provider is loading * Fix * Add more tests * Lint --- homeassistant/auth/mfa_modules/notify.py | 20 +++++++++----- homeassistant/auth/mfa_modules/totp.py | 14 +++++++--- homeassistant/auth/providers/homeassistant.py | 17 ++++++++---- tests/auth/mfa_modules/test_notify.py | 24 +++++++++++++++++ tests/auth/mfa_modules/test_totp.py | 24 +++++++++++++++++ tests/auth/providers/test_homeassistant.py | 27 +++++++++++++++++++ 6 files changed, 110 insertions(+), 16 deletions(-) diff --git a/homeassistant/auth/mfa_modules/notify.py b/homeassistant/auth/mfa_modules/notify.py index 3c26f8b4bde..310abff9484 100644 --- a/homeassistant/auth/mfa_modules/notify.py +++ b/homeassistant/auth/mfa_modules/notify.py @@ -2,6 +2,7 @@ Sending HOTP through notify service """ +import asyncio import logging from collections import OrderedDict from typing import Any, Dict, Optional, List @@ -90,6 +91,7 @@ class NotifyAuthModule(MultiFactorAuthModule): self._include = config.get(CONF_INCLUDE, []) self._exclude = config.get(CONF_EXCLUDE, []) self._message_template = config[CONF_MESSAGE] + self._init_lock = asyncio.Lock() @property def input_schema(self) -> vol.Schema: @@ -98,15 +100,19 @@ class NotifyAuthModule(MultiFactorAuthModule): async def _async_load(self) -> None: """Load stored data.""" - data = await self._user_store.async_load() + async with self._init_lock: + if self._user_settings is not None: + return - if data is None: - data = {STORAGE_USERS: {}} + data = await self._user_store.async_load() - self._user_settings = { - user_id: NotifySetting(**setting) - for user_id, setting in data.get(STORAGE_USERS, {}).items() - } + if data is None: + data = {STORAGE_USERS: {}} + + self._user_settings = { + user_id: NotifySetting(**setting) + for user_id, setting in data.get(STORAGE_USERS, {}).items() + } async def _async_save(self) -> None: """Save data.""" diff --git a/homeassistant/auth/mfa_modules/totp.py b/homeassistant/auth/mfa_modules/totp.py index 68f4e1d0596..dc51152f565 100644 --- a/homeassistant/auth/mfa_modules/totp.py +++ b/homeassistant/auth/mfa_modules/totp.py @@ -1,4 +1,5 @@ """Time-based One Time Password auth module.""" +import asyncio import logging from io import BytesIO from typing import Any, Dict, Optional, Tuple # noqa: F401 @@ -68,6 +69,7 @@ class TotpAuthModule(MultiFactorAuthModule): self._users = None # type: Optional[Dict[str, str]] self._user_store = hass.helpers.storage.Store( STORAGE_VERSION, STORAGE_KEY, private=True) + self._init_lock = asyncio.Lock() @property def input_schema(self) -> vol.Schema: @@ -76,12 +78,16 @@ class TotpAuthModule(MultiFactorAuthModule): async def _async_load(self) -> None: """Load stored data.""" - data = await self._user_store.async_load() + async with self._init_lock: + if self._users is not None: + return - if data is None: - data = {STORAGE_USERS: {}} + data = await self._user_store.async_load() - self._users = data.get(STORAGE_USERS, {}) + if data is None: + data = {STORAGE_USERS: {}} + + self._users = data.get(STORAGE_USERS, {}) async def _async_save(self) -> None: """Save data.""" diff --git a/homeassistant/auth/providers/homeassistant.py b/homeassistant/auth/providers/homeassistant.py index b22f93f11f1..2187d272800 100644 --- a/homeassistant/auth/providers/homeassistant.py +++ b/homeassistant/auth/providers/homeassistant.py @@ -1,4 +1,5 @@ """Home Assistant auth provider.""" +import asyncio import base64 from collections import OrderedDict import logging @@ -204,15 +205,21 @@ class HassAuthProvider(AuthProvider): DEFAULT_TITLE = 'Home Assistant Local' - data = None + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize an Home Assistant auth provider.""" + super().__init__(*args, **kwargs) + self.data = None # type: Optional[Data] + self._init_lock = asyncio.Lock() async def async_initialize(self) -> None: """Initialize the auth provider.""" - if self.data is not None: - return + async with self._init_lock: + if self.data is not None: + return - self.data = Data(self.hass) - await self.data.async_load() + data = Data(self.hass) + await data.async_load() + self.data = data async def async_login_flow( self, context: Optional[Dict]) -> LoginFlow: diff --git a/tests/auth/mfa_modules/test_notify.py b/tests/auth/mfa_modules/test_notify.py index 748b5507824..c0680024dae 100644 --- a/tests/auth/mfa_modules/test_notify.py +++ b/tests/auth/mfa_modules/test_notify.py @@ -1,4 +1,5 @@ """Test the HMAC-based One Time Password (MFA) auth module.""" +import asyncio from unittest.mock import patch from homeassistant import data_entry_flow @@ -395,3 +396,26 @@ async def test_not_raise_exception_when_service_not_exist(hass): # wait service call finished await hass.async_block_till_done() + + +async def test_race_condition_in_data_loading(hass): + """Test race condition in the data loading.""" + counter = 0 + + async def mock_load(_): + """Mock homeassistant.helpers.storage.Store.async_load.""" + nonlocal counter + counter += 1 + await asyncio.sleep(0) + + notify_auth_module = await auth_mfa_module_from_config(hass, { + 'type': 'notify' + }) + with patch('homeassistant.helpers.storage.Store.async_load', + new=mock_load): + task1 = notify_auth_module.async_validate('user', {'code': 'value'}) + task2 = notify_auth_module.async_validate('user', {'code': 'value'}) + results = await asyncio.gather(task1, task2, return_exceptions=True) + assert counter == 1 + assert results[0] is False + assert results[1] is False diff --git a/tests/auth/mfa_modules/test_totp.py b/tests/auth/mfa_modules/test_totp.py index d400fe80672..35ab21ae6de 100644 --- a/tests/auth/mfa_modules/test_totp.py +++ b/tests/auth/mfa_modules/test_totp.py @@ -1,4 +1,5 @@ """Test the Time-based One Time Password (MFA) auth module.""" +import asyncio from unittest.mock import patch from homeassistant import data_entry_flow @@ -128,3 +129,26 @@ async def test_login_flow_validates_mfa(hass): result['flow_id'], {'code': MOCK_CODE}) assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result['data'].id == 'mock-user' + + +async def test_race_condition_in_data_loading(hass): + """Test race condition in the data loading.""" + counter = 0 + + async def mock_load(_): + """Mock of homeassistant.helpers.storage.Store.async_load.""" + nonlocal counter + counter += 1 + await asyncio.sleep(0) + + totp_auth_module = await auth_mfa_module_from_config(hass, { + 'type': 'totp' + }) + with patch('homeassistant.helpers.storage.Store.async_load', + new=mock_load): + task1 = totp_auth_module.async_validate('user', {'code': 'value'}) + task2 = totp_auth_module.async_validate('user', {'code': 'value'}) + results = await asyncio.gather(task1, task2, return_exceptions=True) + assert counter == 1 + assert results[0] is False + assert results[1] is False diff --git a/tests/auth/providers/test_homeassistant.py b/tests/auth/providers/test_homeassistant.py index ffc4d67f21d..c466a1fa42b 100644 --- a/tests/auth/providers/test_homeassistant.py +++ b/tests/auth/providers/test_homeassistant.py @@ -1,4 +1,5 @@ """Test the Home Assistant local auth provider.""" +import asyncio from unittest.mock import Mock, patch import pytest @@ -288,3 +289,29 @@ async def test_legacy_get_or_create_credentials(hass, legacy_data): 'username': 'hello ' }) assert credentials1 is not credentials3 + + +async def test_race_condition_in_data_loading(hass): + """Test race condition in the hass_auth.Data loading. + + Ref issue: https://github.com/home-assistant/home-assistant/issues/21569 + """ + counter = 0 + + async def mock_load(_): + """Mock of homeassistant.helpers.storage.Store.async_load.""" + nonlocal counter + counter += 1 + await asyncio.sleep(0) + + provider = hass_auth.HassAuthProvider(hass, auth_store.AuthStore(hass), + {'type': 'homeassistant'}) + with patch('homeassistant.helpers.storage.Store.async_load', + new=mock_load): + task1 = provider.async_validate_login('user', 'pass') + task2 = provider.async_validate_login('user', 'pass') + results = await asyncio.gather(task1, task2, return_exceptions=True) + assert counter == 1 + assert isinstance(results[0], hass_auth.InvalidAuth) + # results[1] will be a TypeError if race condition occurred + assert isinstance(results[1], hass_auth.InvalidAuth)