Refactor conversation mock_agent (#114428)

* Refactor conversation mock_agent

* Address review comments
This commit is contained in:
Sid 2024-03-29 14:38:58 +01:00 committed by GitHub
parent 8d6d70d6b5
commit dc557fca1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 52 additions and 28 deletions

View File

@ -10,6 +10,7 @@ from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from tests.common import MockToggleEntity from tests.common import MockToggleEntity
from tests.components.conversation import MockAgent
if TYPE_CHECKING: if TYPE_CHECKING:
from tests.components.device_tracker.common import MockScanner from tests.components.device_tracker.common import MockScanner
@ -104,6 +105,16 @@ def tts_mutagen_mock_fixture():
yield from tts_mutagen_mock_fixture_helper() yield from tts_mutagen_mock_fixture_helper()
@pytest.fixture(name="mock_conversation_agent")
def mock_conversation_agent_fixture(hass: HomeAssistant) -> MockAgent:
"""Mock a conversation agent."""
from tests.components.conversation.common import (
mock_conversation_agent_fixture_helper,
)
return mock_conversation_agent_fixture_helper(hass)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def prevent_ffmpeg_subprocess() -> Generator[None, None, None]: def prevent_ffmpeg_subprocess() -> Generator[None, None, None]:
"""Prevent ffmpeg from creating a subprocess.""" """Prevent ffmpeg from creating a subprocess."""

View File

@ -0,0 +1,17 @@
"""Provide common tests tools for conversation."""
from homeassistant.components import conversation
from homeassistant.core import HomeAssistant
from . import MockAgent
from tests.common import MockConfigEntry
def mock_conversation_agent_fixture_helper(hass: HomeAssistant) -> MockAgent:
"""Mock agent."""
entry = MockConfigEntry(entry_id="mock-entry")
entry.add_to_hass(hass)
agent = MockAgent(entry.entry_id, ["smurfish"])
conversation.async_set_agent(hass, entry, agent)
return agent

View File

@ -13,16 +13,6 @@ from . import MockAgent
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@pytest.fixture
def mock_agent(hass):
"""Mock agent."""
entry = MockConfigEntry(entry_id="mock-entry")
entry.add_to_hass(hass)
agent = MockAgent(entry.entry_id, ["smurfish"])
conversation.async_set_agent(hass, entry, agent)
return agent
@pytest.fixture @pytest.fixture
def mock_agent_support_all(hass): def mock_agent_support_all(hass):
"""Mock agent that supports all languages.""" """Mock agent that supports all languages."""

View File

@ -94,7 +94,7 @@ async def test_http_processing_intent_target_ha_agent(
init_components, init_components,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
hass_admin_user: MockUser, hass_admin_user: MockUser,
mock_agent, mock_conversation_agent,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
@ -658,7 +658,7 @@ async def test_custom_agent(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
hass_admin_user: MockUser, hass_admin_user: MockUser,
mock_agent, mock_conversation_agent,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test a custom conversation agent.""" """Test a custom conversation agent."""
@ -672,7 +672,7 @@ async def test_custom_agent(
"text": "Test Text", "text": "Test Text",
"conversation_id": "test-conv-id", "conversation_id": "test-conv-id",
"language": "test-language", "language": "test-language",
"agent_id": mock_agent.agent_id, "agent_id": mock_conversation_agent.agent_id,
} }
resp = await client.post("/api/conversation/process", json=data) resp = await client.post("/api/conversation/process", json=data)
@ -683,14 +683,14 @@ async def test_custom_agent(
assert data["response"]["speech"]["plain"]["speech"] == "Test response" assert data["response"]["speech"]["plain"]["speech"] == "Test response"
assert data["conversation_id"] == "test-conv-id" assert data["conversation_id"] == "test-conv-id"
assert len(mock_agent.calls) == 1 assert len(mock_conversation_agent.calls) == 1
assert mock_agent.calls[0].text == "Test Text" assert mock_conversation_agent.calls[0].text == "Test Text"
assert mock_agent.calls[0].context.user_id == hass_admin_user.id assert mock_conversation_agent.calls[0].context.user_id == hass_admin_user.id
assert mock_agent.calls[0].conversation_id == "test-conv-id" assert mock_conversation_agent.calls[0].conversation_id == "test-conv-id"
assert mock_agent.calls[0].language == "test-language" assert mock_conversation_agent.calls[0].language == "test-language"
conversation.async_unset_agent( conversation.async_unset_agent(
hass, hass.config_entries.async_get_entry(mock_agent.agent_id) hass, hass.config_entries.async_get_entry(mock_conversation_agent.agent_id)
) )
@ -1072,7 +1072,7 @@ async def test_agent_id_validator_invalid_agent(hass: HomeAssistant) -> None:
async def test_get_agent_list( async def test_get_agent_list(
hass: HomeAssistant, hass: HomeAssistant,
init_components, init_components,
mock_agent, mock_conversation_agent,
mock_agent_support_all, mock_agent_support_all,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
@ -1128,14 +1128,20 @@ async def test_get_agent_list(
async def test_get_agent_info( async def test_get_agent_info(
hass: HomeAssistant, init_components, mock_agent, snapshot: SnapshotAssertion hass: HomeAssistant,
init_components,
mock_conversation_agent,
snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test get agent info.""" """Test get agent info."""
agent_info = conversation.async_get_agent_info(hass) agent_info = conversation.async_get_agent_info(hass)
# Test it's the default # Test it's the default
assert conversation.async_get_agent_info(hass, "homeassistant") == agent_info assert conversation.async_get_agent_info(hass, "homeassistant") == agent_info
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot assert (
conversation.async_get_agent_info(hass, mock_conversation_agent.agent_id)
== snapshot
)
assert conversation.async_get_agent_info(hass, "not exist") is None assert conversation.async_get_agent_info(hass, "not exist") is None
# Test the name when config entry title is empty # Test the name when config entry title is empty

View File

@ -24,10 +24,6 @@ from homeassistant.setup import async_setup_component
from .const import CALL_SERVICE, FIRE_EVENT, REGISTER_CLEARTEXT, RENDER_TEMPLATE, UPDATE from .const import CALL_SERVICE, FIRE_EVENT, REGISTER_CLEARTEXT, RENDER_TEMPLATE, UPDATE
from tests.common import async_capture_events, async_mock_service from tests.common import async_capture_events, async_mock_service
from tests.components.conversation.conftest import mock_agent
# To avoid autoflake8 removing the import
mock_agent = mock_agent
@pytest.fixture @pytest.fixture
@ -1027,14 +1023,18 @@ async def test_reregister_sensor(
async def test_webhook_handle_conversation_process( async def test_webhook_handle_conversation_process(
hass: HomeAssistant, homeassistant, create_registrations, webhook_client, mock_agent hass: HomeAssistant,
homeassistant,
create_registrations,
webhook_client,
mock_conversation_agent,
) -> None: ) -> None:
"""Test that we can converse.""" """Test that we can converse."""
webhook_client.server.app.router._frozen = False webhook_client.server.app.router._frozen = False
with patch( with patch(
"homeassistant.components.conversation.AgentManager.async_get_agent", "homeassistant.components.conversation.AgentManager.async_get_agent",
return_value=mock_agent, return_value=mock_conversation_agent,
): ):
resp = await webhook_client.post( resp = await webhook_client.post(
"/api/webhook/{}".format(create_registrations[1]["webhook_id"]), "/api/webhook/{}".format(create_registrations[1]["webhook_id"]),