Refactor conversation mock_agent (#114428)

* Refactor conversation mock_agent

* Address review comments
This commit is contained in:
Sid 2024-03-29 14:38:58 +01:00 committed by GitHub
parent 8d6d70d6b5
commit dc557fca1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 52 additions and 28 deletions

View File

@ -10,6 +10,7 @@ from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant
from tests.common import MockToggleEntity
from tests.components.conversation import MockAgent
if TYPE_CHECKING:
from tests.components.device_tracker.common import MockScanner
@ -104,6 +105,16 @@ def tts_mutagen_mock_fixture():
yield from tts_mutagen_mock_fixture_helper()
@pytest.fixture(name="mock_conversation_agent")
def mock_conversation_agent_fixture(hass: HomeAssistant) -> MockAgent:
"""Mock a conversation agent."""
from tests.components.conversation.common import (
mock_conversation_agent_fixture_helper,
)
return mock_conversation_agent_fixture_helper(hass)
@pytest.fixture(scope="session", autouse=True)
def prevent_ffmpeg_subprocess() -> Generator[None, None, None]:
"""Prevent ffmpeg from creating a subprocess."""

View File

@ -0,0 +1,17 @@
"""Provide common tests tools for conversation."""
from homeassistant.components import conversation
from homeassistant.core import HomeAssistant
from . import MockAgent
from tests.common import MockConfigEntry
def mock_conversation_agent_fixture_helper(hass: HomeAssistant) -> MockAgent:
"""Mock agent."""
entry = MockConfigEntry(entry_id="mock-entry")
entry.add_to_hass(hass)
agent = MockAgent(entry.entry_id, ["smurfish"])
conversation.async_set_agent(hass, entry, agent)
return agent

View File

@ -13,16 +13,6 @@ from . import MockAgent
from tests.common import MockConfigEntry
@pytest.fixture
def mock_agent(hass):
"""Mock agent."""
entry = MockConfigEntry(entry_id="mock-entry")
entry.add_to_hass(hass)
agent = MockAgent(entry.entry_id, ["smurfish"])
conversation.async_set_agent(hass, entry, agent)
return agent
@pytest.fixture
def mock_agent_support_all(hass):
"""Mock agent that supports all languages."""

View File

@ -94,7 +94,7 @@ async def test_http_processing_intent_target_ha_agent(
init_components,
hass_client: ClientSessionGenerator,
hass_admin_user: MockUser,
mock_agent,
mock_conversation_agent,
entity_registry: er.EntityRegistry,
snapshot: SnapshotAssertion,
) -> None:
@ -658,7 +658,7 @@ async def test_custom_agent(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
hass_admin_user: MockUser,
mock_agent,
mock_conversation_agent,
snapshot: SnapshotAssertion,
) -> None:
"""Test a custom conversation agent."""
@ -672,7 +672,7 @@ async def test_custom_agent(
"text": "Test Text",
"conversation_id": "test-conv-id",
"language": "test-language",
"agent_id": mock_agent.agent_id,
"agent_id": mock_conversation_agent.agent_id,
}
resp = await client.post("/api/conversation/process", json=data)
@ -683,14 +683,14 @@ async def test_custom_agent(
assert data["response"]["speech"]["plain"]["speech"] == "Test response"
assert data["conversation_id"] == "test-conv-id"
assert len(mock_agent.calls) == 1
assert mock_agent.calls[0].text == "Test Text"
assert mock_agent.calls[0].context.user_id == hass_admin_user.id
assert mock_agent.calls[0].conversation_id == "test-conv-id"
assert mock_agent.calls[0].language == "test-language"
assert len(mock_conversation_agent.calls) == 1
assert mock_conversation_agent.calls[0].text == "Test Text"
assert mock_conversation_agent.calls[0].context.user_id == hass_admin_user.id
assert mock_conversation_agent.calls[0].conversation_id == "test-conv-id"
assert mock_conversation_agent.calls[0].language == "test-language"
conversation.async_unset_agent(
hass, hass.config_entries.async_get_entry(mock_agent.agent_id)
hass, hass.config_entries.async_get_entry(mock_conversation_agent.agent_id)
)
@ -1072,7 +1072,7 @@ async def test_agent_id_validator_invalid_agent(hass: HomeAssistant) -> None:
async def test_get_agent_list(
hass: HomeAssistant,
init_components,
mock_agent,
mock_conversation_agent,
mock_agent_support_all,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
@ -1128,14 +1128,20 @@ async def test_get_agent_list(
async def test_get_agent_info(
hass: HomeAssistant, init_components, mock_agent, snapshot: SnapshotAssertion
hass: HomeAssistant,
init_components,
mock_conversation_agent,
snapshot: SnapshotAssertion,
) -> None:
"""Test get agent info."""
agent_info = conversation.async_get_agent_info(hass)
# Test it's the default
assert conversation.async_get_agent_info(hass, "homeassistant") == agent_info
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot
assert (
conversation.async_get_agent_info(hass, mock_conversation_agent.agent_id)
== snapshot
)
assert conversation.async_get_agent_info(hass, "not exist") is None
# Test the name when config entry title is empty

View File

@ -24,10 +24,6 @@ from homeassistant.setup import async_setup_component
from .const import CALL_SERVICE, FIRE_EVENT, REGISTER_CLEARTEXT, RENDER_TEMPLATE, UPDATE
from tests.common import async_capture_events, async_mock_service
from tests.components.conversation.conftest import mock_agent
# To avoid autoflake8 removing the import
mock_agent = mock_agent
@pytest.fixture
@ -1027,14 +1023,18 @@ async def test_reregister_sensor(
async def test_webhook_handle_conversation_process(
hass: HomeAssistant, homeassistant, create_registrations, webhook_client, mock_agent
hass: HomeAssistant,
homeassistant,
create_registrations,
webhook_client,
mock_conversation_agent,
) -> None:
"""Test that we can converse."""
webhook_client.server.app.router._frozen = False
with patch(
"homeassistant.components.conversation.AgentManager.async_get_agent",
return_value=mock_agent,
return_value=mock_conversation_agent,
):
resp = await webhook_client.post(
"/api/webhook/{}".format(create_registrations[1]["webhook_id"]),