Use runtime_data in geo_json_events (#138366)

* Use runtime_data in geo_json_events

* Update __init__.py
This commit is contained in:
epenet 2025-02-12 12:42:22 +01:00 committed by GitHub
parent f1471f143c
commit 2bb582f8e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 21 additions and 28 deletions

View File

@ -4,25 +4,27 @@ from __future__ import annotations
import logging import logging
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from .const import DOMAIN, PLATFORMS from .const import PLATFORMS
from .manager import GeoJsonFeedEntityManager from .manager import GeoJsonConfigEntry, GeoJsonFeedEntityManager
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: async def async_setup_entry(
hass: HomeAssistant, config_entry: GeoJsonConfigEntry
) -> bool:
"""Set up the GeoJSON events component as config entry.""" """Set up the GeoJSON events component as config entry."""
feeds = hass.data.setdefault(DOMAIN, {})
# Create feed entity manager for all platforms. # Create feed entity manager for all platforms.
manager = GeoJsonFeedEntityManager(hass, config_entry) manager = GeoJsonFeedEntityManager(hass, config_entry)
feeds[config_entry.entry_id] = manager
_LOGGER.debug("Feed entity manager added for %s", config_entry.entry_id) _LOGGER.debug("Feed entity manager added for %s", config_entry.entry_id)
await remove_orphaned_entities(hass, config_entry.entry_id) await remove_orphaned_entities(hass, config_entry.entry_id)
config_entry.runtime_data = manager
config_entry.async_on_unload(manager.async_stop)
await hass.config_entries.async_forward_entry_setups(config_entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(config_entry, PLATFORMS)
await manager.async_init() await manager.async_init()
return True return True
@ -46,10 +48,6 @@ async def remove_orphaned_entities(hass: HomeAssistant, entry_id: str) -> None:
entity_registry.async_remove(entry.entity_id) entity_registry.async_remove(entry.entity_id)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: GeoJsonConfigEntry) -> bool:
"""Unload the GeoJSON events config entry.""" """Unload the GeoJSON events config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok:
manager: GeoJsonFeedEntityManager = hass.data[DOMAIN].pop(entry.entry_id)
await manager.async_stop()
return unload_ok

View File

@ -9,31 +9,24 @@ from typing import Any
from aio_geojson_generic_client.feed_entry import GenericFeedEntry from aio_geojson_generic_client.feed_entry import GenericFeedEntry
from homeassistant.components.geo_location import GeolocationEvent from homeassistant.components.geo_location import GeolocationEvent
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import UnitOfLength from homeassistant.const import UnitOfLength
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import GeoJsonFeedEntityManager from .const import ATTR_EXTERNAL_ID, SIGNAL_DELETE_ENTITY, SIGNAL_UPDATE_ENTITY, SOURCE
from .const import ( from .manager import GeoJsonConfigEntry, GeoJsonFeedEntityManager
ATTR_EXTERNAL_ID,
DOMAIN,
SIGNAL_DELETE_ENTITY,
SIGNAL_UPDATE_ENTITY,
SOURCE,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
entry: ConfigEntry, entry: GeoJsonConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback, async_add_entities: AddConfigEntryEntitiesCallback,
) -> None: ) -> None:
"""Set up the GeoJSON Events platform.""" """Set up the GeoJSON Events platform."""
manager: GeoJsonFeedEntityManager = hass.data[DOMAIN][entry.entry_id] manager = entry.runtime_data
@callback @callback
def async_add_geolocation( def async_add_geolocation(

View File

@ -25,6 +25,8 @@ from .const import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
type GeoJsonConfigEntry = ConfigEntry[GeoJsonFeedEntityManager]
class GeoJsonFeedEntityManager: class GeoJsonFeedEntityManager:
"""Feed Entity Manager for GeoJSON feeds.""" """Feed Entity Manager for GeoJSON feeds."""

View File

@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
from homeassistant.components.geo_json_events import DOMAIN from homeassistant.components.geo_json_events.const import DOMAIN
from homeassistant.const import CONF_LATITUDE, CONF_LONGITUDE, CONF_RADIUS, CONF_URL from homeassistant.const import CONF_LATITUDE, CONF_LONGITUDE, CONF_RADIUS, CONF_URL
from tests.common import MockConfigEntry from tests.common import MockConfigEntry

View File

@ -3,7 +3,7 @@
import pytest import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.geo_json_events import DOMAIN from homeassistant.components.geo_json_events.const import DOMAIN
from homeassistant.const import ( from homeassistant.const import (
CONF_LATITUDE, CONF_LATITUDE,
CONF_LOCATION, CONF_LOCATION,

View File

@ -2,8 +2,8 @@
from unittest.mock import patch from unittest.mock import patch
from homeassistant.components.geo_json_events.const import DOMAIN
from homeassistant.components.geo_location import DOMAIN as GEO_LOCATION_DOMAIN from homeassistant.components.geo_location import DOMAIN as GEO_LOCATION_DOMAIN
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
@ -24,11 +24,11 @@ async def test_component_unload_config_entry(
assert await hass.config_entries.async_setup(config_entry.entry_id) assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_feed_manager_update.call_count == 1 assert mock_feed_manager_update.call_count == 1
assert hass.data[DOMAIN][config_entry.entry_id] is not None assert config_entry.state is ConfigEntryState.LOADED
# Unload config entry. # Unload config entry.
assert await hass.config_entries.async_unload(config_entry.entry_id) assert await hass.config_entries.async_unload(config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert hass.data[DOMAIN].get(config_entry.entry_id) is None assert config_entry.state is ConfigEntryState.NOT_LOADED
async def test_remove_orphaned_entities( async def test_remove_orphaned_entities(