Detach aiohttp.ClientSession created by config entry setup on unload (#48908)

This commit is contained in:
J. Nick Koston 2021-04-09 07:14:33 -10:00 committed by GitHub
parent 8e2b5b36b5
commit 40450b9cfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 192 additions and 17 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from contextvars import ContextVar
import functools import functools
import logging import logging
from types import MappingProxyType, MethodType from types import MappingProxyType, MethodType
@ -133,6 +134,7 @@ class ConfigEntry:
"_setup_lock", "_setup_lock",
"update_listeners", "update_listeners",
"_async_cancel_retry_setup", "_async_cancel_retry_setup",
"_on_unload",
) )
def __init__( def __init__(
@ -198,6 +200,9 @@ class ConfigEntry:
# Function to cancel a scheduled retry # Function to cancel a scheduled retry
self._async_cancel_retry_setup: Callable[[], Any] | None = None self._async_cancel_retry_setup: Callable[[], Any] | None = None
# Hold list for functions to call on unload.
self._on_unload: list[CALLBACK_TYPE] | None = None
async def async_setup( async def async_setup(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
@ -206,6 +211,7 @@ class ConfigEntry:
tries: int = 0, tries: int = 0,
) -> None: ) -> None:
"""Set up an entry.""" """Set up an entry."""
current_entry.set(self)
if self.source == SOURCE_IGNORE or self.disabled_by: if self.source == SOURCE_IGNORE or self.disabled_by:
return return
@ -290,6 +296,8 @@ class ConfigEntry:
self._async_cancel_retry_setup = hass.bus.async_listen_once( self._async_cancel_retry_setup = hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STARTED, setup_again EVENT_HOMEASSISTANT_STARTED, setup_again
) )
self._async_process_on_unload()
return return
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception( _LOGGER.exception(
@ -358,6 +366,8 @@ class ConfigEntry:
if result and integration.domain == self.domain: if result and integration.domain == self.domain:
self.state = ENTRY_STATE_NOT_LOADED self.state = ENTRY_STATE_NOT_LOADED
self._async_process_on_unload()
return result return result
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception( _LOGGER.exception(
@ -470,6 +480,25 @@ class ConfigEntry:
"disabled_by": self.disabled_by, "disabled_by": self.disabled_by,
} }
@callback
def async_on_unload(self, func: CALLBACK_TYPE) -> None:
"""Add a function to call when config entry is unloaded."""
if self._on_unload is None:
self._on_unload = []
self._on_unload.append(func)
@callback
def _async_process_on_unload(self) -> None:
"""Process the on_unload callbacks."""
if self._on_unload is not None:
while self._on_unload:
self._on_unload.pop()()
current_entry: ContextVar[ConfigEntry | None] = ContextVar(
"current_entry", default=None
)
class ConfigEntriesFlowManager(data_entry_flow.FlowManager): class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
"""Manage all the config entry flows that are in progress.""" """Manage all the config entry flows that are in progress."""

View File

@ -5,7 +5,7 @@ import asyncio
from contextlib import suppress from contextlib import suppress
from ssl import SSLContext from ssl import SSLContext
import sys import sys
from typing import Any, Awaitable, cast from typing import Any, Awaitable, Callable, cast
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
@ -13,6 +13,7 @@ from aiohttp.hdrs import CONTENT_TYPE, USER_AGENT
from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout
import async_timeout import async_timeout
from homeassistant import config_entries
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__ from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers.frame import warn_use from homeassistant.helpers.frame import warn_use
@ -27,6 +28,8 @@ SERVER_SOFTWARE = "HomeAssistant/{0} aiohttp/{1} Python/{2[0]}.{2[1]}".format(
__version__, aiohttp.__version__, sys.version_info __version__, aiohttp.__version__, sys.version_info
) )
WARN_CLOSE_MSG = "closes the Home Assistant aiohttp session"
@callback @callback
@bind_hass @bind_hass
@ -37,12 +40,14 @@ def async_get_clientsession(
This method must be run in the event loop. This method must be run in the event loop.
""" """
key = DATA_CLIENTSESSION_NOTVERIFY key = DATA_CLIENTSESSION if verify_ssl else DATA_CLIENTSESSION_NOTVERIFY
if verify_ssl:
key = DATA_CLIENTSESSION
if key not in hass.data: if key not in hass.data:
hass.data[key] = async_create_clientsession(hass, verify_ssl) hass.data[key] = _async_create_clientsession(
hass,
verify_ssl,
auto_cleanup_method=_async_register_default_clientsession_shutdown,
)
return cast(aiohttp.ClientSession, hass.data[key]) return cast(aiohttp.ClientSession, hass.data[key])
@ -59,24 +64,44 @@ def async_create_clientsession(
If auto_cleanup is False, you need to call detach() after the session If auto_cleanup is False, you need to call detach() after the session
returned is no longer used. Default is True, the session will be returned is no longer used. Default is True, the session will be
automatically detached on homeassistant_stop. automatically detached on homeassistant_stop or when being created
in config entry setup, the config entry is unloaded.
This method must be run in the event loop. This method must be run in the event loop.
""" """
connector = _async_get_connector(hass, verify_ssl) auto_cleanup_method = None
if auto_cleanup:
auto_cleanup_method = _async_register_clientsession_shutdown
clientsession = _async_create_clientsession(
hass,
verify_ssl,
auto_cleanup_method=auto_cleanup_method,
**kwargs,
)
return clientsession
@callback
def _async_create_clientsession(
hass: HomeAssistant,
verify_ssl: bool = True,
auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None]
| None = None,
**kwargs: Any,
) -> aiohttp.ClientSession:
"""Create a new ClientSession with kwargs, i.e. for cookies."""
clientsession = aiohttp.ClientSession( clientsession = aiohttp.ClientSession(
connector=connector, connector=_async_get_connector(hass, verify_ssl),
headers={USER_AGENT: SERVER_SOFTWARE}, headers={USER_AGENT: SERVER_SOFTWARE},
**kwargs, **kwargs,
) )
clientsession.close = warn_use( # type: ignore clientsession.close = warn_use(clientsession.close, WARN_CLOSE_MSG) # type: ignore
clientsession.close, "closes the Home Assistant aiohttp session"
)
if auto_cleanup: if auto_cleanup_method:
_async_register_clientsession_shutdown(hass, clientsession) auto_cleanup_method(hass, clientsession)
return clientsession return clientsession
@ -146,7 +171,33 @@ async def async_aiohttp_proxy_stream(
def _async_register_clientsession_shutdown( def _async_register_clientsession_shutdown(
hass: HomeAssistant, clientsession: aiohttp.ClientSession hass: HomeAssistant, clientsession: aiohttp.ClientSession
) -> None: ) -> None:
"""Register ClientSession close on Home Assistant shutdown. """Register ClientSession close on Home Assistant shutdown or config entry unload.
This method must be run in the event loop.
"""
@callback
def _async_close_websession(*_: Any) -> None:
"""Close websession."""
clientsession.detach()
unsub = hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_CLOSE, _async_close_websession
)
config_entry = config_entries.current_entry.get()
if not config_entry:
return
config_entry.async_on_unload(unsub)
config_entry.async_on_unload(_async_close_websession)
@callback
def _async_register_default_clientsession_shutdown(
hass: HomeAssistant, clientsession: aiohttp.ClientSession
) -> None:
"""Register default ClientSession close on Home Assistant shutdown.
This method must be run in the event loop. This method must be run in the event loop.
""" """

View File

@ -1,15 +1,16 @@
"""Test the config manager.""" """Test the config manager."""
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from homeassistant import config_entries, data_entry_flow, loader from homeassistant import config_entries, data_entry_flow, loader
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import CoreState, callback from homeassistant.core import CoreState, callback
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_create_clientsession
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util import dt from homeassistant.util import dt
@ -2489,3 +2490,97 @@ async def test_updating_entry_with_and_without_changes(manager):
assert manager.async_update_entry(entry, title="newtitle") is True assert manager.async_update_entry(entry, title="newtitle") is True
assert manager.async_update_entry(entry, unique_id="abc123") is False assert manager.async_update_entry(entry, unique_id="abc123") is False
assert manager.async_update_entry(entry, unique_id="abc1234") is True assert manager.async_update_entry(entry, unique_id="abc1234") is True
async def test_entry_reload_calls_on_unload_listeners(hass, manager):
"""Test reload calls the on unload listeners."""
entry = MockConfigEntry(domain="comp", state=config_entries.ENTRY_STATE_LOADED)
entry.add_to_hass(hass)
async_setup = AsyncMock(return_value=True)
mock_setup_entry = AsyncMock(return_value=True)
async_unload_entry = AsyncMock(return_value=True)
mock_integration(
hass,
MockModule(
"comp",
async_setup=async_setup,
async_setup_entry=mock_setup_entry,
async_unload_entry=async_unload_entry,
),
)
mock_entity_platform(hass, "config_flow.comp", None)
mock_unload_callback = Mock()
entry.async_on_unload(mock_unload_callback)
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1
assert len(mock_unload_callback.mock_calls) == 1
assert entry.state == config_entries.ENTRY_STATE_LOADED
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 2
assert len(mock_setup_entry.mock_calls) == 2
# Since we did not register another async_on_unload it should
# have only been called once
assert len(mock_unload_callback.mock_calls) == 1
assert entry.state == config_entries.ENTRY_STATE_LOADED
async def test_entry_reload_cleans_up_aiohttp_session(hass, manager):
"""Test reload cleans up aiohttp sessions their close listener created by the config entry."""
entry = MockConfigEntry(domain="comp", state=config_entries.ENTRY_STATE_LOADED)
entry.add_to_hass(hass)
async_setup_calls = 0
async def async_setup_entry(hass, _):
"""Mock setup entry."""
nonlocal async_setup_calls
async_setup_calls += 1
async_create_clientsession(hass)
return True
async_setup = AsyncMock(return_value=True)
async_unload_entry = AsyncMock(return_value=True)
mock_integration(
hass,
MockModule(
"comp",
async_setup=async_setup,
async_setup_entry=async_setup_entry,
async_unload_entry=async_unload_entry,
),
)
mock_entity_platform(hass, "config_flow.comp", None)
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 1
assert async_setup_calls == 1
assert entry.state == config_entries.ENTRY_STATE_LOADED
original_close_listeners = hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE]
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 2
assert async_setup_calls == 2
assert entry.state == config_entries.ENTRY_STATE_LOADED
assert (
hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE]
== original_close_listeners
)
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 3
assert async_setup_calls == 3
assert entry.state == config_entries.ENTRY_STATE_LOADED
assert (
hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE]
== original_close_listeners
)

View File

@ -288,7 +288,7 @@ def mock_aiohttp_client():
return session return session
with mock.patch( with mock.patch(
"homeassistant.helpers.aiohttp_client.async_create_clientsession", "homeassistant.helpers.aiohttp_client._async_create_clientsession",
side_effect=create_session, side_effect=create_session,
): ):
yield mocker yield mocker