diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index c6078e03f63..90c1b337f16 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -38,6 +38,7 @@ class AuthStore: self._perm_lookup = None # type: Optional[PermissionLookup] self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY, private=True) + self._lock = asyncio.Lock() async def async_get_groups(self) -> List[models.Group]: """Retrieve all users.""" @@ -271,6 +272,13 @@ class AuthStore: self._async_schedule_save() async def _async_load(self) -> None: + """Load the users.""" + async with self._lock: + if self._users is not None: + return + await self._async_load_task() + + async def _async_load_task(self) -> None: """Load the users.""" [ent_reg, data] = await asyncio.gather( self.hass.helpers.entity_registry.async_get_registry(), diff --git a/tests/auth/test_auth_store.py b/tests/auth/test_auth_store.py index 7e9df869a04..08530da324b 100644 --- a/tests/auth/test_auth_store.py +++ b/tests/auth/test_auth_store.py @@ -1,4 +1,8 @@ """Tests for the auth store.""" +import asyncio + +import asynctest + from homeassistant.auth import auth_store @@ -218,3 +222,21 @@ async def test_system_groups_store_id_and_name(hass, hass_storage): 'name': auth_store.GROUP_NAME_READ_ONLY, }, ] + + +async def test_loading_race_condition(hass): + """Test only one storage load called when concurrent loading occurred .""" + store = auth_store.AuthStore(hass) + with asynctest.patch( + 'homeassistant.helpers.entity_registry.async_get_registry', + ) as mock_registry, asynctest.patch( + 'homeassistant.helpers.storage.Store.async_load', + ) as mock_load: + results = await asyncio.gather( + store.async_get_users(), + store.async_get_users(), + ) + + mock_registry.assert_called_once_with(hass) + mock_load.assert_called_once_with() + assert results[0] == results[1] diff --git a/tests/helpers/test_area_registry.py b/tests/helpers/test_area_registry.py index 9f2801fe334..284cb2b3dbe 100644 --- a/tests/helpers/test_area_registry.py +++ b/tests/helpers/test_area_registry.py @@ -1,4 +1,7 @@ """Tests for the Area Registry.""" +import asyncio + +import asynctest import pytest from homeassistant.helpers import area_registry @@ -125,3 +128,17 @@ async def test_loading_area_from_storage(hass, hass_storage): registry = await area_registry.async_get_registry(hass) assert len(registry.areas) == 1 + + +async def test_loading_race_condition(hass): + """Test only one storage load called when concurrent loading occurred .""" + with asynctest.patch( + 'homeassistant.helpers.area_registry.AreaRegistry.async_load', + ) as mock_load: + results = await asyncio.gather( + area_registry.async_get_registry(hass), + area_registry.async_get_registry(hass), + ) + + mock_load.assert_called_once_with() + assert results[0] == results[1] diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index caf1dafdf8f..adfa05a021b 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -1,6 +1,8 @@ """Tests for the Device Registry.""" +import asyncio from unittest.mock import patch +import asynctest import pytest from homeassistant.helpers import device_registry @@ -370,3 +372,17 @@ async def test_update(registry): assert updated_entry != entry assert updated_entry.area_id == '12345A' assert updated_entry.name_by_user == 'Test Friendly Name' + + +async def test_loading_race_condition(hass): + """Test only one storage load called when concurrent loading occurred .""" + with asynctest.patch( + 'homeassistant.helpers.device_registry.DeviceRegistry.async_load', + ) as mock_load: + results = await asyncio.gather( + device_registry.async_get_registry(hass), + device_registry.async_get_registry(hass), + ) + + mock_load.assert_called_once_with() + assert results[0] == results[1] diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index b1c13a36c6d..3fb79f693bd 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -2,6 +2,7 @@ import asyncio from unittest.mock import patch +import asynctest import pytest from homeassistant.core import valid_entity_id @@ -19,7 +20,6 @@ def registry(hass): return mock_registry(hass) -@asyncio.coroutine def test_get_or_create_returns_same_entry(registry): """Make sure we do not duplicate entries.""" entry = registry.async_get_or_create('light', 'hue', '1234') @@ -30,7 +30,6 @@ def test_get_or_create_returns_same_entry(registry): assert entry.entity_id == 'light.hue_1234' -@asyncio.coroutine def test_get_or_create_suggested_object_id(registry): """Test that suggested_object_id works.""" entry = registry.async_get_or_create( @@ -39,7 +38,6 @@ def test_get_or_create_suggested_object_id(registry): assert entry.entity_id == 'light.beer' -@asyncio.coroutine def test_get_or_create_suggested_object_id_conflict_register(registry): """Test that we don't generate an entity id that is already registered.""" entry = registry.async_get_or_create( @@ -51,7 +49,6 @@ def test_get_or_create_suggested_object_id_conflict_register(registry): assert entry2.entity_id == 'light.beer_2' -@asyncio.coroutine def test_get_or_create_suggested_object_id_conflict_existing(hass, registry): """Test that we don't generate an entity id that currently exists.""" hass.states.async_set('light.hue_1234', 'on') @@ -59,7 +56,6 @@ def test_get_or_create_suggested_object_id_conflict_existing(hass, registry): assert entry.entity_id == 'light.hue_1234_2' -@asyncio.coroutine def test_create_triggers_save(hass, registry): """Test that registering entry triggers a save.""" with patch.object(registry, 'async_schedule_save') as mock_schedule_save: @@ -91,7 +87,6 @@ async def test_loading_saving_data(hass, registry): assert orig_entry2 == new_entry2 -@asyncio.coroutine def test_generate_entity_considers_registered_entities(registry): """Test that we don't create entity id that are already registered.""" entry = registry.async_get_or_create('light', 'hue', '1234') @@ -100,7 +95,6 @@ def test_generate_entity_considers_registered_entities(registry): 'light.hue_1234_2' -@asyncio.coroutine def test_generate_entity_considers_existing_entities(hass, registry): """Test that we don't create entity id that currently exists.""" hass.states.async_set('light.kitchen', 'on') @@ -108,7 +102,6 @@ def test_generate_entity_considers_existing_entities(hass, registry): 'light.kitchen_2' -@asyncio.coroutine def test_is_registered(registry): """Test that is_registered works.""" entry = registry.async_get_or_create('light', 'hue', '1234') @@ -166,7 +159,6 @@ async def test_loading_extra_values(hass, hass_storage): assert entry_disabled_user.disabled_by == entity_registry.DISABLED_USER -@asyncio.coroutine def test_async_get_entity_id(registry): """Test that entity_id is returned.""" entry = registry.async_get_or_create('light', 'hue', '1234') @@ -176,7 +168,7 @@ def test_async_get_entity_id(registry): assert registry.async_get_entity_id('light', 'hue', '123') is None -async def test_updating_config_entry_id(registry): +def test_updating_config_entry_id(registry): """Test that we update config entry id in registry.""" entry = registry.async_get_or_create( 'light', 'hue', '5678', config_entry_id='mock-id-1') @@ -186,7 +178,7 @@ async def test_updating_config_entry_id(registry): assert entry2.config_entry_id == 'mock-id-2' -async def test_removing_config_entry_id(registry): +def test_removing_config_entry_id(registry): """Test that we update config entry id in registry.""" entry = registry.async_get_or_create( 'light', 'hue', '5678', config_entry_id='mock-id-1') @@ -265,3 +257,17 @@ async def test_loading_invalid_entity_id(hass, hass_storage): 'test', 'super_platform', 'id-invalid-start') assert valid_entity_id(entity_invalid_start.entity_id) + + +async def test_loading_race_condition(hass): + """Test only one storage load called when concurrent loading occurred .""" + with asynctest.patch( + 'homeassistant.helpers.entity_registry.EntityRegistry.async_load', + ) as mock_load: + results = await asyncio.gather( + entity_registry.async_get_registry(hass), + entity_registry.async_get_registry(hass), + ) + + mock_load.assert_called_once_with() + assert results[0] == results[1]