diff --git a/tests/components/google_generative_ai_conversation/conftest.py b/tests/components/google_generative_ai_conversation/conftest.py index 0d222c78472..f499f18bc15 100644 --- a/tests/components/google_generative_ai_conversation/conftest.py +++ b/tests/components/google_generative_ai_conversation/conftest.py @@ -1,6 +1,7 @@ """Tests helpers.""" -from unittest.mock import Mock, patch +from collections.abc import Generator +from unittest.mock import AsyncMock, Mock, patch import pytest @@ -77,3 +78,22 @@ async def mock_init_component( async def setup_ha(hass: HomeAssistant) -> None: """Set up Home Assistant.""" assert await async_setup_component(hass, "homeassistant", {}) + + +@pytest.fixture +def mock_send_message_stream() -> Generator[AsyncMock]: + """Mock stream response.""" + + async def mock_generator(stream): + for value in stream: + yield value + + with patch( + "google.genai.chats.AsyncChat.send_message_stream", + AsyncMock(), + ) as mock_send_message_stream: + mock_send_message_stream.side_effect = lambda **kwargs: mock_generator( + mock_send_message_stream.return_value.pop(0) + ) + + yield mock_send_message_stream diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index a55a86b67c9..92aa6f08d42 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -1,6 +1,5 @@ """Tests for the Google Generative AI Conversation integration conversation platform.""" -from collections.abc import Generator from unittest.mock import AsyncMock, patch from freezegun import freeze_time @@ -41,25 +40,6 @@ def mock_ulid_tools(): yield -@pytest.fixture -def mock_send_message_stream() -> Generator[AsyncMock]: - """Mock stream response.""" - - async def mock_generator(stream): - for value in stream: - yield value - - with patch( - "google.genai.chats.AsyncChat.send_message_stream", - AsyncMock(), - ) as mock_send_message_stream: - mock_send_message_stream.side_effect = lambda **kwargs: mock_generator( - mock_send_message_stream.return_value.pop(0) - ) - - yield mock_send_message_stream - - @pytest.mark.parametrize( ("error"), [