mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Detach aiohttp.ClientSession created by config entry setup on unload (#48908)
This commit is contained in:
parent
8e2b5b36b5
commit
40450b9cfd
@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from contextvars import ContextVar
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from types import MappingProxyType, MethodType
|
from types import MappingProxyType, MethodType
|
||||||
@ -133,6 +134,7 @@ class ConfigEntry:
|
|||||||
"_setup_lock",
|
"_setup_lock",
|
||||||
"update_listeners",
|
"update_listeners",
|
||||||
"_async_cancel_retry_setup",
|
"_async_cancel_retry_setup",
|
||||||
|
"_on_unload",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -198,6 +200,9 @@ class ConfigEntry:
|
|||||||
# Function to cancel a scheduled retry
|
# Function to cancel a scheduled retry
|
||||||
self._async_cancel_retry_setup: Callable[[], Any] | None = None
|
self._async_cancel_retry_setup: Callable[[], Any] | None = None
|
||||||
|
|
||||||
|
# Hold list for functions to call on unload.
|
||||||
|
self._on_unload: list[CALLBACK_TYPE] | None = None
|
||||||
|
|
||||||
async def async_setup(
|
async def async_setup(
|
||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -206,6 +211,7 @@ class ConfigEntry:
|
|||||||
tries: int = 0,
|
tries: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up an entry."""
|
"""Set up an entry."""
|
||||||
|
current_entry.set(self)
|
||||||
if self.source == SOURCE_IGNORE or self.disabled_by:
|
if self.source == SOURCE_IGNORE or self.disabled_by:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -290,6 +296,8 @@ class ConfigEntry:
|
|||||||
self._async_cancel_retry_setup = hass.bus.async_listen_once(
|
self._async_cancel_retry_setup = hass.bus.async_listen_once(
|
||||||
EVENT_HOMEASSISTANT_STARTED, setup_again
|
EVENT_HOMEASSISTANT_STARTED, setup_again
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._async_process_on_unload()
|
||||||
return
|
return
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
_LOGGER.exception(
|
_LOGGER.exception(
|
||||||
@ -358,6 +366,8 @@ class ConfigEntry:
|
|||||||
if result and integration.domain == self.domain:
|
if result and integration.domain == self.domain:
|
||||||
self.state = ENTRY_STATE_NOT_LOADED
|
self.state = ENTRY_STATE_NOT_LOADED
|
||||||
|
|
||||||
|
self._async_process_on_unload()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
_LOGGER.exception(
|
_LOGGER.exception(
|
||||||
@ -470,6 +480,25 @@ class ConfigEntry:
|
|||||||
"disabled_by": self.disabled_by,
|
"disabled_by": self.disabled_by,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_on_unload(self, func: CALLBACK_TYPE) -> None:
|
||||||
|
"""Add a function to call when config entry is unloaded."""
|
||||||
|
if self._on_unload is None:
|
||||||
|
self._on_unload = []
|
||||||
|
self._on_unload.append(func)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_process_on_unload(self) -> None:
|
||||||
|
"""Process the on_unload callbacks."""
|
||||||
|
if self._on_unload is not None:
|
||||||
|
while self._on_unload:
|
||||||
|
self._on_unload.pop()()
|
||||||
|
|
||||||
|
|
||||||
|
current_entry: ContextVar[ConfigEntry | None] = ContextVar(
|
||||||
|
"current_entry", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
|
class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
|
||||||
"""Manage all the config entry flows that are in progress."""
|
"""Manage all the config entry flows that are in progress."""
|
||||||
|
@ -5,7 +5,7 @@ import asyncio
|
|||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from ssl import SSLContext
|
from ssl import SSLContext
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Awaitable, cast
|
from typing import Any, Awaitable, Callable, cast
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
@ -13,6 +13,7 @@ from aiohttp.hdrs import CONTENT_TYPE, USER_AGENT
|
|||||||
from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout
|
from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout
|
||||||
import async_timeout
|
import async_timeout
|
||||||
|
|
||||||
|
from homeassistant import config_entries
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__
|
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__
|
||||||
from homeassistant.core import Event, HomeAssistant, callback
|
from homeassistant.core import Event, HomeAssistant, callback
|
||||||
from homeassistant.helpers.frame import warn_use
|
from homeassistant.helpers.frame import warn_use
|
||||||
@ -27,6 +28,8 @@ SERVER_SOFTWARE = "HomeAssistant/{0} aiohttp/{1} Python/{2[0]}.{2[1]}".format(
|
|||||||
__version__, aiohttp.__version__, sys.version_info
|
__version__, aiohttp.__version__, sys.version_info
|
||||||
)
|
)
|
||||||
|
|
||||||
|
WARN_CLOSE_MSG = "closes the Home Assistant aiohttp session"
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@bind_hass
|
@bind_hass
|
||||||
@ -37,12 +40,14 @@ def async_get_clientsession(
|
|||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
key = DATA_CLIENTSESSION_NOTVERIFY
|
key = DATA_CLIENTSESSION if verify_ssl else DATA_CLIENTSESSION_NOTVERIFY
|
||||||
if verify_ssl:
|
|
||||||
key = DATA_CLIENTSESSION
|
|
||||||
|
|
||||||
if key not in hass.data:
|
if key not in hass.data:
|
||||||
hass.data[key] = async_create_clientsession(hass, verify_ssl)
|
hass.data[key] = _async_create_clientsession(
|
||||||
|
hass,
|
||||||
|
verify_ssl,
|
||||||
|
auto_cleanup_method=_async_register_default_clientsession_shutdown,
|
||||||
|
)
|
||||||
|
|
||||||
return cast(aiohttp.ClientSession, hass.data[key])
|
return cast(aiohttp.ClientSession, hass.data[key])
|
||||||
|
|
||||||
@ -59,24 +64,44 @@ def async_create_clientsession(
|
|||||||
|
|
||||||
If auto_cleanup is False, you need to call detach() after the session
|
If auto_cleanup is False, you need to call detach() after the session
|
||||||
returned is no longer used. Default is True, the session will be
|
returned is no longer used. Default is True, the session will be
|
||||||
automatically detached on homeassistant_stop.
|
automatically detached on homeassistant_stop or when being created
|
||||||
|
in config entry setup, the config entry is unloaded.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
connector = _async_get_connector(hass, verify_ssl)
|
auto_cleanup_method = None
|
||||||
|
if auto_cleanup:
|
||||||
|
auto_cleanup_method = _async_register_clientsession_shutdown
|
||||||
|
|
||||||
|
clientsession = _async_create_clientsession(
|
||||||
|
hass,
|
||||||
|
verify_ssl,
|
||||||
|
auto_cleanup_method=auto_cleanup_method,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return clientsession
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_create_clientsession(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
verify_ssl: bool = True,
|
||||||
|
auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None]
|
||||||
|
| None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> aiohttp.ClientSession:
|
||||||
|
"""Create a new ClientSession with kwargs, i.e. for cookies."""
|
||||||
clientsession = aiohttp.ClientSession(
|
clientsession = aiohttp.ClientSession(
|
||||||
connector=connector,
|
connector=_async_get_connector(hass, verify_ssl),
|
||||||
headers={USER_AGENT: SERVER_SOFTWARE},
|
headers={USER_AGENT: SERVER_SOFTWARE},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
clientsession.close = warn_use( # type: ignore
|
clientsession.close = warn_use(clientsession.close, WARN_CLOSE_MSG) # type: ignore
|
||||||
clientsession.close, "closes the Home Assistant aiohttp session"
|
|
||||||
)
|
|
||||||
|
|
||||||
if auto_cleanup:
|
if auto_cleanup_method:
|
||||||
_async_register_clientsession_shutdown(hass, clientsession)
|
auto_cleanup_method(hass, clientsession)
|
||||||
|
|
||||||
return clientsession
|
return clientsession
|
||||||
|
|
||||||
@ -146,7 +171,33 @@ async def async_aiohttp_proxy_stream(
|
|||||||
def _async_register_clientsession_shutdown(
|
def _async_register_clientsession_shutdown(
|
||||||
hass: HomeAssistant, clientsession: aiohttp.ClientSession
|
hass: HomeAssistant, clientsession: aiohttp.ClientSession
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register ClientSession close on Home Assistant shutdown.
|
"""Register ClientSession close on Home Assistant shutdown or config entry unload.
|
||||||
|
|
||||||
|
This method must be run in the event loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_close_websession(*_: Any) -> None:
|
||||||
|
"""Close websession."""
|
||||||
|
clientsession.detach()
|
||||||
|
|
||||||
|
unsub = hass.bus.async_listen_once(
|
||||||
|
EVENT_HOMEASSISTANT_CLOSE, _async_close_websession
|
||||||
|
)
|
||||||
|
|
||||||
|
config_entry = config_entries.current_entry.get()
|
||||||
|
if not config_entry:
|
||||||
|
return
|
||||||
|
|
||||||
|
config_entry.async_on_unload(unsub)
|
||||||
|
config_entry.async_on_unload(_async_close_websession)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_register_default_clientsession_shutdown(
|
||||||
|
hass: HomeAssistant, clientsession: aiohttp.ClientSession
|
||||||
|
) -> None:
|
||||||
|
"""Register default ClientSession close on Home Assistant shutdown.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
"""
|
"""
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
"""Test the config manager."""
|
"""Test the config manager."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant import config_entries, data_entry_flow, loader
|
from homeassistant import config_entries, data_entry_flow, loader
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
|
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, EVENT_HOMEASSISTANT_STARTED
|
||||||
from homeassistant.core import CoreState, callback
|
from homeassistant.core import CoreState, callback
|
||||||
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
|
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
|
||||||
from homeassistant.helpers import entity_registry as er
|
from homeassistant.helpers import entity_registry as er
|
||||||
|
from homeassistant.helpers.aiohttp_client import async_create_clientsession
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.util import dt
|
from homeassistant.util import dt
|
||||||
|
|
||||||
@ -2489,3 +2490,97 @@ async def test_updating_entry_with_and_without_changes(manager):
|
|||||||
assert manager.async_update_entry(entry, title="newtitle") is True
|
assert manager.async_update_entry(entry, title="newtitle") is True
|
||||||
assert manager.async_update_entry(entry, unique_id="abc123") is False
|
assert manager.async_update_entry(entry, unique_id="abc123") is False
|
||||||
assert manager.async_update_entry(entry, unique_id="abc1234") is True
|
assert manager.async_update_entry(entry, unique_id="abc1234") is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_entry_reload_calls_on_unload_listeners(hass, manager):
|
||||||
|
"""Test reload calls the on unload listeners."""
|
||||||
|
entry = MockConfigEntry(domain="comp", state=config_entries.ENTRY_STATE_LOADED)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
async_setup = AsyncMock(return_value=True)
|
||||||
|
mock_setup_entry = AsyncMock(return_value=True)
|
||||||
|
async_unload_entry = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
mock_integration(
|
||||||
|
hass,
|
||||||
|
MockModule(
|
||||||
|
"comp",
|
||||||
|
async_setup=async_setup,
|
||||||
|
async_setup_entry=mock_setup_entry,
|
||||||
|
async_unload_entry=async_unload_entry,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_entity_platform(hass, "config_flow.comp", None)
|
||||||
|
|
||||||
|
mock_unload_callback = Mock()
|
||||||
|
|
||||||
|
entry.async_on_unload(mock_unload_callback)
|
||||||
|
|
||||||
|
assert await manager.async_reload(entry.entry_id)
|
||||||
|
assert len(async_unload_entry.mock_calls) == 1
|
||||||
|
assert len(mock_setup_entry.mock_calls) == 1
|
||||||
|
assert len(mock_unload_callback.mock_calls) == 1
|
||||||
|
assert entry.state == config_entries.ENTRY_STATE_LOADED
|
||||||
|
|
||||||
|
assert await manager.async_reload(entry.entry_id)
|
||||||
|
assert len(async_unload_entry.mock_calls) == 2
|
||||||
|
assert len(mock_setup_entry.mock_calls) == 2
|
||||||
|
# Since we did not register another async_on_unload it should
|
||||||
|
# have only been called once
|
||||||
|
assert len(mock_unload_callback.mock_calls) == 1
|
||||||
|
assert entry.state == config_entries.ENTRY_STATE_LOADED
|
||||||
|
|
||||||
|
|
||||||
|
async def test_entry_reload_cleans_up_aiohttp_session(hass, manager):
|
||||||
|
"""Test reload cleans up aiohttp sessions their close listener created by the config entry."""
|
||||||
|
entry = MockConfigEntry(domain="comp", state=config_entries.ENTRY_STATE_LOADED)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
async_setup_calls = 0
|
||||||
|
|
||||||
|
async def async_setup_entry(hass, _):
|
||||||
|
"""Mock setup entry."""
|
||||||
|
nonlocal async_setup_calls
|
||||||
|
async_setup_calls += 1
|
||||||
|
async_create_clientsession(hass)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async_setup = AsyncMock(return_value=True)
|
||||||
|
async_unload_entry = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
mock_integration(
|
||||||
|
hass,
|
||||||
|
MockModule(
|
||||||
|
"comp",
|
||||||
|
async_setup=async_setup,
|
||||||
|
async_setup_entry=async_setup_entry,
|
||||||
|
async_unload_entry=async_unload_entry,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_entity_platform(hass, "config_flow.comp", None)
|
||||||
|
|
||||||
|
assert await manager.async_reload(entry.entry_id)
|
||||||
|
assert len(async_unload_entry.mock_calls) == 1
|
||||||
|
assert async_setup_calls == 1
|
||||||
|
assert entry.state == config_entries.ENTRY_STATE_LOADED
|
||||||
|
|
||||||
|
original_close_listeners = hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE]
|
||||||
|
|
||||||
|
assert await manager.async_reload(entry.entry_id)
|
||||||
|
assert len(async_unload_entry.mock_calls) == 2
|
||||||
|
assert async_setup_calls == 2
|
||||||
|
assert entry.state == config_entries.ENTRY_STATE_LOADED
|
||||||
|
|
||||||
|
assert (
|
||||||
|
hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE]
|
||||||
|
== original_close_listeners
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await manager.async_reload(entry.entry_id)
|
||||||
|
assert len(async_unload_entry.mock_calls) == 3
|
||||||
|
assert async_setup_calls == 3
|
||||||
|
assert entry.state == config_entries.ENTRY_STATE_LOADED
|
||||||
|
|
||||||
|
assert (
|
||||||
|
hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE]
|
||||||
|
== original_close_listeners
|
||||||
|
)
|
||||||
|
@ -288,7 +288,7 @@ def mock_aiohttp_client():
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
with mock.patch(
|
with mock.patch(
|
||||||
"homeassistant.helpers.aiohttp_client.async_create_clientsession",
|
"homeassistant.helpers.aiohttp_client._async_create_clientsession",
|
||||||
side_effect=create_session,
|
side_effect=create_session,
|
||||||
):
|
):
|
||||||
yield mocker
|
yield mocker
|
||||||
|
Loading…
x
Reference in New Issue
Block a user