diff --git a/homeassistant/components/zwave_js/__init__.py b/homeassistant/components/zwave_js/__init__.py index 360969e83d4..52a5a1b7388 100644 --- a/homeassistant/components/zwave_js/__init__.py +++ b/homeassistant/components/zwave_js/__init__.py @@ -1074,23 +1074,32 @@ async def client_listen( try: await client.listen(driver_ready) except BaseZwaveJSServerError as err: - if entry.state is not ConfigEntryState.LOADED: + if entry.state is ConfigEntryState.SETUP_IN_PROGRESS: raise LOGGER.error("Client listen failed: %s", err) except Exception as err: # We need to guard against unknown exceptions to not crash this task. LOGGER.exception("Unexpected exception: %s", err) - if entry.state is not ConfigEntryState.LOADED: + if entry.state is ConfigEntryState.SETUP_IN_PROGRESS: raise + if hass.is_stopping or entry.state is ConfigEntryState.UNLOAD_IN_PROGRESS: + return + + if entry.state is ConfigEntryState.SETUP_IN_PROGRESS: + raise HomeAssistantError("Listen task ended unexpectedly") + # The entry needs to be reloaded since a new driver state # will be acquired on reconnect. # All model instances will be replaced when the new state is acquired. - if not hass.is_stopping: - if entry.state is not ConfigEntryState.LOADED: - raise HomeAssistantError("Listen task ended unexpectedly") + if entry.state.recoverable: LOGGER.debug("Disconnected from server. Reloading integration") hass.config_entries.async_schedule_reload(entry.entry_id) + else: + LOGGER.error( + "Disconnected from server. Cannot recover entry %s", + entry.title, + ) async def async_unload_entry(hass: HomeAssistant, entry: ZwaveJSConfigEntry) -> bool: diff --git a/tests/components/zwave_js/conftest.py b/tests/components/zwave_js/conftest.py index 3c07869d5b7..eef92a7eb0a 100644 --- a/tests/components/zwave_js/conftest.py +++ b/tests/components/zwave_js/conftest.py @@ -565,12 +565,6 @@ def mock_listen_block_fixture() -> asyncio.Event: return asyncio.Event() -@pytest.fixture(name="listen_result") -def listen_result_fixture() -> asyncio.Future[None]: - """Mock a listen result.""" - return asyncio.Future() - - @pytest.fixture(name="client") def mock_client_fixture( controller_state: dict[str, Any], @@ -578,7 +572,6 @@ def mock_client_fixture( version_state: dict[str, Any], log_config_state: dict[str, Any], listen_block: asyncio.Event, - listen_result: asyncio.Future[None], ): """Mock a client.""" with patch( @@ -587,15 +580,16 @@ def mock_client_fixture( client = client_class.return_value async def connect(): + listen_block.clear() await asyncio.sleep(0) client.connected = True async def listen(driver_ready: asyncio.Event) -> None: driver_ready.set() await listen_block.wait() - await listen_result async def disconnect(): + listen_block.set() client.connected = False client.connect = AsyncMock(side_effect=connect) diff --git a/tests/components/zwave_js/test_init.py b/tests/components/zwave_js/test_init.py index d9b3f392dd6..4decb061ad0 100644 --- a/tests/components/zwave_js/test_init.py +++ b/tests/components/zwave_js/test_init.py @@ -196,19 +196,24 @@ async def test_listen_done_during_setup_before_forward_entry( hass: HomeAssistant, client: MagicMock, listen_block: asyncio.Event, - listen_result: asyncio.Future[None], core_state: CoreState, listen_future_result_method: str, listen_future_result: Exception | None, ) -> None: """Test listen task finishing during setup before forward entry.""" + listen_result = asyncio.Future[None]() assert hass.state is CoreState.running + async def connect(): + await asyncio.sleep(0) + client.connected = True + async def listen(driver_ready: asyncio.Event) -> None: await listen_block.wait() await listen_result async_fire_time_changed(hass, fire_all=True) + client.connect.side_effect = connect client.listen.side_effect = listen hass.set_state(core_state) listen_block.set() @@ -229,9 +234,9 @@ async def test_not_connected_during_setup_after_forward_entry( hass: HomeAssistant, client: MagicMock, listen_block: asyncio.Event, - listen_result: asyncio.Future[None], ) -> None: """Test we handle not connected client during setup after forward entry.""" + listen_result = asyncio.Future[None]() async def send_command_side_effect(*args: Any, **kwargs: Any) -> None: """Mock send command.""" @@ -277,12 +282,12 @@ async def test_listen_done_during_setup_after_forward_entry( hass: HomeAssistant, client: MagicMock, listen_block: asyncio.Event, - listen_result: asyncio.Future[None], core_state: CoreState, listen_future_result_method: str, listen_future_result: Exception | None, ) -> None: """Test listen task finishing during setup after forward entry.""" + listen_result = asyncio.Future[None]() assert hass.state is CoreState.running original_send_command_side_effect = client.async_send_command.side_effect @@ -320,16 +325,14 @@ async def test_listen_done_during_setup_after_forward_entry( @pytest.mark.parametrize( - ("core_state", "final_config_entry_state", "disconnect_call_count"), + ("core_state", "disconnect_call_count"), [ ( CoreState.running, - ConfigEntryState.SETUP_RETRY, - 2, - ), # the reload will cause a disconnect call too + 1, + ), # the reload will cause a disconnect ( CoreState.stopping, - ConfigEntryState.LOADED, 0, ), # the home assistant stop event will handle the disconnect ], @@ -345,19 +348,33 @@ async def test_listen_done_during_setup_after_forward_entry( async def test_listen_done_after_setup( hass: HomeAssistant, client: MagicMock, - integration: MockConfigEntry, listen_block: asyncio.Event, - listen_result: asyncio.Future[None], core_state: CoreState, listen_future_result_method: str, listen_future_result: Exception | None, - final_config_entry_state: ConfigEntryState, disconnect_call_count: int, ) -> None: """Test listen task finishing after setup.""" - config_entry = integration - assert config_entry.state is ConfigEntryState.LOADED + listen_result = asyncio.Future[None]() + + async def listen(driver_ready: asyncio.Event) -> None: + driver_ready.set() + await listen_block.wait() + await listen_result + + client.listen.side_effect = listen + + config_entry = MockConfigEntry( + domain="zwave_js", + data={"url": "ws://test.org", "data_collection_opted_in": True}, + ) + config_entry.add_to_hass(hass) + + await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + assert hass.state is CoreState.running + assert config_entry.state is ConfigEntryState.LOADED assert client.disconnect.call_count == 0 hass.set_state(core_state) @@ -365,10 +382,51 @@ async def test_listen_done_after_setup( getattr(listen_result, listen_future_result_method)(listen_future_result) await hass.async_block_till_done() - assert config_entry.state is final_config_entry_state + assert config_entry.state is ConfigEntryState.LOADED assert client.disconnect.call_count == disconnect_call_count +async def test_listen_ending_before_cancelling_listen( + hass: HomeAssistant, + integration: MockConfigEntry, + listen_block: asyncio.Event, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test listen ending during unloading before cancelling the listen task.""" + config_entry = integration + + # We can't easily simulate the race condition where the listen task ends + # before getting cancelled by the config entry during unloading. + # Use mock_state to provoke the correct condition. + config_entry.mock_state(hass, ConfigEntryState.UNLOAD_IN_PROGRESS, None) + listen_block.set() + await hass.async_block_till_done() + + assert config_entry.state is ConfigEntryState.UNLOAD_IN_PROGRESS + assert not any(record.levelno == logging.ERROR for record in caplog.records) + + +async def test_listen_ending_unrecoverable_config_entry_state( + hass: HomeAssistant, + integration: MockConfigEntry, + listen_block: asyncio.Event, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test listen ending when the config entry has an unrecoverable state.""" + config_entry = integration + + with patch.object( + hass.config_entries, "async_unload_platforms", return_value=False + ): + await hass.config_entries.async_unload(config_entry.entry_id) + + listen_block.set() + await hass.async_block_till_done() + + assert config_entry.state is ConfigEntryState.FAILED_UNLOAD + assert "Disconnected from server. Cannot recover entry" in caplog.text + + @pytest.mark.usefixtures("client") @pytest.mark.parametrize("platforms", [[Platform.SENSOR]]) async def test_new_entity_on_value_added(