From 37fe25cfdc32ea6b29277b4db8785223caf59295 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 19 May 2025 13:43:06 -0400 Subject: [PATCH] Add support_streaming to ConversationEntity (#144998) * Add support_streaming to ConversationEntity * pipeline tests --- .../components/conversation/__init__.py | 6 +++- .../components/conversation/agent_manager.py | 1 + .../components/conversation/entity.py | 6 ++++ .../components/conversation/models.py | 1 + .../assist_pipeline/test_pipeline.py | 30 +++++++++++++++---- .../conversation/snapshots/test_init.ambr | 3 ++ tests/components/conversation/test_init.py | 7 +++++ 7 files changed, 48 insertions(+), 6 deletions(-) diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index 25aaf6df290..fff2c00641f 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -203,7 +203,11 @@ def async_get_agent_info( name = agent.name if not isinstance(name, str): name = agent.entity_id - return AgentInfo(id=agent.entity_id, name=name) + return AgentInfo( + id=agent.entity_id, + name=name, + supports_streaming=agent.supports_streaming, + ) manager = get_agent_manager(hass) diff --git a/homeassistant/components/conversation/agent_manager.py b/homeassistant/components/conversation/agent_manager.py index 5ff47977d88..38c0ca8db6b 100644 --- a/homeassistant/components/conversation/agent_manager.py +++ b/homeassistant/components/conversation/agent_manager.py @@ -166,6 +166,7 @@ class AgentManager: AgentInfo( id=agent_id, name=config_entry.title or config_entry.domain, + supports_streaming=False, ) ) return agents diff --git a/homeassistant/components/conversation/entity.py b/homeassistant/components/conversation/entity.py index ca4d18ab9f5..60cf24dbf96 100644 --- a/homeassistant/components/conversation/entity.py +++ b/homeassistant/components/conversation/entity.py @@ -18,8 +18,14 @@ class ConversationEntity(RestoreEntity): _attr_should_poll = False _attr_supported_features = ConversationEntityFeature(0) + _attr_supports_streaming = False __last_activity: str | None = None + @property + def supports_streaming(self) -> bool: + """Return if the entity supports streaming responses.""" + return self._attr_supports_streaming + @property @final def state(self) -> str | None: diff --git a/homeassistant/components/conversation/models.py b/homeassistant/components/conversation/models.py index 7bdd13afc01..00097f5b4d3 100644 --- a/homeassistant/components/conversation/models.py +++ b/homeassistant/components/conversation/models.py @@ -16,6 +16,7 @@ class AgentInfo: id: str name: str + supports_streaming: bool @dataclass(slots=True) diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index abf6572afc9..f4e7c886d40 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -1083,7 +1083,11 @@ async def test_sentence_trigger_overrides_conversation_agent( # Ensure prepare succeeds with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", - return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), + return_value=conversation.AgentInfo( + id="test-agent", + name="Test Agent", + supports_streaming=False, + ), ): await pipeline_input.validate() @@ -1161,7 +1165,11 @@ async def test_prefer_local_intents( # Ensure prepare succeeds with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", - return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), + return_value=conversation.AgentInfo( + id="test-agent", + name="Test Agent", + supports_streaming=False, + ), ): await pipeline_input.validate() @@ -1225,7 +1233,11 @@ async def test_intent_continue_conversation( # Ensure prepare succeeds with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", - return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), + return_value=conversation.AgentInfo( + id="test-agent", + name="Test Agent", + supports_streaming=False, + ), ): await pipeline_input.validate() @@ -1295,7 +1307,11 @@ async def test_intent_continue_conversation( # Ensure prepare succeeds with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", - return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), + return_value=conversation.AgentInfo( + id="test-agent", + name="Test Agent", + supports_streaming=False, + ), ) as mock_prepare: await pipeline_input.validate() @@ -1633,7 +1649,11 @@ async def test_chat_log_tts_streaming( with patch( "homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info", - return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"), + return_value=conversation.AgentInfo( + id="test-agent", + name="Test Agent", + supports_streaming=False, + ), ): await pipeline_input.validate() diff --git a/tests/components/conversation/snapshots/test_init.ambr b/tests/components/conversation/snapshots/test_init.ambr index 3d843d4e32a..a853faa7a3d 100644 --- a/tests/components/conversation/snapshots/test_init.ambr +++ b/tests/components/conversation/snapshots/test_init.ambr @@ -29,18 +29,21 @@ dict({ 'id': 'conversation.home_assistant', 'name': 'Home Assistant', + 'supports_streaming': False, }) # --- # name: test_get_agent_info.1 dict({ 'id': 'mock-entry', 'name': 'Mock Title', + 'supports_streaming': False, }) # --- # name: test_get_agent_info.2 dict({ 'id': 'conversation.home_assistant', 'name': 'Home Assistant', + 'supports_streaming': False, }) # --- # name: test_turn_on_intent[None-turn kitchen on-None] diff --git a/tests/components/conversation/test_init.py b/tests/components/conversation/test_init.py index 9ac5c7d16a4..c3de5f1127c 100644 --- a/tests/components/conversation/test_init.py +++ b/tests/components/conversation/test_init.py @@ -220,6 +220,13 @@ async def test_get_agent_info( agent_info = conversation.async_get_agent_info(hass) assert agent_info == snapshot + default_agent = conversation.async_get_agent(hass) + default_agent._attr_supports_streaming = True + assert ( + conversation.async_get_agent_info(hass, "homeassistant").supports_streaming + is True + ) + @pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS) async def test_prepare_agent(