Anthropic web search support (#153753)

This commit is contained in:
Denis Shulyaka
2025-10-10 17:21:21 +03:00
committed by GitHub
parent f49299b009
commit 517124dfbe
8 changed files with 1021 additions and 40 deletions

View File

@@ -4,12 +4,15 @@ from __future__ import annotations
from collections.abc import Mapping
from functools import partial
import json
import logging
from typing import Any, cast
import anthropic
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components.zone import ENTITY_ID_HOME
from homeassistant.config_entries import (
ConfigEntry,
ConfigEntryState,
@@ -18,7 +21,13 @@ from homeassistant.config_entries import (
ConfigSubentryFlow,
SubentryFlowResult,
)
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME
from homeassistant.const import (
ATTR_LATITUDE,
ATTR_LONGITUDE,
CONF_API_KEY,
CONF_LLM_HASS_API,
CONF_NAME,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import llm
from homeassistant.helpers.selector import (
@@ -37,12 +46,23 @@ from .const import (
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_THINKING_BUDGET,
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_CITY,
CONF_WEB_SEARCH_COUNTRY,
CONF_WEB_SEARCH_MAX_USES,
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT_CONVERSATION_NAME,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_THINKING_BUDGET,
RECOMMENDED_WEB_SEARCH,
RECOMMENDED_WEB_SEARCH_MAX_USES,
RECOMMENDED_WEB_SEARCH_USER_LOCATION,
WEB_SEARCH_UNSUPPORTED_MODELS,
)
_LOGGER = logging.getLogger(__name__)
@@ -168,6 +188,14 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET
) >= user_input.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS):
errors[CONF_THINKING_BUDGET] = "thinking_budget_too_large"
if user_input.get(CONF_WEB_SEARCH, RECOMMENDED_WEB_SEARCH):
model = user_input.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
if model.startswith(tuple(WEB_SEARCH_UNSUPPORTED_MODELS)):
errors[CONF_WEB_SEARCH] = "web_search_unsupported_model"
elif user_input.get(
CONF_WEB_SEARCH_USER_LOCATION, RECOMMENDED_WEB_SEARCH_USER_LOCATION
):
user_input.update(await self._get_location_data())
if not errors:
if self._is_new:
@@ -215,6 +243,68 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
errors=errors or None,
)
async def _get_location_data(self) -> dict[str, str]:
"""Get approximate location data of the user."""
location_data: dict[str, str] = {}
zone_home = self.hass.states.get(ENTITY_ID_HOME)
if zone_home is not None:
client = await self.hass.async_add_executor_job(
partial(
anthropic.AsyncAnthropic,
api_key=self._get_entry().data[CONF_API_KEY],
)
)
location_schema = vol.Schema(
{
vol.Optional(
CONF_WEB_SEARCH_CITY,
description="Free text input for the city, e.g. `San Francisco`",
): str,
vol.Optional(
CONF_WEB_SEARCH_REGION,
description="Free text input for the region, e.g. `California`",
): str,
}
)
response = await client.messages.create(
model=RECOMMENDED_CHAT_MODEL,
messages=[
{
"role": "user",
"content": "Where are the following coordinates located: "
f"({zone_home.attributes[ATTR_LATITUDE]},"
f" {zone_home.attributes[ATTR_LONGITUDE]})? Please respond "
"only with a JSON object using the following schema:\n"
f"{convert(location_schema)}",
},
{
"role": "assistant",
"content": "{", # hints the model to skip any preamble
},
],
max_tokens=RECOMMENDED_MAX_TOKENS,
)
_LOGGER.debug("Model response: %s", response.content)
location_data = location_schema(
json.loads(
"{"
+ "".join(
block.text
for block in response.content
if isinstance(block, anthropic.types.TextBlock)
)
)
or {}
)
if self.hass.config.country:
location_data[CONF_WEB_SEARCH_COUNTRY] = self.hass.config.country
location_data[CONF_WEB_SEARCH_TIMEZONE] = self.hass.config.time_zone
_LOGGER.debug("Location data: %s", location_data)
return location_data
async_step_user = async_step_set_options
async_step_reconfigure = async_step_set_options
@@ -273,6 +363,18 @@ def anthropic_config_option_schema(
CONF_THINKING_BUDGET,
default=RECOMMENDED_THINKING_BUDGET,
): int,
vol.Optional(
CONF_WEB_SEARCH,
default=RECOMMENDED_WEB_SEARCH,
): bool,
vol.Optional(
CONF_WEB_SEARCH_MAX_USES,
default=RECOMMENDED_WEB_SEARCH_MAX_USES,
): int,
vol.Optional(
CONF_WEB_SEARCH_USER_LOCATION,
default=RECOMMENDED_WEB_SEARCH_USER_LOCATION,
): bool,
}
)
return schema

View File

@@ -18,9 +18,26 @@ RECOMMENDED_TEMPERATURE = 1.0
CONF_THINKING_BUDGET = "thinking_budget"
RECOMMENDED_THINKING_BUDGET = 0
MIN_THINKING_BUDGET = 1024
CONF_WEB_SEARCH = "web_search"
RECOMMENDED_WEB_SEARCH = False
CONF_WEB_SEARCH_USER_LOCATION = "user_location"
RECOMMENDED_WEB_SEARCH_USER_LOCATION = False
CONF_WEB_SEARCH_MAX_USES = "web_search_max_uses"
RECOMMENDED_WEB_SEARCH_MAX_USES = 5
CONF_WEB_SEARCH_CITY = "city"
CONF_WEB_SEARCH_REGION = "region"
CONF_WEB_SEARCH_COUNTRY = "country"
CONF_WEB_SEARCH_TIMEZONE = "timezone"
NON_THINKING_MODELS = [
"claude-3-5", # Both sonnet and haiku
"claude-3-opus",
"claude-3-haiku",
]
WEB_SEARCH_UNSUPPORTED_MODELS = [
"claude-3-haiku",
"claude-3-opus",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
]

View File

@@ -1,12 +1,17 @@
"""Base entity for Anthropic."""
from collections.abc import AsyncGenerator, Callable, Iterable
from dataclasses import dataclass, field
import json
from typing import Any
import anthropic
from anthropic import AsyncStream
from anthropic.types import (
CitationsDelta,
CitationsWebSearchResultLocation,
CitationWebSearchResultLocationParam,
ContentBlockParam,
InputJSONDelta,
MessageDeltaUsage,
MessageParam,
@@ -16,11 +21,16 @@ from anthropic.types import (
RawContentBlockStopEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStopEvent,
RedactedThinkingBlock,
RedactedThinkingBlockParam,
ServerToolUseBlock,
ServerToolUseBlockParam,
SignatureDelta,
TextBlock,
TextBlockParam,
TextCitation,
TextCitationParam,
TextDelta,
ThinkingBlock,
ThinkingBlockParam,
@@ -29,9 +39,15 @@ from anthropic.types import (
ThinkingDelta,
ToolParam,
ToolResultBlockParam,
ToolUnionParam,
ToolUseBlock,
ToolUseBlockParam,
Usage,
WebSearchTool20250305Param,
WebSearchToolRequestErrorParam,
WebSearchToolResultBlock,
WebSearchToolResultBlockParam,
WebSearchToolResultError,
)
from anthropic.types.message_create_params import MessageCreateParamsStreaming
from voluptuous_openapi import convert
@@ -48,6 +64,13 @@ from .const import (
CONF_MAX_TOKENS,
CONF_TEMPERATURE,
CONF_THINKING_BUDGET,
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_CITY,
CONF_WEB_SEARCH_COUNTRY,
CONF_WEB_SEARCH_MAX_USES,
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DOMAIN,
LOGGER,
MIN_THINKING_BUDGET,
@@ -73,6 +96,69 @@ def _format_tool(
)
@dataclass(slots=True)
class CitationDetails:
"""Citation details for a content part."""
index: int = 0
"""Start position of the text."""
length: int = 0
"""Length of the relevant data."""
citations: list[TextCitationParam] = field(default_factory=list)
"""Citations for the content part."""
@dataclass(slots=True)
class ContentDetails:
"""Native data for AssistantContent."""
citation_details: list[CitationDetails] = field(default_factory=list)
def has_content(self) -> bool:
"""Check if there is any content."""
return any(detail.length > 0 for detail in self.citation_details)
def has_citations(self) -> bool:
"""Check if there are any citations."""
return any(detail.citations for detail in self.citation_details)
def add_citation_detail(self) -> None:
"""Add a new citation detail."""
if not self.citation_details or self.citation_details[-1].length > 0:
self.citation_details.append(
CitationDetails(
index=self.citation_details[-1].index
+ self.citation_details[-1].length
if self.citation_details
else 0
)
)
def add_citation(self, citation: TextCitation) -> None:
"""Add a citation to the current detail."""
if not self.citation_details:
self.citation_details.append(CitationDetails())
citation_param: TextCitationParam | None = None
if isinstance(citation, CitationsWebSearchResultLocation):
citation_param = CitationWebSearchResultLocationParam(
type="web_search_result_location",
title=citation.title,
url=citation.url,
cited_text=citation.cited_text,
encrypted_index=citation.encrypted_index,
)
if citation_param:
self.citation_details[-1].citations.append(citation_param)
def delete_empty(self) -> None:
"""Delete empty citation details."""
self.citation_details = [
detail for detail in self.citation_details if detail.citations
]
def _convert_content(
chat_content: Iterable[conversation.Content],
) -> list[MessageParam]:
@@ -81,15 +167,31 @@ def _convert_content(
for content in chat_content:
if isinstance(content, conversation.ToolResultContent):
tool_result_block = ToolResultBlockParam(
type="tool_result",
tool_use_id=content.tool_call_id,
content=json.dumps(content.tool_result),
)
if not messages or messages[-1]["role"] != "user":
if content.tool_name == "web_search":
tool_result_block: ContentBlockParam = WebSearchToolResultBlockParam(
type="web_search_tool_result",
tool_use_id=content.tool_call_id,
content=content.tool_result["content"]
if "content" in content.tool_result
else WebSearchToolRequestErrorParam(
type="web_search_tool_result_error",
error_code=content.tool_result.get("error_code", "unavailable"), # type: ignore[typeddict-item]
),
)
external_tool = True
else:
tool_result_block = ToolResultBlockParam(
type="tool_result",
tool_use_id=content.tool_call_id,
content=json.dumps(content.tool_result),
)
external_tool = False
if not messages or messages[-1]["role"] != (
"assistant" if external_tool else "user"
):
messages.append(
MessageParam(
role="user",
role="assistant" if external_tool else "user",
content=[tool_result_block],
)
)
@@ -151,13 +253,56 @@ def _convert_content(
redacted_thinking_block
)
if content.content:
messages[-1]["content"].append( # type: ignore[union-attr]
TextBlockParam(type="text", text=content.content)
)
current_index = 0
for detail in (
content.native.citation_details
if isinstance(content.native, ContentDetails)
else [CitationDetails(length=len(content.content))]
):
if detail.index > current_index:
# Add text block for any text without citations
messages[-1]["content"].append( # type: ignore[union-attr]
TextBlockParam(
type="text",
text=content.content[current_index : detail.index],
)
)
messages[-1]["content"].append( # type: ignore[union-attr]
TextBlockParam(
type="text",
text=content.content[
detail.index : detail.index + detail.length
],
citations=detail.citations,
)
if detail.citations
else TextBlockParam(
type="text",
text=content.content[
detail.index : detail.index + detail.length
],
)
)
current_index = detail.index + detail.length
if current_index < len(content.content):
# Add text block for any remaining text without citations
messages[-1]["content"].append( # type: ignore[union-attr]
TextBlockParam(
type="text",
text=content.content[current_index:],
)
)
if content.tool_calls:
messages[-1]["content"].extend( # type: ignore[union-attr]
[
ToolUseBlockParam(
ServerToolUseBlockParam(
type="server_tool_use",
id=tool_call.id,
name="web_search",
input=tool_call.tool_args,
)
if tool_call.external and tool_call.tool_name == "web_search"
else ToolUseBlockParam(
type="tool_use",
id=tool_call.id,
name=tool_call.tool_name,
@@ -173,10 +318,12 @@ def _convert_content(
return messages
async def _transform_stream(
async def _transform_stream( # noqa: C901 - This is complex, but better to have it in one place
chat_log: conversation.ChatLog,
stream: AsyncStream[MessageStreamEvent],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
) -> AsyncGenerator[
conversation.AssistantContentDeltaDict | conversation.ToolResultContentDeltaDict
]:
"""Transform the response stream into HA format.
A typical stream of responses might look something like the following:
@@ -209,11 +356,13 @@ async def _transform_stream(
if stream is None:
raise TypeError("Expected a stream of messages")
current_tool_block: ToolUseBlockParam | None = None
current_tool_block: ToolUseBlockParam | ServerToolUseBlockParam | None = None
current_tool_args: str
content_details = ContentDetails()
content_details.add_citation_detail()
input_usage: Usage | None = None
has_content = False
has_native = False
first_block: bool
async for response in stream:
LOGGER.debug("Received response: %s", response)
@@ -222,6 +371,7 @@ async def _transform_stream(
if response.message.role != "assistant":
raise ValueError("Unexpected message role")
input_usage = response.message.usage
first_block = True
elif isinstance(response, RawContentBlockStartEvent):
if isinstance(response.content_block, ToolUseBlock):
current_tool_block = ToolUseBlockParam(
@@ -232,17 +382,37 @@ async def _transform_stream(
)
current_tool_args = ""
elif isinstance(response.content_block, TextBlock):
if has_content:
if ( # Do not start a new assistant content just for citations, concatenate consecutive blocks with citations instead.
first_block
or (
not content_details.has_citations()
and response.content_block.citations is None
and content_details.has_content()
)
):
if content_details.has_citations():
content_details.delete_empty()
yield {"native": content_details}
content_details = ContentDetails()
yield {"role": "assistant"}
has_native = False
has_content = True
first_block = False
content_details.add_citation_detail()
if response.content_block.text:
content_details.citation_details[-1].length += len(
response.content_block.text
)
yield {"content": response.content_block.text}
elif isinstance(response.content_block, ThinkingBlock):
if has_native:
if first_block or has_native:
if content_details.has_citations():
content_details.delete_empty()
yield {"native": content_details}
content_details = ContentDetails()
content_details.add_citation_detail()
yield {"role": "assistant"}
has_native = False
has_content = False
first_block = False
elif isinstance(response.content_block, RedactedThinkingBlock):
LOGGER.debug(
"Some of Claudes internal reasoning has been automatically "
@@ -250,15 +420,60 @@ async def _transform_stream(
"responses"
)
if has_native:
if content_details.has_citations():
content_details.delete_empty()
yield {"native": content_details}
content_details = ContentDetails()
content_details.add_citation_detail()
yield {"role": "assistant"}
has_native = False
has_content = False
first_block = False
yield {"native": response.content_block}
has_native = True
elif isinstance(response.content_block, ServerToolUseBlock):
current_tool_block = ServerToolUseBlockParam(
type="server_tool_use",
id=response.content_block.id,
name=response.content_block.name,
input="",
)
current_tool_args = ""
elif isinstance(response.content_block, WebSearchToolResultBlock):
if content_details.has_citations():
content_details.delete_empty()
yield {"native": content_details}
content_details = ContentDetails()
content_details.add_citation_detail()
yield {
"role": "tool_result",
"tool_call_id": response.content_block.tool_use_id,
"tool_name": "web_search",
"tool_result": {
"type": "web_search_tool_result_error",
"error_code": response.content_block.content.error_code,
}
if isinstance(
response.content_block.content, WebSearchToolResultError
)
else {
"content": [
{
"type": "web_search_result",
"encrypted_content": block.encrypted_content,
"page_age": block.page_age,
"title": block.title,
"url": block.url,
}
for block in response.content_block.content
]
},
}
first_block = True
elif isinstance(response, RawContentBlockDeltaEvent):
if isinstance(response.delta, InputJSONDelta):
current_tool_args += response.delta.partial_json
elif isinstance(response.delta, TextDelta):
content_details.citation_details[-1].length += len(response.delta.text)
yield {"content": response.delta.text}
elif isinstance(response.delta, ThinkingDelta):
yield {"thinking_content": response.delta.thinking}
@@ -271,6 +486,8 @@ async def _transform_stream(
)
}
has_native = True
elif isinstance(response.delta, CitationsDelta):
content_details.add_citation(response.delta.citation)
elif isinstance(response, RawContentBlockStopEvent):
if current_tool_block is not None:
tool_args = json.loads(current_tool_args) if current_tool_args else {}
@@ -281,6 +498,7 @@ async def _transform_stream(
id=current_tool_block["id"],
tool_name=current_tool_block["name"],
tool_args=tool_args,
external=current_tool_block["type"] == "server_tool_use",
)
]
}
@@ -290,6 +508,12 @@ async def _transform_stream(
chat_log.async_trace(_create_token_stats(input_usage, usage))
if response.delta.stop_reason == "refusal":
raise HomeAssistantError("Potential policy violation detected")
elif isinstance(response, RawMessageStopEvent):
if content_details.has_citations():
content_details.delete_empty()
yield {"native": content_details}
content_details = ContentDetails()
content_details.add_citation_detail()
def _create_token_stats(
@@ -337,21 +561,11 @@ class AnthropicBaseLLMEntity(Entity):
"""Generate an answer for the chat log."""
options = self.subentry.data
tools: list[ToolParam] | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
system = chat_log.content[0]
if not isinstance(system, conversation.SystemContent):
raise TypeError("First message must be a system message")
messages = _convert_content(chat_log.content[1:])
client = self.entry.runtime_data
thinking_budget = options.get(CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET)
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
model_args = MessageCreateParamsStreaming(
@@ -361,8 +575,8 @@ class AnthropicBaseLLMEntity(Entity):
system=system.content,
stream=True,
)
if tools:
model_args["tools"] = tools
thinking_budget = options.get(CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET)
if (
not model.startswith(tuple(NON_THINKING_MODELS))
and thinking_budget >= MIN_THINKING_BUDGET
@@ -376,6 +590,34 @@ class AnthropicBaseLLMEntity(Entity):
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
)
tools: list[ToolUnionParam] = []
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
if options.get(CONF_WEB_SEARCH):
web_search = WebSearchTool20250305Param(
name="web_search",
type="web_search_20250305",
max_uses=options.get(CONF_WEB_SEARCH_MAX_USES),
)
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
web_search["user_location"] = {
"type": "approximate",
"city": options.get(CONF_WEB_SEARCH_CITY, ""),
"region": options.get(CONF_WEB_SEARCH_REGION, ""),
"country": options.get(CONF_WEB_SEARCH_COUNTRY, ""),
"timezone": options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
}
tools.append(web_search)
if tools:
model_args["tools"] = tools
client = self.entry.runtime_data
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:

View File

@@ -35,11 +35,17 @@
"temperature": "Temperature",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"recommended": "Recommended model settings",
"thinking_budget_tokens": "Thinking budget"
"thinking_budget": "Thinking budget",
"web_search": "Enable web search",
"web_search_max_uses": "Maximum web searches",
"user_location": "Include home location"
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template.",
"thinking_budget_tokens": "The number of tokens the model can use to think about the response out of the total maximum number of tokens. Set to 1024 or greater to enable extended thinking."
"thinking_budget": "The number of tokens the model can use to think about the response out of the total maximum number of tokens. Set to 1024 or greater to enable extended thinking.",
"web_search": "The web search tool gives Claude direct access to real-time web content, allowing it to answer questions with up-to-date information beyond its knowledge cutoff",
"web_search_max_uses": "Limit the number of searches performed per response",
"user_location": "Localize search results based on home location"
}
}
},
@@ -48,7 +54,8 @@
"entry_not_loaded": "Cannot add things while the configuration is disabled."
},
"error": {
"thinking_budget_too_large": "Maximum tokens must be greater than the thinking budget."
"thinking_budget_too_large": "Maximum tokens must be greater than the thinking budget.",
"web_search_unsupported_model": "Web search is not supported by the selected model. Please choose a compatible model or disable web search."
}
}
}

View File

@@ -6,7 +6,16 @@ from unittest.mock import patch
import pytest
from homeassistant.components.anthropic import CONF_CHAT_MODEL
from homeassistant.components.anthropic.const import DEFAULT_CONVERSATION_NAME
from homeassistant.components.anthropic.const import (
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_CITY,
CONF_WEB_SEARCH_COUNTRY,
CONF_WEB_SEARCH_MAX_USES,
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT_CONVERSATION_NAME,
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
@@ -55,7 +64,7 @@ def mock_config_entry_with_assist(
def mock_config_entry_with_extended_thinking(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> MockConfigEntry:
"""Mock a config entry with assist."""
"""Mock a config entry with extended thinking."""
hass.config_entries.async_update_subentry(
mock_config_entry,
next(iter(mock_config_entry.subentries.values())),
@@ -67,6 +76,29 @@ def mock_config_entry_with_extended_thinking(
return mock_config_entry
@pytest.fixture
def mock_config_entry_with_web_search(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> MockConfigEntry:
"""Mock a config entry with server tools enabled."""
hass.config_entries.async_update_subentry(
mock_config_entry,
next(iter(mock_config_entry.subentries.values())),
data={
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_CHAT_MODEL: "claude-sonnet-4-5",
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_MAX_USES: 5,
CONF_WEB_SEARCH_USER_LOCATION: True,
CONF_WEB_SEARCH_CITY: "San Francisco",
CONF_WEB_SEARCH_REGION: "California",
CONF_WEB_SEARCH_COUNTRY: "US",
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
},
)
return mock_config_entry
@pytest.fixture
async def mock_init_component(
hass: HomeAssistant, mock_config_entry: MockConfigEntry

View File

@@ -338,6 +338,123 @@
}),
])
# ---
# name: test_history_conversion[content5]
list([
dict({
'content': "What's on the news today?",
'role': 'user',
}),
dict({
'content': list([
dict({
'signature': 'ErU/V+ayA==',
'thinking': "The user is asking about today's news, which requires current, real-time information. This is clearly something that requires recent information beyond my knowledge cutoff. I should use the web_search tool to find today's news.",
'type': 'thinking',
}),
dict({
'text': "To get today's news, I'll perform a web search",
'type': 'text',
}),
dict({
'id': 'srvtoolu_12345ABC',
'input': dict({
'query': "today's news",
}),
'name': 'web_search',
'type': 'server_tool_use',
}),
dict({
'content': list([
dict({
'encrypted_content': 'ABCDEFG',
'page_age': '2 days ago',
'title': "Today's News - Example.com",
'type': 'web_search_result',
'url': 'https://www.example.com/todays-news',
}),
dict({
'encrypted_content': 'ABCDEFG',
'page_age': None,
'title': 'Breaking News - NewsSite.com',
'type': 'web_search_result',
'url': 'https://www.newssite.com/breaking-news',
}),
]),
'tool_use_id': 'srvtoolu_12345ABC',
'type': 'web_search_tool_result',
}),
dict({
'text': '''
Here's what I found on the web about today's news:
1.
''',
'type': 'text',
}),
dict({
'citations': list([
dict({
'cited_text': 'This release iterates on some of the features we introduced in the last couple of releases, but also...',
'encrypted_index': 'AAA==',
'title': 'Home Assistant Release',
'type': 'web_search_result_location',
'url': 'https://www.example.com/todays-news',
}),
]),
'text': 'New Home Assistant release',
'type': 'text',
}),
dict({
'text': '''
2.
''',
'type': 'text',
}),
dict({
'citations': list([
dict({
'cited_text': 'Breaking news from around the world today includes major events in technology, politics, and culture...',
'encrypted_index': 'AQE=',
'title': 'Breaking News',
'type': 'web_search_result_location',
'url': 'https://www.newssite.com/breaking-news',
}),
dict({
'cited_text': 'Well, this happened...',
'encrypted_index': 'AgI=',
'title': 'Breaking News',
'type': 'web_search_result_location',
'url': 'https://www.newssite.com/breaking-news',
}),
]),
'text': 'Something incredible happened',
'type': 'text',
}),
dict({
'text': '''
Those are the main headlines making news today.
''',
'type': 'text',
}),
]),
'role': 'assistant',
}),
dict({
'content': 'Are you sure?',
'role': 'user',
}),
dict({
'content': list([
dict({
'text': 'Yes, I am sure!',
'type': 'text',
}),
]),
'role': 'assistant',
}),
])
# ---
# name: test_redacted_thinking
list([
dict({
@@ -405,3 +522,102 @@
),
})
# ---
# name: test_web_search
list([
dict({
'attachments': None,
'content': "What's on the news today?",
'role': 'user',
}),
dict({
'agent_id': 'conversation.claude_conversation',
'content': "To get today's news, I'll perform a web search",
'native': ThinkingBlock(signature='ErUBCkYIARgCIkCYXaVNJShe3A86Hp7XUzh9YsCYBbJTbQsrklTAPtJ2sP/NoB6tSzpK/nTL6CjSo2R6n0KNBIg5MH6asM2R/kmaEgyB/X1FtZq5OQAC7jUaDEPWCdcwGQ4RaBy5wiIwmRxExIlDhoY6tILoVPnOExkC/0igZxHEwxK8RU/fmw0b+o+TwAarzUitwzbo21E5Kh3pa3I6yqVROf1t2F8rFocNUeCegsWV/ytwYV+ayA==', thinking='', type='thinking'),
'role': 'assistant',
'thinking_content': "The user is asking about today's news, which requires current, real-time information. This is clearly something that requires recent information beyond my knowledge cutoff. I should use the web_search tool to find today's news.",
'tool_calls': list([
dict({
'external': True,
'id': 'srvtoolu_12345ABC',
'tool_args': dict({
'query': "today's news",
}),
'tool_name': 'web_search',
}),
]),
}),
dict({
'agent_id': 'conversation.claude_conversation',
'role': 'tool_result',
'tool_call_id': 'srvtoolu_12345ABC',
'tool_name': 'web_search',
'tool_result': dict({
'content': list([
dict({
'encrypted_content': 'ABCDEFG',
'page_age': '2 days ago',
'title': "Today's News - Example.com",
'type': 'web_search_result',
'url': 'https://www.example.com/todays-news',
}),
dict({
'encrypted_content': 'ABCDEFG',
'page_age': None,
'title': 'Breaking News - NewsSite.com',
'type': 'web_search_result',
'url': 'https://www.newssite.com/breaking-news',
}),
]),
}),
}),
dict({
'agent_id': 'conversation.claude_conversation',
'content': '''
Here's what I found on the web about today's news:
1. New Home Assistant release
2. Something incredible happened
Those are the main headlines making news today.
''',
'native': dict({
'citation_details': list([
dict({
'citations': list([
dict({
'cited_text': 'This release iterates on some of the features we introduced in the last couple of releases, but also...',
'encrypted_index': 'AAA==',
'title': 'Home Assistant Release',
'type': 'web_search_result_location',
'url': 'https://www.example.com/todays-news',
}),
]),
'index': 54,
'length': 26,
}),
dict({
'citations': list([
dict({
'cited_text': 'Breaking news from around the world today includes major events in technology, politics, and culture...',
'encrypted_index': 'AQE=',
'title': 'Breaking News',
'type': 'web_search_result_location',
'url': 'https://www.newssite.com/breaking-news',
}),
dict({
'cited_text': 'Well, this happened...',
'encrypted_index': 'AgI=',
'title': 'Breaking News',
'type': 'web_search_result_location',
'url': 'https://www.newssite.com/breaking-news',
}),
]),
'index': 84,
'length': 29,
}),
]),
}),
'role': 'assistant',
'thinking_content': None,
'tool_calls': None,
}),
])
# ---

View File

@@ -9,6 +9,7 @@ from anthropic import (
AuthenticationError,
BadRequestError,
InternalServerError,
types,
)
from httpx import URL, Request, Response
import pytest
@@ -22,6 +23,9 @@ from homeassistant.components.anthropic.const import (
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_THINKING_BUDGET,
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_MAX_USES,
CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT_CONVERSATION_NAME,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
@@ -256,6 +260,103 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
assert result2["errors"] == {"base": error}
async def test_subentry_web_search_unsupported_model(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test error when enabling web search with unsupported model."""
subentry = next(iter(mock_config_entry.subentries.values()))
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_id
)
options = await hass.config_entries.subentries.async_configure(
options_flow["flow_id"],
{
"prompt": "You are a helpful assistant",
"max_tokens": 8192,
"chat_model": "claude-3-haiku-20240307",
"recommended": False,
"web_search": True,
"web_search_max_uses": 5,
},
)
await hass.async_block_till_done()
assert options["type"] is FlowResultType.FORM
assert options["errors"] == {"web_search": "web_search_unsupported_model"}
async def test_subentry_web_search_user_location(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test fetching user location."""
subentry = next(iter(mock_config_entry.subentries.values()))
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_id
)
hass.config.country = "US"
hass.config.time_zone = "America/Los_Angeles"
hass.states.async_set(
"zone.home", "0", {"latitude": 37.7749, "longitude": -122.4194}
)
with patch(
"anthropic.resources.messages.AsyncMessages.create",
new_callable=AsyncMock,
return_value=types.Message(
type="message",
id="mock_message_id",
role="assistant",
model="claude-sonnet-4-0",
usage=types.Usage(input_tokens=100, output_tokens=100),
content=[
types.TextBlock(
type="text", text='"city": "San Francisco", "region": "California"}'
)
],
),
) as mock_create:
options = await hass.config_entries.subentries.async_configure(
options_flow["flow_id"],
{
"prompt": "You are a helpful assistant",
"max_tokens": 8192,
"chat_model": "claude-sonnet-4-5",
"recommended": False,
"web_search": True,
"web_search_max_uses": 5,
"user_location": True,
},
)
await hass.async_block_till_done()
assert (
mock_create.call_args.kwargs["messages"][0]["content"] == "Where are the "
"following coordinates located: (37.7749, -122.4194)? Please respond only "
"with a JSON object using the following schema:\n"
"{'type': 'object', 'properties': {'city': {'type': 'string', 'description': "
"'Free text input for the city, e.g. `San Francisco`'}, 'region': {'type': "
"'string', 'description': 'Free text input for the region, e.g. `California`'"
"}}, 'required': []}"
)
assert options["type"] is FlowResultType.ABORT
assert options["reason"] == "reconfigure_successful"
assert subentry.data == {
"chat_model": "claude-sonnet-4-5",
"city": "San Francisco",
"country": "US",
"max_tokens": 8192,
"prompt": "You are a helpful assistant",
"recommended": False,
"region": "California",
"temperature": 1.0,
"thinking_budget": 0,
"timezone": "America/Los_Angeles",
"user_location": True,
"web_search": True,
"web_search_max_uses": 5,
}
@pytest.mark.parametrize(
("current_options", "new_options", "expected_options"),
[
@@ -277,6 +378,9 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_THINKING_BUDGET: RECOMMENDED_THINKING_BUDGET,
CONF_WEB_SEARCH: False,
CONF_WEB_SEARCH_MAX_USES: 5,
CONF_WEB_SEARCH_USER_LOCATION: False,
},
),
(
@@ -287,6 +391,9 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_THINKING_BUDGET: RECOMMENDED_THINKING_BUDGET,
CONF_WEB_SEARCH: False,
CONF_WEB_SEARCH_MAX_USES: 5,
CONF_WEB_SEARCH_USER_LOCATION: False,
},
{
CONF_RECOMMENDED: True,

View File

@@ -6,6 +6,9 @@ from unittest.mock import AsyncMock, Mock, patch
from anthropic import RateLimitError
from anthropic.types import (
CitationsDelta,
CitationsWebSearchResultLocation,
CitationWebSearchResultLocationParam,
InputJSONDelta,
Message,
MessageDeltaUsage,
@@ -17,13 +20,17 @@ from anthropic.types import (
RawMessageStopEvent,
RawMessageStreamEvent,
RedactedThinkingBlock,
ServerToolUseBlock,
SignatureDelta,
TextBlock,
TextCitation,
TextDelta,
ThinkingBlock,
ThinkingDelta,
ToolUseBlock,
Usage,
WebSearchResultBlock,
WebSearchToolResultBlock,
)
from anthropic.types.raw_message_delta_event import Delta
from freezegun import freeze_time
@@ -33,6 +40,7 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.anthropic.entity import CitationDetails, ContentDetails
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@@ -78,15 +86,25 @@ def create_messages(
def create_content_block(
index: int, text_parts: list[str]
index: int, text_parts: list[str], citations: list[TextCitation] | None = None
) -> list[RawMessageStreamEvent]:
"""Create a text content block with the specified deltas."""
return [
RawContentBlockStartEvent(
type="content_block_start",
content_block=TextBlock(text="", type="text"),
content_block=TextBlock(
text="", type="text", citations=[] if citations else None
),
index=index,
),
*[
RawContentBlockDeltaEvent(
delta=CitationsDelta(citation=citation, type="citations_delta"),
index=index,
type="content_block_delta",
)
for citation in (citations or [])
],
*[
RawContentBlockDeltaEvent(
delta=TextDelta(text=text_part, type="text_delta"),
@@ -174,6 +192,46 @@ def create_tool_use_block(
]
def create_web_search_block(
index: int, id: str, query_parts: list[str]
) -> list[RawMessageStreamEvent]:
"""Create a server tool use block for web search."""
return [
RawContentBlockStartEvent(
type="content_block_start",
content_block=ServerToolUseBlock(
type="server_tool_use", id=id, input={}, name="web_search"
),
index=index,
),
*[
RawContentBlockDeltaEvent(
delta=InputJSONDelta(type="input_json_delta", partial_json=query_part),
index=index,
type="content_block_delta",
)
for query_part in query_parts
],
RawContentBlockStopEvent(index=index, type="content_block_stop"),
]
def create_web_search_result_block(
index: int, id: str, results: list[WebSearchResultBlock]
) -> list[RawMessageStreamEvent]:
"""Create a server tool result block for web search results."""
return [
RawContentBlockStartEvent(
type="content_block_start",
content_block=WebSearchToolResultBlock(
type="web_search_tool_result", tool_use_id=id, content=results
),
index=index,
),
RawContentBlockStopEvent(index=index, type="content_block_stop"),
]
async def test_entity(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
@@ -850,6 +908,119 @@ async def test_extended_thinking_tool_call(
assert mock_create.mock_calls[1][2]["messages"] == snapshot
async def test_web_search(
hass: HomeAssistant,
mock_config_entry_with_web_search: MockConfigEntry,
mock_init_component,
snapshot: SnapshotAssertion,
) -> None:
"""Test web search."""
web_search_results = [
WebSearchResultBlock(
type="web_search_result",
title="Today's News - Example.com",
url="https://www.example.com/todays-news",
page_age="2 days ago",
encrypted_content="ABCDEFG",
),
WebSearchResultBlock(
type="web_search_result",
title="Breaking News - NewsSite.com",
url="https://www.newssite.com/breaking-news",
page_age=None,
encrypted_content="ABCDEFG",
),
]
with patch(
"anthropic.resources.messages.AsyncMessages.create",
new_callable=AsyncMock,
return_value=stream_generator(
create_messages(
[
*create_thinking_block(
0,
[
"The user is",
" asking about today's news, which",
" requires current, real-time information",
". This is clearly something that requires recent",
" information beyond my knowledge cutoff.",
" I should use the web",
"_search tool to fin",
"d today's news.",
],
),
*create_content_block(
1, ["To get today's news, I'll perform a web search"]
),
*create_web_search_block(
2,
"srvtoolu_12345ABC",
["", '{"que', 'ry"', ": \"today's", ' news"}'],
),
*create_web_search_result_block(
3, "srvtoolu_12345ABC", web_search_results
),
*create_content_block(
4,
["Here's what I found on the web about today's news:\n", "1. "],
),
*create_content_block(
5,
["New Home Assistant release"],
citations=[
CitationsWebSearchResultLocation(
type="web_search_result_location",
cited_text="This release iterates on some of the features we introduced in the last couple of releases, but also...",
encrypted_index="AAA==",
title="Home Assistant Release",
url="https://www.example.com/todays-news",
)
],
),
*create_content_block(6, ["\n2. "]),
*create_content_block(
7,
["Something incredible happened"],
citations=[
CitationsWebSearchResultLocation(
type="web_search_result_location",
cited_text="Breaking news from around the world today includes major events in technology, politics, and culture...",
encrypted_index="AQE=",
title="Breaking News",
url="https://www.newssite.com/breaking-news",
),
CitationsWebSearchResultLocation(
type="web_search_result_location",
cited_text="Well, this happened...",
encrypted_index="AgI=",
title="Breaking News",
url="https://www.newssite.com/breaking-news",
),
],
),
*create_content_block(
8, ["\nThose are the main headlines making news today."]
),
]
),
),
):
result = await conversation.async_converse(
hass,
"What's on the news today?",
None,
Context(),
agent_id="conversation.claude_conversation",
)
chat_log = hass.data.get(conversation.chat_log.DATA_CHAT_LOGS).get(
result.conversation_id
)
# Don't test the prompt because it's not deterministic
assert chat_log.content[1:] == snapshot
@pytest.mark.parametrize(
"content",
[
@@ -929,6 +1100,93 @@ async def test_extended_thinking_tool_call(
content="Should I add milk to the shopping list?",
),
],
[
conversation.chat_log.SystemContent("You are a helpful assistant."),
conversation.chat_log.UserContent("What's on the news today?"),
conversation.chat_log.AssistantContent(
agent_id="conversation.claude_conversation",
content="To get today's news, I'll perform a web search",
thinking_content="The user is asking about today's news, which requires current, real-time information. This is clearly something that requires recent information beyond my knowledge cutoff. I should use the web_search tool to find today's news.",
native=ThinkingBlock(
signature="ErU/V+ayA==", thinking="", type="thinking"
),
tool_calls=[
llm.ToolInput(
id="srvtoolu_12345ABC",
tool_name="web_search",
tool_args={"query": "today's news"},
external=True,
),
],
),
conversation.chat_log.ToolResultContent(
agent_id="conversation.claude_conversation",
tool_call_id="srvtoolu_12345ABC",
tool_name="web_search",
tool_result={
"content": [
{
"type": "web_search_result",
"title": "Today's News - Example.com",
"url": "https://www.example.com/todays-news",
"page_age": "2 days ago",
"encrypted_content": "ABCDEFG",
},
{
"type": "web_search_result",
"title": "Breaking News - NewsSite.com",
"url": "https://www.newssite.com/breaking-news",
"page_age": None,
"encrypted_content": "ABCDEFG",
},
]
},
),
conversation.chat_log.AssistantContent(
agent_id="conversation.claude_conversation",
content="Here's what I found on the web about today's news:\n"
"1. New Home Assistant release\n"
"2. Something incredible happened\n"
"Those are the main headlines making news today.",
native=ContentDetails(
citation_details=[
CitationDetails(
index=54,
length=26,
citations=[
CitationWebSearchResultLocationParam(
type="web_search_result_location",
cited_text="This release iterates on some of the features we introduced in the last couple of releases, but also...",
encrypted_index="AAA==",
title="Home Assistant Release",
url="https://www.example.com/todays-news",
),
],
),
CitationDetails(
index=84,
length=29,
citations=[
CitationWebSearchResultLocationParam(
type="web_search_result_location",
cited_text="Breaking news from around the world today includes major events in technology, politics, and culture...",
encrypted_index="AQE=",
title="Breaking News",
url="https://www.newssite.com/breaking-news",
),
CitationWebSearchResultLocationParam(
type="web_search_result_location",
cited_text="Well, this happened...",
encrypted_index="AgI=",
title="Breaking News",
url="https://www.newssite.com/breaking-news",
),
],
),
],
),
),
],
],
)
async def test_history_conversion(