From 40450b9cfdfeaa177b1580327526302b996babc0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 9 Apr 2021 07:14:33 -1000 Subject: [PATCH] Detach aiohttp.ClientSession created by config entry setup on unload (#48908) --- homeassistant/config_entries.py | 29 ++++++++ homeassistant/helpers/aiohttp_client.py | 79 ++++++++++++++++---- tests/test_config_entries.py | 99 ++++++++++++++++++++++++- tests/test_util/aiohttp.py | 2 +- 4 files changed, 192 insertions(+), 17 deletions(-) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 23758cf88f2..6ef14afb6a6 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from contextvars import ContextVar import functools import logging from types import MappingProxyType, MethodType @@ -133,6 +134,7 @@ class ConfigEntry: "_setup_lock", "update_listeners", "_async_cancel_retry_setup", + "_on_unload", ) def __init__( @@ -198,6 +200,9 @@ class ConfigEntry: # Function to cancel a scheduled retry self._async_cancel_retry_setup: Callable[[], Any] | None = None + # Hold list for functions to call on unload. + self._on_unload: list[CALLBACK_TYPE] | None = None + async def async_setup( self, hass: HomeAssistant, @@ -206,6 +211,7 @@ class ConfigEntry: tries: int = 0, ) -> None: """Set up an entry.""" + current_entry.set(self) if self.source == SOURCE_IGNORE or self.disabled_by: return @@ -290,6 +296,8 @@ class ConfigEntry: self._async_cancel_retry_setup = hass.bus.async_listen_once( EVENT_HOMEASSISTANT_STARTED, setup_again ) + + self._async_process_on_unload() return except Exception: # pylint: disable=broad-except _LOGGER.exception( @@ -358,6 +366,8 @@ class ConfigEntry: if result and integration.domain == self.domain: self.state = ENTRY_STATE_NOT_LOADED + self._async_process_on_unload() + return result except Exception: # pylint: disable=broad-except _LOGGER.exception( @@ -470,6 +480,25 @@ class ConfigEntry: "disabled_by": self.disabled_by, } + @callback + def async_on_unload(self, func: CALLBACK_TYPE) -> None: + """Add a function to call when config entry is unloaded.""" + if self._on_unload is None: + self._on_unload = [] + self._on_unload.append(func) + + @callback + def _async_process_on_unload(self) -> None: + """Process the on_unload callbacks.""" + if self._on_unload is not None: + while self._on_unload: + self._on_unload.pop()() + + +current_entry: ContextVar[ConfigEntry | None] = ContextVar( + "current_entry", default=None +) + class ConfigEntriesFlowManager(data_entry_flow.FlowManager): """Manage all the config entry flows that are in progress.""" diff --git a/homeassistant/helpers/aiohttp_client.py b/homeassistant/helpers/aiohttp_client.py index f3ded75062e..53b906efd35 100644 --- a/homeassistant/helpers/aiohttp_client.py +++ b/homeassistant/helpers/aiohttp_client.py @@ -5,7 +5,7 @@ import asyncio from contextlib import suppress from ssl import SSLContext import sys -from typing import Any, Awaitable, cast +from typing import Any, Awaitable, Callable, cast import aiohttp from aiohttp import web @@ -13,6 +13,7 @@ from aiohttp.hdrs import CONTENT_TYPE, USER_AGENT from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout import async_timeout +from homeassistant import config_entries from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__ from homeassistant.core import Event, HomeAssistant, callback from homeassistant.helpers.frame import warn_use @@ -27,6 +28,8 @@ SERVER_SOFTWARE = "HomeAssistant/{0} aiohttp/{1} Python/{2[0]}.{2[1]}".format( __version__, aiohttp.__version__, sys.version_info ) +WARN_CLOSE_MSG = "closes the Home Assistant aiohttp session" + @callback @bind_hass @@ -37,12 +40,14 @@ def async_get_clientsession( This method must be run in the event loop. """ - key = DATA_CLIENTSESSION_NOTVERIFY - if verify_ssl: - key = DATA_CLIENTSESSION + key = DATA_CLIENTSESSION if verify_ssl else DATA_CLIENTSESSION_NOTVERIFY if key not in hass.data: - hass.data[key] = async_create_clientsession(hass, verify_ssl) + hass.data[key] = _async_create_clientsession( + hass, + verify_ssl, + auto_cleanup_method=_async_register_default_clientsession_shutdown, + ) return cast(aiohttp.ClientSession, hass.data[key]) @@ -59,24 +64,44 @@ def async_create_clientsession( If auto_cleanup is False, you need to call detach() after the session returned is no longer used. Default is True, the session will be - automatically detached on homeassistant_stop. + automatically detached on homeassistant_stop or when being created + in config entry setup, the config entry is unloaded. This method must be run in the event loop. """ - connector = _async_get_connector(hass, verify_ssl) + auto_cleanup_method = None + if auto_cleanup: + auto_cleanup_method = _async_register_clientsession_shutdown + clientsession = _async_create_clientsession( + hass, + verify_ssl, + auto_cleanup_method=auto_cleanup_method, + **kwargs, + ) + + return clientsession + + +@callback +def _async_create_clientsession( + hass: HomeAssistant, + verify_ssl: bool = True, + auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None] + | None = None, + **kwargs: Any, +) -> aiohttp.ClientSession: + """Create a new ClientSession with kwargs, i.e. for cookies.""" clientsession = aiohttp.ClientSession( - connector=connector, + connector=_async_get_connector(hass, verify_ssl), headers={USER_AGENT: SERVER_SOFTWARE}, **kwargs, ) - clientsession.close = warn_use( # type: ignore - clientsession.close, "closes the Home Assistant aiohttp session" - ) + clientsession.close = warn_use(clientsession.close, WARN_CLOSE_MSG) # type: ignore - if auto_cleanup: - _async_register_clientsession_shutdown(hass, clientsession) + if auto_cleanup_method: + auto_cleanup_method(hass, clientsession) return clientsession @@ -146,7 +171,33 @@ async def async_aiohttp_proxy_stream( def _async_register_clientsession_shutdown( hass: HomeAssistant, clientsession: aiohttp.ClientSession ) -> None: - """Register ClientSession close on Home Assistant shutdown. + """Register ClientSession close on Home Assistant shutdown or config entry unload. + + This method must be run in the event loop. + """ + + @callback + def _async_close_websession(*_: Any) -> None: + """Close websession.""" + clientsession.detach() + + unsub = hass.bus.async_listen_once( + EVENT_HOMEASSISTANT_CLOSE, _async_close_websession + ) + + config_entry = config_entries.current_entry.get() + if not config_entry: + return + + config_entry.async_on_unload(unsub) + config_entry.async_on_unload(_async_close_websession) + + +@callback +def _async_register_default_clientsession_shutdown( + hass: HomeAssistant, clientsession: aiohttp.ClientSession +) -> None: + """Register default ClientSession close on Home Assistant shutdown. This method must be run in the event loop. """ diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index c35ba61a767..24d635d52a3 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -1,15 +1,16 @@ """Test the config manager.""" import asyncio from datetime import timedelta -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from homeassistant import config_entries, data_entry_flow, loader -from homeassistant.const import EVENT_HOMEASSISTANT_STARTED +from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, EVENT_HOMEASSISTANT_STARTED from homeassistant.core import CoreState, callback from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError from homeassistant.helpers import entity_registry as er +from homeassistant.helpers.aiohttp_client import async_create_clientsession from homeassistant.setup import async_setup_component from homeassistant.util import dt @@ -2489,3 +2490,97 @@ async def test_updating_entry_with_and_without_changes(manager): assert manager.async_update_entry(entry, title="newtitle") is True assert manager.async_update_entry(entry, unique_id="abc123") is False assert manager.async_update_entry(entry, unique_id="abc1234") is True + + +async def test_entry_reload_calls_on_unload_listeners(hass, manager): + """Test reload calls the on unload listeners.""" + entry = MockConfigEntry(domain="comp", state=config_entries.ENTRY_STATE_LOADED) + entry.add_to_hass(hass) + + async_setup = AsyncMock(return_value=True) + mock_setup_entry = AsyncMock(return_value=True) + async_unload_entry = AsyncMock(return_value=True) + + mock_integration( + hass, + MockModule( + "comp", + async_setup=async_setup, + async_setup_entry=mock_setup_entry, + async_unload_entry=async_unload_entry, + ), + ) + mock_entity_platform(hass, "config_flow.comp", None) + + mock_unload_callback = Mock() + + entry.async_on_unload(mock_unload_callback) + + assert await manager.async_reload(entry.entry_id) + assert len(async_unload_entry.mock_calls) == 1 + assert len(mock_setup_entry.mock_calls) == 1 + assert len(mock_unload_callback.mock_calls) == 1 + assert entry.state == config_entries.ENTRY_STATE_LOADED + + assert await manager.async_reload(entry.entry_id) + assert len(async_unload_entry.mock_calls) == 2 + assert len(mock_setup_entry.mock_calls) == 2 + # Since we did not register another async_on_unload it should + # have only been called once + assert len(mock_unload_callback.mock_calls) == 1 + assert entry.state == config_entries.ENTRY_STATE_LOADED + + +async def test_entry_reload_cleans_up_aiohttp_session(hass, manager): + """Test reload cleans up aiohttp sessions their close listener created by the config entry.""" + entry = MockConfigEntry(domain="comp", state=config_entries.ENTRY_STATE_LOADED) + entry.add_to_hass(hass) + async_setup_calls = 0 + + async def async_setup_entry(hass, _): + """Mock setup entry.""" + nonlocal async_setup_calls + async_setup_calls += 1 + async_create_clientsession(hass) + return True + + async_setup = AsyncMock(return_value=True) + async_unload_entry = AsyncMock(return_value=True) + + mock_integration( + hass, + MockModule( + "comp", + async_setup=async_setup, + async_setup_entry=async_setup_entry, + async_unload_entry=async_unload_entry, + ), + ) + mock_entity_platform(hass, "config_flow.comp", None) + + assert await manager.async_reload(entry.entry_id) + assert len(async_unload_entry.mock_calls) == 1 + assert async_setup_calls == 1 + assert entry.state == config_entries.ENTRY_STATE_LOADED + + original_close_listeners = hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE] + + assert await manager.async_reload(entry.entry_id) + assert len(async_unload_entry.mock_calls) == 2 + assert async_setup_calls == 2 + assert entry.state == config_entries.ENTRY_STATE_LOADED + + assert ( + hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE] + == original_close_listeners + ) + + assert await manager.async_reload(entry.entry_id) + assert len(async_unload_entry.mock_calls) == 3 + assert async_setup_calls == 3 + assert entry.state == config_entries.ENTRY_STATE_LOADED + + assert ( + hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE] + == original_close_listeners + ) diff --git a/tests/test_util/aiohttp.py b/tests/test_util/aiohttp.py index 5219212f1cf..58e4c6a2275 100644 --- a/tests/test_util/aiohttp.py +++ b/tests/test_util/aiohttp.py @@ -288,7 +288,7 @@ def mock_aiohttp_client(): return session with mock.patch( - "homeassistant.helpers.aiohttp_client.async_create_clientsession", + "homeassistant.helpers.aiohttp_client._async_create_clientsession", side_effect=create_session, ): yield mocker