Use model list to check anthropic API key (#139307)

Anthropic model list
This commit is contained in:
Denis Shulyaka 2025-03-02 00:28:48 +03:00 committed by GitHub
parent 3588784f1e
commit 1786bb9903
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 20 additions and 31 deletions

View File

@ -12,7 +12,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from .const import DOMAIN, LOGGER from .const import CONF_CHAT_MODEL, DOMAIN, LOGGER, RECOMMENDED_CHAT_MODEL
PLATFORMS = (Platform.CONVERSATION,) PLATFORMS = (Platform.CONVERSATION,)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
@ -26,12 +26,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: AnthropicConfigEntry) ->
partial(anthropic.AsyncAnthropic, api_key=entry.data[CONF_API_KEY]) partial(anthropic.AsyncAnthropic, api_key=entry.data[CONF_API_KEY])
) )
try: try:
await client.messages.create( model_id = entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
model="claude-3-haiku-20240307", model = await client.models.retrieve(model_id=model_id, timeout=10.0)
max_tokens=1, LOGGER.debug("Anthropic model: %s", model.display_name)
messages=[{"role": "user", "content": "Hi"}],
timeout=10.0,
)
except anthropic.AuthenticationError as err: except anthropic.AuthenticationError as err:
LOGGER.error("Invalid API key: %s", err) LOGGER.error("Invalid API key: %s", err)
return False return False

View File

@ -63,12 +63,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
client = await hass.async_add_executor_job( client = await hass.async_add_executor_job(
partial(anthropic.AsyncAnthropic, api_key=data[CONF_API_KEY]) partial(anthropic.AsyncAnthropic, api_key=data[CONF_API_KEY])
) )
await client.messages.create( await client.models.list(timeout=10.0)
model="claude-3-haiku-20240307",
max_tokens=1,
messages=[{"role": "user", "content": "Hi"}],
timeout=10.0,
)
class AnthropicConfigFlow(ConfigFlow, domain=DOMAIN): class AnthropicConfigFlow(ConfigFlow, domain=DOMAIN):

View File

@ -1,7 +1,7 @@
"""Tests helpers.""" """Tests helpers."""
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, patch from unittest.mock import patch
import pytest import pytest
@ -43,9 +43,7 @@ async def mock_init_component(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> AsyncGenerator[None]: ) -> AsyncGenerator[None]:
"""Initialize integration.""" """Initialize integration."""
with patch( with patch("anthropic.resources.models.AsyncModels.retrieve"):
"anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock
):
assert await async_setup_component(hass, "anthropic", {}) assert await async_setup_component(hass, "anthropic", {})
await hass.async_block_till_done() await hass.async_block_till_done()
yield yield

View File

@ -49,7 +49,7 @@ async def test_form(hass: HomeAssistant) -> None:
with ( with (
patch( patch(
"homeassistant.components.anthropic.config_flow.anthropic.resources.messages.AsyncMessages.create", "homeassistant.components.anthropic.config_flow.anthropic.resources.models.AsyncModels.list",
new_callable=AsyncMock, new_callable=AsyncMock,
), ),
patch( patch(
@ -151,7 +151,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
) )
with patch( with patch(
"homeassistant.components.anthropic.config_flow.anthropic.resources.messages.AsyncMessages.create", "homeassistant.components.anthropic.config_flow.anthropic.resources.models.AsyncModels.list",
new_callable=AsyncMock, new_callable=AsyncMock,
side_effect=side_effect, side_effect=side_effect,
): ):

View File

@ -127,9 +127,7 @@ async def test_entity(
CONF_LLM_HASS_API: "assist", CONF_LLM_HASS_API: "assist",
}, },
) )
with patch( with patch("anthropic.resources.models.AsyncModels.retrieve"):
"anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock
):
await hass.config_entries.async_reload(mock_config_entry.entry_id) await hass.config_entries.async_reload(mock_config_entry.entry_id)
state = hass.states.get("conversation.claude") state = hass.states.get("conversation.claude")
@ -173,8 +171,11 @@ async def test_template_error(
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
}, },
) )
with patch( with (
patch("anthropic.resources.models.AsyncModels.retrieve"),
patch(
"anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock "anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock
),
): ):
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -205,6 +206,7 @@ async def test_template_variables(
}, },
) )
with ( with (
patch("anthropic.resources.models.AsyncModels.retrieve"),
patch( patch(
"anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock "anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock
) as mock_create, ) as mock_create,
@ -230,8 +232,8 @@ async def test_template_variables(
result.response.speech["plain"]["speech"] result.response.speech["plain"]["speech"]
== "Okay, let me take care of that for you." == "Okay, let me take care of that for you."
) )
assert "The user name is Test User." in mock_create.mock_calls[1][2]["system"] assert "The user name is Test User." in mock_create.call_args.kwargs["system"]
assert "The user id is 12345." in mock_create.mock_calls[1][2]["system"] assert "The user id is 12345." in mock_create.call_args.kwargs["system"]
async def test_conversation_agent( async def test_conversation_agent(
@ -497,9 +499,7 @@ async def test_unknown_hass_api(
assert result == snapshot assert result == snapshot
@patch("anthropic.resources.messages.AsyncMessages.create", new_callable=AsyncMock)
async def test_conversation_id( async def test_conversation_id(
mock_create,
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
mock_init_component, mock_init_component,

View File

@ -1,6 +1,6 @@
"""Tests for the Anthropic integration.""" """Tests for the Anthropic integration."""
from unittest.mock import AsyncMock, patch from unittest.mock import patch
from anthropic import ( from anthropic import (
APIConnectionError, APIConnectionError,
@ -55,8 +55,7 @@ async def test_init_error(
) -> None: ) -> None:
"""Test initialization errors.""" """Test initialization errors."""
with patch( with patch(
"anthropic.resources.messages.AsyncMessages.create", "anthropic.resources.models.AsyncModels.retrieve",
new_callable=AsyncMock,
side_effect=side_effect, side_effect=side_effect,
): ):
assert await async_setup_component(hass, "anthropic", {}) assert await async_setup_component(hass, "anthropic", {})