Set right model in OpenAI conversation (#147575)

This commit is contained in:
Joost Lekkerkerker 2025-06-26 12:49:33 +02:00 committed by GitHub
parent a73dafe097
commit 4244d2f66f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 95 additions and 9 deletions

View File

@ -247,7 +247,7 @@ class OpenAIConversationEntity(
identifiers={(DOMAIN, subentry.subentry_id)},
name=subentry.title,
manufacturer="OpenAI",
model=entry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
model=subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
entry_type=dr.DeviceEntryType.SERVICE,
)
if self.subentry.data.get(CONF_LLM_HASS_API):

View File

@ -1,10 +1,12 @@
"""Tests helpers."""
from typing import Any
from unittest.mock import patch
import pytest
from homeassistant.components.openai_conversation.const import DEFAULT_CONVERSATION_NAME
from homeassistant.config_entries import ConfigSubentryData
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
@ -14,7 +16,15 @@ from tests.common import MockConfigEntry
@pytest.fixture
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
def mock_subentry_data() -> dict[str, Any]:
"""Mock subentry data."""
return {}
@pytest.fixture
def mock_config_entry(
hass: HomeAssistant, mock_subentry_data: dict[str, Any]
) -> MockConfigEntry:
"""Mock a config entry."""
entry = MockConfigEntry(
title="OpenAI",
@ -24,12 +34,12 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
},
version=2,
subentries_data=[
{
"data": {},
"subentry_type": "conversation",
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
ConfigSubentryData(
data=mock_subentry_data,
subentry_type="conversation",
title=DEFAULT_CONVERSATION_NAME,
unique_id=None,
)
],
)
entry.add_to_hass(hass)

View File

@ -0,0 +1,55 @@
# serializer version: 1
# name: test_devices[mock_subentry_data0]
DeviceRegistryEntrySnapshot({
'area_id': None,
'config_entries': <ANY>,
'config_entries_subentries': <ANY>,
'configuration_url': None,
'connections': set({
}),
'disabled_by': None,
'entry_type': <DeviceEntryType.SERVICE: 'service'>,
'hw_version': None,
'id': <ANY>,
'is_new': False,
'labels': set({
}),
'manufacturer': 'OpenAI',
'model': 'gpt-4o-mini',
'model_id': None,
'name': 'OpenAI Conversation',
'name_by_user': None,
'primary_config_entry': <ANY>,
'serial_number': None,
'suggested_area': None,
'sw_version': None,
'via_device_id': None,
})
# ---
# name: test_devices[mock_subentry_data1]
DeviceRegistryEntrySnapshot({
'area_id': None,
'config_entries': <ANY>,
'config_entries_subentries': <ANY>,
'configuration_url': None,
'connections': set({
}),
'disabled_by': None,
'entry_type': <DeviceEntryType.SERVICE: 'service'>,
'hw_version': None,
'id': <ANY>,
'is_new': False,
'labels': set({
}),
'manufacturer': 'OpenAI',
'model': 'gpt-1o',
'model_id': None,
'name': 'OpenAI Conversation',
'name_by_user': None,
'primary_config_entry': <ANY>,
'serial_number': None,
'suggested_area': None,
'sw_version': None,
'via_device_id': None,
})
# ---

View File

@ -13,8 +13,10 @@ from openai.types.image import Image
from openai.types.images_response import ImagesResponse
from openai.types.responses import Response, ResponseOutputMessage, ResponseOutputText
import pytest
from syrupy.assertion import SnapshotAssertion
from syrupy.filters import props
from homeassistant.components.openai_conversation import CONF_FILENAMES
from homeassistant.components.openai_conversation import CONF_CHAT_MODEL, CONF_FILENAMES
from homeassistant.components.openai_conversation.const import DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
@ -806,3 +808,22 @@ async def test_migration_from_v1_to_v2_with_same_keys(
identifiers={(DOMAIN, subentry.subentry_id)}
)
assert dev is not None
@pytest.mark.parametrize("mock_subentry_data", [{}, {CONF_CHAT_MODEL: "gpt-1o"}])
async def test_devices(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
) -> None:
"""Assert exception when invalid config entry is provided."""
devices = dr.async_entries_for_config_entry(
device_registry, mock_config_entry.entry_id
)
assert len(devices) == 1
device = devices[0]
assert device == snapshot(exclude=props("identifiers"))
subentry = next(iter(mock_config_entry.subentries.values()))
assert device.identifiers == {(DOMAIN, subentry.subentry_id)}