diff --git a/homeassistant/components/wiz/__init__.py b/homeassistant/components/wiz/__init__.py index 1bed875d02f..7bea86d323c 100644 --- a/homeassistant/components/wiz/__init__.py +++ b/homeassistant/components/wiz/__init__.py @@ -8,8 +8,8 @@ from pywizlight import PilotParser, wizlight from pywizlight.bulb import PIR_SOURCE from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_HOST, Platform -from homeassistant.core import HomeAssistant, callback +from homeassistant.const import CONF_HOST, EVENT_HOMEASSISTANT_STOP, Platform +from homeassistant.core import Event, HomeAssistant, callback from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers.debounce import Debouncer from homeassistant.helpers.dispatcher import async_dispatcher_send @@ -57,6 +57,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: scenes = await bulb.getSupportedScenes() await bulb.getMac() except WIZ_CONNECT_EXCEPTIONS as err: + await bulb.async_close() raise ConfigEntryNotReady(f"{ip_address}: {err}") from err async def _async_update() -> None: @@ -79,6 +80,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ), ) + try: + await coordinator.async_config_entry_first_refresh() + except ConfigEntryNotReady as err: + await bulb.async_close() + raise err + + async def _async_shutdown_on_stop(event: Event) -> None: + await bulb.async_close() + + entry.async_on_unload( + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_shutdown_on_stop) + ) + @callback def _async_push_update(state: PilotParser) -> None: """Receive a push update.""" @@ -89,7 +103,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: await bulb.start_push(_async_push_update) bulb.set_discovery_callback(lambda bulb: async_trigger_discovery(hass, [bulb])) - await coordinator.async_config_entry_first_refresh() hass.data.setdefault(DOMAIN, {})[entry.entry_id] = WizData( coordinator=coordinator, bulb=bulb, scenes=scenes diff --git a/tests/components/wiz/test_init.py b/tests/components/wiz/test_init.py index 6411146d162..fb21e930efd 100644 --- a/tests/components/wiz/test_init.py +++ b/tests/components/wiz/test_init.py @@ -3,6 +3,7 @@ import datetime from unittest.mock import AsyncMock from homeassistant import config_entries +from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import HomeAssistant from homeassistant.util.dt import utcnow @@ -30,3 +31,22 @@ async def test_setup_retry(hass: HomeAssistant) -> None: async_fire_time_changed(hass, utcnow() + datetime.timedelta(minutes=15)) await hass.async_block_till_done() assert entry.state == config_entries.ConfigEntryState.LOADED + + +async def test_cleanup_on_shutdown(hass: HomeAssistant) -> None: + """Test the socket is cleaned up on shutdown.""" + bulb = _mocked_wizlight(None, None, FAKE_SOCKET) + _, entry = await async_setup_integration(hass, wizlight=bulb) + assert entry.state == config_entries.ConfigEntryState.LOADED + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + bulb.async_close.assert_called_once() + + +async def test_cleanup_on_failed_first_update(hass: HomeAssistant) -> None: + """Test the socket is cleaned up on failed first update.""" + bulb = _mocked_wizlight(None, None, FAKE_SOCKET) + bulb.updateState = AsyncMock(side_effect=OSError) + _, entry = await async_setup_integration(hass, wizlight=bulb) + assert entry.state == config_entries.ConfigEntryState.SETUP_RETRY + bulb.async_close.assert_called_once()