mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Add support_streaming to ConversationEntity (#144998)
* Add support_streaming to ConversationEntity * pipeline tests
This commit is contained in:
parent
cff7aa229e
commit
37fe25cfdc
@ -203,7 +203,11 @@ def async_get_agent_info(
|
|||||||
name = agent.name
|
name = agent.name
|
||||||
if not isinstance(name, str):
|
if not isinstance(name, str):
|
||||||
name = agent.entity_id
|
name = agent.entity_id
|
||||||
return AgentInfo(id=agent.entity_id, name=name)
|
return AgentInfo(
|
||||||
|
id=agent.entity_id,
|
||||||
|
name=name,
|
||||||
|
supports_streaming=agent.supports_streaming,
|
||||||
|
)
|
||||||
|
|
||||||
manager = get_agent_manager(hass)
|
manager = get_agent_manager(hass)
|
||||||
|
|
||||||
|
@ -166,6 +166,7 @@ class AgentManager:
|
|||||||
AgentInfo(
|
AgentInfo(
|
||||||
id=agent_id,
|
id=agent_id,
|
||||||
name=config_entry.title or config_entry.domain,
|
name=config_entry.title or config_entry.domain,
|
||||||
|
supports_streaming=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return agents
|
return agents
|
||||||
|
@ -18,8 +18,14 @@ class ConversationEntity(RestoreEntity):
|
|||||||
|
|
||||||
_attr_should_poll = False
|
_attr_should_poll = False
|
||||||
_attr_supported_features = ConversationEntityFeature(0)
|
_attr_supported_features = ConversationEntityFeature(0)
|
||||||
|
_attr_supports_streaming = False
|
||||||
__last_activity: str | None = None
|
__last_activity: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_streaming(self) -> bool:
|
||||||
|
"""Return if the entity supports streaming responses."""
|
||||||
|
return self._attr_supports_streaming
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@final
|
@final
|
||||||
def state(self) -> str | None:
|
def state(self) -> str | None:
|
||||||
|
@ -16,6 +16,7 @@ class AgentInfo:
|
|||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
|
supports_streaming: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
|
@ -1083,7 +1083,11 @@ async def test_sentence_trigger_overrides_conversation_agent(
|
|||||||
# Ensure prepare succeeds
|
# Ensure prepare succeeds
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
return_value=conversation.AgentInfo(
|
||||||
|
id="test-agent",
|
||||||
|
name="Test Agent",
|
||||||
|
supports_streaming=False,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
await pipeline_input.validate()
|
await pipeline_input.validate()
|
||||||
|
|
||||||
@ -1161,7 +1165,11 @@ async def test_prefer_local_intents(
|
|||||||
# Ensure prepare succeeds
|
# Ensure prepare succeeds
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
return_value=conversation.AgentInfo(
|
||||||
|
id="test-agent",
|
||||||
|
name="Test Agent",
|
||||||
|
supports_streaming=False,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
await pipeline_input.validate()
|
await pipeline_input.validate()
|
||||||
|
|
||||||
@ -1225,7 +1233,11 @@ async def test_intent_continue_conversation(
|
|||||||
# Ensure prepare succeeds
|
# Ensure prepare succeeds
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
return_value=conversation.AgentInfo(
|
||||||
|
id="test-agent",
|
||||||
|
name="Test Agent",
|
||||||
|
supports_streaming=False,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
await pipeline_input.validate()
|
await pipeline_input.validate()
|
||||||
|
|
||||||
@ -1295,7 +1307,11 @@ async def test_intent_continue_conversation(
|
|||||||
# Ensure prepare succeeds
|
# Ensure prepare succeeds
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
return_value=conversation.AgentInfo(
|
||||||
|
id="test-agent",
|
||||||
|
name="Test Agent",
|
||||||
|
supports_streaming=False,
|
||||||
|
),
|
||||||
) as mock_prepare:
|
) as mock_prepare:
|
||||||
await pipeline_input.validate()
|
await pipeline_input.validate()
|
||||||
|
|
||||||
@ -1633,7 +1649,11 @@ async def test_chat_log_tts_streaming(
|
|||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
return_value=conversation.AgentInfo(
|
||||||
|
id="test-agent",
|
||||||
|
name="Test Agent",
|
||||||
|
supports_streaming=False,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
await pipeline_input.validate()
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
@ -29,18 +29,21 @@
|
|||||||
dict({
|
dict({
|
||||||
'id': 'conversation.home_assistant',
|
'id': 'conversation.home_assistant',
|
||||||
'name': 'Home Assistant',
|
'name': 'Home Assistant',
|
||||||
|
'supports_streaming': False,
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_get_agent_info.1
|
# name: test_get_agent_info.1
|
||||||
dict({
|
dict({
|
||||||
'id': 'mock-entry',
|
'id': 'mock-entry',
|
||||||
'name': 'Mock Title',
|
'name': 'Mock Title',
|
||||||
|
'supports_streaming': False,
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_get_agent_info.2
|
# name: test_get_agent_info.2
|
||||||
dict({
|
dict({
|
||||||
'id': 'conversation.home_assistant',
|
'id': 'conversation.home_assistant',
|
||||||
'name': 'Home Assistant',
|
'name': 'Home Assistant',
|
||||||
|
'supports_streaming': False,
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_turn_on_intent[None-turn kitchen on-None]
|
# name: test_turn_on_intent[None-turn kitchen on-None]
|
||||||
|
@ -220,6 +220,13 @@ async def test_get_agent_info(
|
|||||||
agent_info = conversation.async_get_agent_info(hass)
|
agent_info = conversation.async_get_agent_info(hass)
|
||||||
assert agent_info == snapshot
|
assert agent_info == snapshot
|
||||||
|
|
||||||
|
default_agent = conversation.async_get_agent(hass)
|
||||||
|
default_agent._attr_supports_streaming = True
|
||||||
|
assert (
|
||||||
|
conversation.async_get_agent_info(hass, "homeassistant").supports_streaming
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS)
|
@pytest.mark.parametrize("agent_id", AGENT_ID_OPTIONS)
|
||||||
async def test_prepare_agent(
|
async def test_prepare_agent(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user