Refactor ConfigStore to avoid needing to pass config_dir (#114827)

Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
J. Nick Koston 2024-04-04 09:30:10 -10:00 committed by GitHub
parent 56ef9500f7
commit cceea6dac2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 13 deletions

View File

@ -405,6 +405,7 @@ class HomeAssistant:
self.services = ServiceRegistry(self) self.services = ServiceRegistry(self)
self.states = StateMachine(self.bus, self.loop) self.states = StateMachine(self.bus, self.loop)
self.config = Config(self, config_dir) self.config = Config(self, config_dir)
self.config.async_initialize()
self.components = loader.Components(self) self.components = loader.Components(self)
self.helpers = loader.Helpers(self) self.helpers = loader.Helpers(self)
self.state: CoreState = CoreState.not_running self.state: CoreState = CoreState.not_running
@ -2600,12 +2601,12 @@ class ServiceRegistry:
class Config: class Config:
"""Configuration settings for Home Assistant.""" """Configuration settings for Home Assistant."""
_store: Config._ConfigStore
def __init__(self, hass: HomeAssistant, config_dir: str) -> None: def __init__(self, hass: HomeAssistant, config_dir: str) -> None:
"""Initialize a new config object.""" """Initialize a new config object."""
self.hass = hass self.hass = hass
self._store = self._ConfigStore(self.hass, config_dir)
self.latitude: float = 0 self.latitude: float = 0
self.longitude: float = 0 self.longitude: float = 0
@ -2656,6 +2657,13 @@ class Config:
# If Home Assistant is running in safe mode # If Home Assistant is running in safe mode
self.safe_mode: bool = False self.safe_mode: bool = False
def async_initialize(self) -> None:
"""Finish initializing a config object.
This must be called before the config object is used.
"""
self._store = self._ConfigStore(self.hass)
def distance(self, lat: float, lon: float) -> float | None: def distance(self, lat: float, lon: float) -> float | None:
"""Calculate distance from Home Assistant. """Calculate distance from Home Assistant.
@ -2862,7 +2870,6 @@ class Config:
"country": self.country, "country": self.country,
"language": self.language, "language": self.language,
} }
await self._store.async_save(data) await self._store.async_save(data)
# Circular dependency prevents us from generating the class at top level # Circular dependency prevents us from generating the class at top level
@ -2872,7 +2879,7 @@ class Config:
class _ConfigStore(Store[dict[str, Any]]): class _ConfigStore(Store[dict[str, Any]]):
"""Class to help storing Config data.""" """Class to help storing Config data."""
def __init__(self, hass: HomeAssistant, config_dir: str) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize storage class.""" """Initialize storage class."""
super().__init__( super().__init__(
hass, hass,
@ -2881,7 +2888,6 @@ class Config:
private=True, private=True,
atomic_writes=True, atomic_writes=True,
minor_version=CORE_STORAGE_MINOR_VERSION, minor_version=CORE_STORAGE_MINOR_VERSION,
config_dir=config_dir,
) )
self._original_unit_system: str | None = None # from old store 1.1 self._original_unit_system: str | None = None # from old store 1.1

View File

@ -90,9 +90,7 @@ async def async_migrator(
return config return config
def get_internal_store_manager( def get_internal_store_manager(hass: HomeAssistant) -> _StoreManager:
hass: HomeAssistant, config_dir: str | None = None
) -> _StoreManager:
"""Get the store manager. """Get the store manager.
This function is not part of the API and should only be This function is not part of the API and should only be
@ -100,7 +98,7 @@ def get_internal_store_manager(
guaranteed to be stable. guaranteed to be stable.
""" """
if STORAGE_MANAGER not in hass.data: if STORAGE_MANAGER not in hass.data:
manager = _StoreManager(hass, config_dir or hass.config.config_dir) manager = _StoreManager(hass)
hass.data[STORAGE_MANAGER] = manager hass.data[STORAGE_MANAGER] = manager
return hass.data[STORAGE_MANAGER] return hass.data[STORAGE_MANAGER]
@ -111,13 +109,13 @@ class _StoreManager:
The store manager is used to cache and manage storage files. The store manager is used to cache and manage storage files.
""" """
def __init__(self, hass: HomeAssistant, config_dir: str) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize storage manager class.""" """Initialize storage manager class."""
self._hass = hass self._hass = hass
self._invalidated: set[str] = set() self._invalidated: set[str] = set()
self._files: set[str] | None = None self._files: set[str] | None = None
self._data_preload: dict[str, json_util.JsonValueType] = {} self._data_preload: dict[str, json_util.JsonValueType] = {}
self._storage_path: Path = Path(config_dir).joinpath(STORAGE_DIR) self._storage_path: Path = Path(hass.config.config_dir).joinpath(STORAGE_DIR)
self._cancel_cleanup: asyncio.TimerHandle | None = None self._cancel_cleanup: asyncio.TimerHandle | None = None
async def async_initialize(self) -> None: async def async_initialize(self) -> None:
@ -246,7 +244,6 @@ class Store(Generic[_T]):
encoder: type[JSONEncoder] | None = None, encoder: type[JSONEncoder] | None = None,
minor_version: int = 1, minor_version: int = 1,
read_only: bool = False, read_only: bool = False,
config_dir: str | None = None,
) -> None: ) -> None:
"""Initialize storage class.""" """Initialize storage class."""
self.version = version self.version = version
@ -263,7 +260,7 @@ class Store(Generic[_T]):
self._atomic_writes = atomic_writes self._atomic_writes = atomic_writes
self._read_only = read_only self._read_only = read_only
self._next_write_time = 0.0 self._next_write_time = 0.0
self._manager = get_internal_store_manager(hass, config_dir) self._manager = get_internal_store_manager(hass)
@cached_property @cached_property
def path(self): def path(self):

View File

@ -2288,6 +2288,7 @@ async def test_additional_data_in_core_config(
) -> None: ) -> None:
"""Test that we can handle additional data in core configuration.""" """Test that we can handle additional data in core configuration."""
config = ha.Config(hass, "/test/ha-config") config = ha.Config(hass, "/test/ha-config")
config.async_initialize()
hass_storage[ha.CORE_STORAGE_KEY] = { hass_storage[ha.CORE_STORAGE_KEY] = {
"version": 1, "version": 1,
"data": {"location_name": "Test Name", "additional_valid_key": "value"}, "data": {"location_name": "Test Name", "additional_valid_key": "value"},
@ -2301,6 +2302,7 @@ async def test_incorrect_internal_external_url(
) -> None: ) -> None:
"""Test that we warn when detecting invalid internal/external url.""" """Test that we warn when detecting invalid internal/external url."""
config = ha.Config(hass, "/test/ha-config") config = ha.Config(hass, "/test/ha-config")
config.async_initialize()
hass_storage[ha.CORE_STORAGE_KEY] = { hass_storage[ha.CORE_STORAGE_KEY] = {
"version": 1, "version": 1,
@ -2314,6 +2316,7 @@ async def test_incorrect_internal_external_url(
assert "Invalid internal_url set" not in caplog.text assert "Invalid internal_url set" not in caplog.text
config = ha.Config(hass, "/test/ha-config") config = ha.Config(hass, "/test/ha-config")
config.async_initialize()
hass_storage[ha.CORE_STORAGE_KEY] = { hass_storage[ha.CORE_STORAGE_KEY] = {
"version": 1, "version": 1,