diff --git a/homeassistant/components/assist_satellite/entity.py b/homeassistant/components/assist_satellite/entity.py index e43abb4539c..8c63525294c 100644 --- a/homeassistant/components/assist_satellite/entity.py +++ b/homeassistant/components/assist_satellite/entity.py @@ -405,7 +405,10 @@ class AssistSatelliteEntity(entity.Entity): def _internal_on_pipeline_event(self, event: PipelineEvent) -> None: """Set state based on pipeline stage.""" if event.type is PipelineEventType.WAKE_WORD_START: - self._set_state(AssistSatelliteState.IDLE) + # Only return to idle if we're not currently responding. + # The state will return to idle in tts_response_finished. + if self.state != AssistSatelliteState.RESPONDING: + self._set_state(AssistSatelliteState.IDLE) elif event.type is PipelineEventType.STT_START: self._set_state(AssistSatelliteState.LISTENING) elif event.type is PipelineEventType.INTENT_START: diff --git a/tests/components/assist_satellite/test_entity.py b/tests/components/assist_satellite/test_entity.py index b3437bf5c5d..42b4adf742c 100644 --- a/tests/components/assist_satellite/test_entity.py +++ b/tests/components/assist_satellite/test_entity.py @@ -590,3 +590,54 @@ async def test_start_conversation_reject_builtin_agent( target={"entity_id": "assist_satellite.test_entity"}, blocking=True, ) + + +async def test_wake_word_start_keeps_responding( + hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite +) -> None: + """Test entity state stays responding on wake word start event.""" + + state = hass.states.get(ENTITY_ID) + assert state is not None + assert state.state == AssistSatelliteState.IDLE + + # Get into responding state + audio_stream = object() + + with patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream" + ) as mock_start_pipeline: + await entity.async_accept_pipeline_from_satellite( + audio_stream, start_stage=PipelineStage.TTS + ) + + assert mock_start_pipeline.called + kwargs = mock_start_pipeline.call_args[1] + event_callback = kwargs["event_callback"] + event_callback(PipelineEvent(PipelineEventType.TTS_START, {})) + + state = hass.states.get(ENTITY_ID) + assert state.state == AssistSatelliteState.RESPONDING + + # Verify that starting a new wake word stream keeps the state + audio_stream = object() + + with patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream" + ) as mock_start_pipeline: + await entity.async_accept_pipeline_from_satellite( + audio_stream, start_stage=PipelineStage.WAKE_WORD + ) + + assert mock_start_pipeline.called + kwargs = mock_start_pipeline.call_args[1] + event_callback = kwargs["event_callback"] + event_callback(PipelineEvent(PipelineEventType.WAKE_WORD_START, {})) + + state = hass.states.get(ENTITY_ID) + assert state.state == AssistSatelliteState.RESPONDING + + # Only return to idle once TTS is finished + entity.tts_response_finished() + state = hass.states.get(ENTITY_ID) + assert state.state == AssistSatelliteState.IDLE