Add support_streaming to ConversationEntity (#144998)

* Add support_streaming to ConversationEntity

* pipeline tests
This commit is contained in:
Paulus Schoutsen 2025-05-19 13:43:06 -04:00 committed by GitHub
parent cff7aa229e
commit 37fe25cfdc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 48 additions and 6 deletions

View File

@ -203,7 +203,11 @@ def async_get_agent_info(
name = agent.name
if not isinstance(name, str):
name = agent.entity_id
return AgentInfo(id=agent.entity_id, name=name)
return AgentInfo(
id=agent.entity_id,
name=name,
supports_streaming=agent.supports_streaming,
)
manager = get_agent_manager(hass)

View File

@ -166,6 +166,7 @@ class AgentManager:
AgentInfo(
id=agent_id,
name=config_entry.title or config_entry.domain,
supports_streaming=False,
)
)
return agents

View File

@ -18,8 +18,14 @@ class ConversationEntity(RestoreEntity):
_attr_should_poll = False
_attr_supported_features = ConversationEntityFeature(0)
_attr_supports_streaming = False
__last_activity: str | None = None
@property
def supports_streaming(self) -> bool:
"""Return if the entity supports streaming responses."""
return self._attr_supports_streaming
@property
@final
def state(self) -> str | None:

View File

@ -16,6 +16,7 @@ class AgentInfo:
id: str
name: str
supports_streaming: bool
@dataclass(slots=True)

View File

@ -1083,7 +1083,11 @@ async def test_sentence_trigger_overrides_conversation_agent(
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
return_value=conversation.AgentInfo(
id="test-agent",
name="Test Agent",
supports_streaming=False,
),
):
await pipeline_input.validate()
@ -1161,7 +1165,11 @@ async def test_prefer_local_intents(
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
return_value=conversation.AgentInfo(
id="test-agent",
name="Test Agent",
supports_streaming=False,
),
):
await pipeline_input.validate()
@ -1225,7 +1233,11 @@ async def test_intent_continue_conversation(
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
return_value=conversation.AgentInfo(
id="test-agent",
name="Test Agent",
supports_streaming=False,
),
):
await pipeline_input.validate()
@ -1295,7 +1307,11 @@ async def test_intent_continue_conversation(
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
return_value=conversation.AgentInfo(
id="test-agent",
name="Test Agent",
supports_streaming=False,
),
) as mock_prepare:
await pipeline_input.validate()
@ -1633,7 +1649,11 @@ async def test_chat_log_tts_streaming(
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
return_value=conversation.AgentInfo(
id="test-agent",
name="Test Agent",
supports_streaming=False,
),
):
await pipeline_input.validate()

View File

@ -29,18 +29,21 @@
dict({
'id': 'conversation.home_assistant',
'name': 'Home Assistant',
'supports_streaming': False,
})
# ---
# name: test_get_agent_info.1
dict({
'id': 'mock-entry',
'name': 'Mock Title',
'supports_streaming': False,
})
# ---
# name: test_get_agent_info.2
dict({
'id': 'conversation.home_assistant',
'name': 'Home Assistant',
'supports_streaming': False,
})
# ---
# name: test_turn_on_intent[None-turn kitchen on-None]

View File

@ -220,6 +220,13 @@ async def test_get_agent_info(
agent_info = conversation.async_get_agent_info(hass)
assert agent_info == snapshot
default_agent = conversation.async_get_agent(hass)
default_agent._attr_supports_streaming = True
assert (
conversation.async_get_agent_info(hass, "homeassistant").supports_streaming
is True
)
@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS)
async def test_prepare_agent(