mirror of
https://github.com/home-assistant/core.git
synced 2026-04-27 05:16:51 +00:00
Anthropic web search support (#153753)
This commit is contained in:
@@ -4,12 +4,15 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
import anthropic
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components.zone import ENTITY_ID_HOME
|
||||
from homeassistant.config_entries import (
|
||||
ConfigEntry,
|
||||
ConfigEntryState,
|
||||
@@ -18,7 +21,13 @@ from homeassistant.config_entries import (
|
||||
ConfigSubentryFlow,
|
||||
SubentryFlowResult,
|
||||
)
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME
|
||||
from homeassistant.const import (
|
||||
ATTR_LATITUDE,
|
||||
ATTR_LONGITUDE,
|
||||
CONF_API_KEY,
|
||||
CONF_LLM_HASS_API,
|
||||
CONF_NAME,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.selector import (
|
||||
@@ -37,12 +46,23 @@ from .const import (
|
||||
CONF_RECOMMENDED,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_THINKING_BUDGET,
|
||||
CONF_WEB_SEARCH,
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
CONF_WEB_SEARCH_COUNTRY,
|
||||
CONF_WEB_SEARCH_MAX_USES,
|
||||
CONF_WEB_SEARCH_REGION,
|
||||
CONF_WEB_SEARCH_TIMEZONE,
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_THINKING_BUDGET,
|
||||
RECOMMENDED_WEB_SEARCH,
|
||||
RECOMMENDED_WEB_SEARCH_MAX_USES,
|
||||
RECOMMENDED_WEB_SEARCH_USER_LOCATION,
|
||||
WEB_SEARCH_UNSUPPORTED_MODELS,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -168,6 +188,14 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET
|
||||
) >= user_input.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS):
|
||||
errors[CONF_THINKING_BUDGET] = "thinking_budget_too_large"
|
||||
if user_input.get(CONF_WEB_SEARCH, RECOMMENDED_WEB_SEARCH):
|
||||
model = user_input.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
if model.startswith(tuple(WEB_SEARCH_UNSUPPORTED_MODELS)):
|
||||
errors[CONF_WEB_SEARCH] = "web_search_unsupported_model"
|
||||
elif user_input.get(
|
||||
CONF_WEB_SEARCH_USER_LOCATION, RECOMMENDED_WEB_SEARCH_USER_LOCATION
|
||||
):
|
||||
user_input.update(await self._get_location_data())
|
||||
|
||||
if not errors:
|
||||
if self._is_new:
|
||||
@@ -215,6 +243,68 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
errors=errors or None,
|
||||
)
|
||||
|
||||
async def _get_location_data(self) -> dict[str, str]:
|
||||
"""Get approximate location data of the user."""
|
||||
location_data: dict[str, str] = {}
|
||||
zone_home = self.hass.states.get(ENTITY_ID_HOME)
|
||||
if zone_home is not None:
|
||||
client = await self.hass.async_add_executor_job(
|
||||
partial(
|
||||
anthropic.AsyncAnthropic,
|
||||
api_key=self._get_entry().data[CONF_API_KEY],
|
||||
)
|
||||
)
|
||||
location_schema = vol.Schema(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
description="Free text input for the city, e.g. `San Francisco`",
|
||||
): str,
|
||||
vol.Optional(
|
||||
CONF_WEB_SEARCH_REGION,
|
||||
description="Free text input for the region, e.g. `California`",
|
||||
): str,
|
||||
}
|
||||
)
|
||||
response = await client.messages.create(
|
||||
model=RECOMMENDED_CHAT_MODEL,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Where are the following coordinates located: "
|
||||
f"({zone_home.attributes[ATTR_LATITUDE]},"
|
||||
f" {zone_home.attributes[ATTR_LONGITUDE]})? Please respond "
|
||||
"only with a JSON object using the following schema:\n"
|
||||
f"{convert(location_schema)}",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{", # hints the model to skip any preamble
|
||||
},
|
||||
],
|
||||
max_tokens=RECOMMENDED_MAX_TOKENS,
|
||||
)
|
||||
_LOGGER.debug("Model response: %s", response.content)
|
||||
location_data = location_schema(
|
||||
json.loads(
|
||||
"{"
|
||||
+ "".join(
|
||||
block.text
|
||||
for block in response.content
|
||||
if isinstance(block, anthropic.types.TextBlock)
|
||||
)
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
if self.hass.config.country:
|
||||
location_data[CONF_WEB_SEARCH_COUNTRY] = self.hass.config.country
|
||||
location_data[CONF_WEB_SEARCH_TIMEZONE] = self.hass.config.time_zone
|
||||
|
||||
_LOGGER.debug("Location data: %s", location_data)
|
||||
|
||||
return location_data
|
||||
|
||||
async_step_user = async_step_set_options
|
||||
async_step_reconfigure = async_step_set_options
|
||||
|
||||
@@ -273,6 +363,18 @@ def anthropic_config_option_schema(
|
||||
CONF_THINKING_BUDGET,
|
||||
default=RECOMMENDED_THINKING_BUDGET,
|
||||
): int,
|
||||
vol.Optional(
|
||||
CONF_WEB_SEARCH,
|
||||
default=RECOMMENDED_WEB_SEARCH,
|
||||
): bool,
|
||||
vol.Optional(
|
||||
CONF_WEB_SEARCH_MAX_USES,
|
||||
default=RECOMMENDED_WEB_SEARCH_MAX_USES,
|
||||
): int,
|
||||
vol.Optional(
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
default=RECOMMENDED_WEB_SEARCH_USER_LOCATION,
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
return schema
|
||||
|
||||
@@ -18,9 +18,26 @@ RECOMMENDED_TEMPERATURE = 1.0
|
||||
CONF_THINKING_BUDGET = "thinking_budget"
|
||||
RECOMMENDED_THINKING_BUDGET = 0
|
||||
MIN_THINKING_BUDGET = 1024
|
||||
CONF_WEB_SEARCH = "web_search"
|
||||
RECOMMENDED_WEB_SEARCH = False
|
||||
CONF_WEB_SEARCH_USER_LOCATION = "user_location"
|
||||
RECOMMENDED_WEB_SEARCH_USER_LOCATION = False
|
||||
CONF_WEB_SEARCH_MAX_USES = "web_search_max_uses"
|
||||
RECOMMENDED_WEB_SEARCH_MAX_USES = 5
|
||||
CONF_WEB_SEARCH_CITY = "city"
|
||||
CONF_WEB_SEARCH_REGION = "region"
|
||||
CONF_WEB_SEARCH_COUNTRY = "country"
|
||||
CONF_WEB_SEARCH_TIMEZONE = "timezone"
|
||||
|
||||
NON_THINKING_MODELS = [
|
||||
"claude-3-5", # Both sonnet and haiku
|
||||
"claude-3-opus",
|
||||
"claude-3-haiku",
|
||||
]
|
||||
|
||||
WEB_SEARCH_UNSUPPORTED_MODELS = [
|
||||
"claude-3-haiku",
|
||||
"claude-3-opus",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
]
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
"""Base entity for Anthropic."""
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable, Iterable
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import anthropic
|
||||
from anthropic import AsyncStream
|
||||
from anthropic.types import (
|
||||
CitationsDelta,
|
||||
CitationsWebSearchResultLocation,
|
||||
CitationWebSearchResultLocationParam,
|
||||
ContentBlockParam,
|
||||
InputJSONDelta,
|
||||
MessageDeltaUsage,
|
||||
MessageParam,
|
||||
@@ -16,11 +21,16 @@ from anthropic.types import (
|
||||
RawContentBlockStopEvent,
|
||||
RawMessageDeltaEvent,
|
||||
RawMessageStartEvent,
|
||||
RawMessageStopEvent,
|
||||
RedactedThinkingBlock,
|
||||
RedactedThinkingBlockParam,
|
||||
ServerToolUseBlock,
|
||||
ServerToolUseBlockParam,
|
||||
SignatureDelta,
|
||||
TextBlock,
|
||||
TextBlockParam,
|
||||
TextCitation,
|
||||
TextCitationParam,
|
||||
TextDelta,
|
||||
ThinkingBlock,
|
||||
ThinkingBlockParam,
|
||||
@@ -29,9 +39,15 @@ from anthropic.types import (
|
||||
ThinkingDelta,
|
||||
ToolParam,
|
||||
ToolResultBlockParam,
|
||||
ToolUnionParam,
|
||||
ToolUseBlock,
|
||||
ToolUseBlockParam,
|
||||
Usage,
|
||||
WebSearchTool20250305Param,
|
||||
WebSearchToolRequestErrorParam,
|
||||
WebSearchToolResultBlock,
|
||||
WebSearchToolResultBlockParam,
|
||||
WebSearchToolResultError,
|
||||
)
|
||||
from anthropic.types.message_create_params import MessageCreateParamsStreaming
|
||||
from voluptuous_openapi import convert
|
||||
@@ -48,6 +64,13 @@ from .const import (
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_THINKING_BUDGET,
|
||||
CONF_WEB_SEARCH,
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
CONF_WEB_SEARCH_COUNTRY,
|
||||
CONF_WEB_SEARCH_MAX_USES,
|
||||
CONF_WEB_SEARCH_REGION,
|
||||
CONF_WEB_SEARCH_TIMEZONE,
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
MIN_THINKING_BUDGET,
|
||||
@@ -73,6 +96,69 @@ def _format_tool(
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CitationDetails:
|
||||
"""Citation details for a content part."""
|
||||
|
||||
index: int = 0
|
||||
"""Start position of the text."""
|
||||
|
||||
length: int = 0
|
||||
"""Length of the relevant data."""
|
||||
|
||||
citations: list[TextCitationParam] = field(default_factory=list)
|
||||
"""Citations for the content part."""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ContentDetails:
|
||||
"""Native data for AssistantContent."""
|
||||
|
||||
citation_details: list[CitationDetails] = field(default_factory=list)
|
||||
|
||||
def has_content(self) -> bool:
|
||||
"""Check if there is any content."""
|
||||
return any(detail.length > 0 for detail in self.citation_details)
|
||||
|
||||
def has_citations(self) -> bool:
|
||||
"""Check if there are any citations."""
|
||||
return any(detail.citations for detail in self.citation_details)
|
||||
|
||||
def add_citation_detail(self) -> None:
|
||||
"""Add a new citation detail."""
|
||||
if not self.citation_details or self.citation_details[-1].length > 0:
|
||||
self.citation_details.append(
|
||||
CitationDetails(
|
||||
index=self.citation_details[-1].index
|
||||
+ self.citation_details[-1].length
|
||||
if self.citation_details
|
||||
else 0
|
||||
)
|
||||
)
|
||||
|
||||
def add_citation(self, citation: TextCitation) -> None:
|
||||
"""Add a citation to the current detail."""
|
||||
if not self.citation_details:
|
||||
self.citation_details.append(CitationDetails())
|
||||
citation_param: TextCitationParam | None = None
|
||||
if isinstance(citation, CitationsWebSearchResultLocation):
|
||||
citation_param = CitationWebSearchResultLocationParam(
|
||||
type="web_search_result_location",
|
||||
title=citation.title,
|
||||
url=citation.url,
|
||||
cited_text=citation.cited_text,
|
||||
encrypted_index=citation.encrypted_index,
|
||||
)
|
||||
if citation_param:
|
||||
self.citation_details[-1].citations.append(citation_param)
|
||||
|
||||
def delete_empty(self) -> None:
|
||||
"""Delete empty citation details."""
|
||||
self.citation_details = [
|
||||
detail for detail in self.citation_details if detail.citations
|
||||
]
|
||||
|
||||
|
||||
def _convert_content(
|
||||
chat_content: Iterable[conversation.Content],
|
||||
) -> list[MessageParam]:
|
||||
@@ -81,15 +167,31 @@ def _convert_content(
|
||||
|
||||
for content in chat_content:
|
||||
if isinstance(content, conversation.ToolResultContent):
|
||||
tool_result_block = ToolResultBlockParam(
|
||||
type="tool_result",
|
||||
tool_use_id=content.tool_call_id,
|
||||
content=json.dumps(content.tool_result),
|
||||
)
|
||||
if not messages or messages[-1]["role"] != "user":
|
||||
if content.tool_name == "web_search":
|
||||
tool_result_block: ContentBlockParam = WebSearchToolResultBlockParam(
|
||||
type="web_search_tool_result",
|
||||
tool_use_id=content.tool_call_id,
|
||||
content=content.tool_result["content"]
|
||||
if "content" in content.tool_result
|
||||
else WebSearchToolRequestErrorParam(
|
||||
type="web_search_tool_result_error",
|
||||
error_code=content.tool_result.get("error_code", "unavailable"), # type: ignore[typeddict-item]
|
||||
),
|
||||
)
|
||||
external_tool = True
|
||||
else:
|
||||
tool_result_block = ToolResultBlockParam(
|
||||
type="tool_result",
|
||||
tool_use_id=content.tool_call_id,
|
||||
content=json.dumps(content.tool_result),
|
||||
)
|
||||
external_tool = False
|
||||
if not messages or messages[-1]["role"] != (
|
||||
"assistant" if external_tool else "user"
|
||||
):
|
||||
messages.append(
|
||||
MessageParam(
|
||||
role="user",
|
||||
role="assistant" if external_tool else "user",
|
||||
content=[tool_result_block],
|
||||
)
|
||||
)
|
||||
@@ -151,13 +253,56 @@ def _convert_content(
|
||||
redacted_thinking_block
|
||||
)
|
||||
if content.content:
|
||||
messages[-1]["content"].append( # type: ignore[union-attr]
|
||||
TextBlockParam(type="text", text=content.content)
|
||||
)
|
||||
current_index = 0
|
||||
for detail in (
|
||||
content.native.citation_details
|
||||
if isinstance(content.native, ContentDetails)
|
||||
else [CitationDetails(length=len(content.content))]
|
||||
):
|
||||
if detail.index > current_index:
|
||||
# Add text block for any text without citations
|
||||
messages[-1]["content"].append( # type: ignore[union-attr]
|
||||
TextBlockParam(
|
||||
type="text",
|
||||
text=content.content[current_index : detail.index],
|
||||
)
|
||||
)
|
||||
messages[-1]["content"].append( # type: ignore[union-attr]
|
||||
TextBlockParam(
|
||||
type="text",
|
||||
text=content.content[
|
||||
detail.index : detail.index + detail.length
|
||||
],
|
||||
citations=detail.citations,
|
||||
)
|
||||
if detail.citations
|
||||
else TextBlockParam(
|
||||
type="text",
|
||||
text=content.content[
|
||||
detail.index : detail.index + detail.length
|
||||
],
|
||||
)
|
||||
)
|
||||
current_index = detail.index + detail.length
|
||||
if current_index < len(content.content):
|
||||
# Add text block for any remaining text without citations
|
||||
messages[-1]["content"].append( # type: ignore[union-attr]
|
||||
TextBlockParam(
|
||||
type="text",
|
||||
text=content.content[current_index:],
|
||||
)
|
||||
)
|
||||
if content.tool_calls:
|
||||
messages[-1]["content"].extend( # type: ignore[union-attr]
|
||||
[
|
||||
ToolUseBlockParam(
|
||||
ServerToolUseBlockParam(
|
||||
type="server_tool_use",
|
||||
id=tool_call.id,
|
||||
name="web_search",
|
||||
input=tool_call.tool_args,
|
||||
)
|
||||
if tool_call.external and tool_call.tool_name == "web_search"
|
||||
else ToolUseBlockParam(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_call.tool_name,
|
||||
@@ -173,10 +318,12 @@ def _convert_content(
|
||||
return messages
|
||||
|
||||
|
||||
async def _transform_stream(
|
||||
async def _transform_stream( # noqa: C901 - This is complex, but better to have it in one place
|
||||
chat_log: conversation.ChatLog,
|
||||
stream: AsyncStream[MessageStreamEvent],
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
) -> AsyncGenerator[
|
||||
conversation.AssistantContentDeltaDict | conversation.ToolResultContentDeltaDict
|
||||
]:
|
||||
"""Transform the response stream into HA format.
|
||||
|
||||
A typical stream of responses might look something like the following:
|
||||
@@ -209,11 +356,13 @@ async def _transform_stream(
|
||||
if stream is None:
|
||||
raise TypeError("Expected a stream of messages")
|
||||
|
||||
current_tool_block: ToolUseBlockParam | None = None
|
||||
current_tool_block: ToolUseBlockParam | ServerToolUseBlockParam | None = None
|
||||
current_tool_args: str
|
||||
content_details = ContentDetails()
|
||||
content_details.add_citation_detail()
|
||||
input_usage: Usage | None = None
|
||||
has_content = False
|
||||
has_native = False
|
||||
first_block: bool
|
||||
|
||||
async for response in stream:
|
||||
LOGGER.debug("Received response: %s", response)
|
||||
@@ -222,6 +371,7 @@ async def _transform_stream(
|
||||
if response.message.role != "assistant":
|
||||
raise ValueError("Unexpected message role")
|
||||
input_usage = response.message.usage
|
||||
first_block = True
|
||||
elif isinstance(response, RawContentBlockStartEvent):
|
||||
if isinstance(response.content_block, ToolUseBlock):
|
||||
current_tool_block = ToolUseBlockParam(
|
||||
@@ -232,17 +382,37 @@ async def _transform_stream(
|
||||
)
|
||||
current_tool_args = ""
|
||||
elif isinstance(response.content_block, TextBlock):
|
||||
if has_content:
|
||||
if ( # Do not start a new assistant content just for citations, concatenate consecutive blocks with citations instead.
|
||||
first_block
|
||||
or (
|
||||
not content_details.has_citations()
|
||||
and response.content_block.citations is None
|
||||
and content_details.has_content()
|
||||
)
|
||||
):
|
||||
if content_details.has_citations():
|
||||
content_details.delete_empty()
|
||||
yield {"native": content_details}
|
||||
content_details = ContentDetails()
|
||||
yield {"role": "assistant"}
|
||||
has_native = False
|
||||
has_content = True
|
||||
first_block = False
|
||||
content_details.add_citation_detail()
|
||||
if response.content_block.text:
|
||||
content_details.citation_details[-1].length += len(
|
||||
response.content_block.text
|
||||
)
|
||||
yield {"content": response.content_block.text}
|
||||
elif isinstance(response.content_block, ThinkingBlock):
|
||||
if has_native:
|
||||
if first_block or has_native:
|
||||
if content_details.has_citations():
|
||||
content_details.delete_empty()
|
||||
yield {"native": content_details}
|
||||
content_details = ContentDetails()
|
||||
content_details.add_citation_detail()
|
||||
yield {"role": "assistant"}
|
||||
has_native = False
|
||||
has_content = False
|
||||
first_block = False
|
||||
elif isinstance(response.content_block, RedactedThinkingBlock):
|
||||
LOGGER.debug(
|
||||
"Some of Claude’s internal reasoning has been automatically "
|
||||
@@ -250,15 +420,60 @@ async def _transform_stream(
|
||||
"responses"
|
||||
)
|
||||
if has_native:
|
||||
if content_details.has_citations():
|
||||
content_details.delete_empty()
|
||||
yield {"native": content_details}
|
||||
content_details = ContentDetails()
|
||||
content_details.add_citation_detail()
|
||||
yield {"role": "assistant"}
|
||||
has_native = False
|
||||
has_content = False
|
||||
first_block = False
|
||||
yield {"native": response.content_block}
|
||||
has_native = True
|
||||
elif isinstance(response.content_block, ServerToolUseBlock):
|
||||
current_tool_block = ServerToolUseBlockParam(
|
||||
type="server_tool_use",
|
||||
id=response.content_block.id,
|
||||
name=response.content_block.name,
|
||||
input="",
|
||||
)
|
||||
current_tool_args = ""
|
||||
elif isinstance(response.content_block, WebSearchToolResultBlock):
|
||||
if content_details.has_citations():
|
||||
content_details.delete_empty()
|
||||
yield {"native": content_details}
|
||||
content_details = ContentDetails()
|
||||
content_details.add_citation_detail()
|
||||
yield {
|
||||
"role": "tool_result",
|
||||
"tool_call_id": response.content_block.tool_use_id,
|
||||
"tool_name": "web_search",
|
||||
"tool_result": {
|
||||
"type": "web_search_tool_result_error",
|
||||
"error_code": response.content_block.content.error_code,
|
||||
}
|
||||
if isinstance(
|
||||
response.content_block.content, WebSearchToolResultError
|
||||
)
|
||||
else {
|
||||
"content": [
|
||||
{
|
||||
"type": "web_search_result",
|
||||
"encrypted_content": block.encrypted_content,
|
||||
"page_age": block.page_age,
|
||||
"title": block.title,
|
||||
"url": block.url,
|
||||
}
|
||||
for block in response.content_block.content
|
||||
]
|
||||
},
|
||||
}
|
||||
first_block = True
|
||||
elif isinstance(response, RawContentBlockDeltaEvent):
|
||||
if isinstance(response.delta, InputJSONDelta):
|
||||
current_tool_args += response.delta.partial_json
|
||||
elif isinstance(response.delta, TextDelta):
|
||||
content_details.citation_details[-1].length += len(response.delta.text)
|
||||
yield {"content": response.delta.text}
|
||||
elif isinstance(response.delta, ThinkingDelta):
|
||||
yield {"thinking_content": response.delta.thinking}
|
||||
@@ -271,6 +486,8 @@ async def _transform_stream(
|
||||
)
|
||||
}
|
||||
has_native = True
|
||||
elif isinstance(response.delta, CitationsDelta):
|
||||
content_details.add_citation(response.delta.citation)
|
||||
elif isinstance(response, RawContentBlockStopEvent):
|
||||
if current_tool_block is not None:
|
||||
tool_args = json.loads(current_tool_args) if current_tool_args else {}
|
||||
@@ -281,6 +498,7 @@ async def _transform_stream(
|
||||
id=current_tool_block["id"],
|
||||
tool_name=current_tool_block["name"],
|
||||
tool_args=tool_args,
|
||||
external=current_tool_block["type"] == "server_tool_use",
|
||||
)
|
||||
]
|
||||
}
|
||||
@@ -290,6 +508,12 @@ async def _transform_stream(
|
||||
chat_log.async_trace(_create_token_stats(input_usage, usage))
|
||||
if response.delta.stop_reason == "refusal":
|
||||
raise HomeAssistantError("Potential policy violation detected")
|
||||
elif isinstance(response, RawMessageStopEvent):
|
||||
if content_details.has_citations():
|
||||
content_details.delete_empty()
|
||||
yield {"native": content_details}
|
||||
content_details = ContentDetails()
|
||||
content_details.add_citation_detail()
|
||||
|
||||
|
||||
def _create_token_stats(
|
||||
@@ -337,21 +561,11 @@ class AnthropicBaseLLMEntity(Entity):
|
||||
"""Generate an answer for the chat log."""
|
||||
options = self.subentry.data
|
||||
|
||||
tools: list[ToolParam] | None = None
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
system = chat_log.content[0]
|
||||
if not isinstance(system, conversation.SystemContent):
|
||||
raise TypeError("First message must be a system message")
|
||||
messages = _convert_content(chat_log.content[1:])
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
||||
thinking_budget = options.get(CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET)
|
||||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
|
||||
model_args = MessageCreateParamsStreaming(
|
||||
@@ -361,8 +575,8 @@ class AnthropicBaseLLMEntity(Entity):
|
||||
system=system.content,
|
||||
stream=True,
|
||||
)
|
||||
if tools:
|
||||
model_args["tools"] = tools
|
||||
|
||||
thinking_budget = options.get(CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET)
|
||||
if (
|
||||
not model.startswith(tuple(NON_THINKING_MODELS))
|
||||
and thinking_budget >= MIN_THINKING_BUDGET
|
||||
@@ -376,6 +590,34 @@ class AnthropicBaseLLMEntity(Entity):
|
||||
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
|
||||
)
|
||||
|
||||
tools: list[ToolUnionParam] = []
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
if options.get(CONF_WEB_SEARCH):
|
||||
web_search = WebSearchTool20250305Param(
|
||||
name="web_search",
|
||||
type="web_search_20250305",
|
||||
max_uses=options.get(CONF_WEB_SEARCH_MAX_USES),
|
||||
)
|
||||
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
|
||||
web_search["user_location"] = {
|
||||
"type": "approximate",
|
||||
"city": options.get(CONF_WEB_SEARCH_CITY, ""),
|
||||
"region": options.get(CONF_WEB_SEARCH_REGION, ""),
|
||||
"country": options.get(CONF_WEB_SEARCH_COUNTRY, ""),
|
||||
"timezone": options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
|
||||
}
|
||||
tools.append(web_search)
|
||||
|
||||
if tools:
|
||||
model_args["tools"] = tools
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
|
||||
@@ -35,11 +35,17 @@
|
||||
"temperature": "Temperature",
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||
"recommended": "Recommended model settings",
|
||||
"thinking_budget_tokens": "Thinking budget"
|
||||
"thinking_budget": "Thinking budget",
|
||||
"web_search": "Enable web search",
|
||||
"web_search_max_uses": "Maximum web searches",
|
||||
"user_location": "Include home location"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "Instruct how the LLM should respond. This can be a template.",
|
||||
"thinking_budget_tokens": "The number of tokens the model can use to think about the response out of the total maximum number of tokens. Set to 1024 or greater to enable extended thinking."
|
||||
"thinking_budget": "The number of tokens the model can use to think about the response out of the total maximum number of tokens. Set to 1024 or greater to enable extended thinking.",
|
||||
"web_search": "The web search tool gives Claude direct access to real-time web content, allowing it to answer questions with up-to-date information beyond its knowledge cutoff",
|
||||
"web_search_max_uses": "Limit the number of searches performed per response",
|
||||
"user_location": "Localize search results based on home location"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -48,7 +54,8 @@
|
||||
"entry_not_loaded": "Cannot add things while the configuration is disabled."
|
||||
},
|
||||
"error": {
|
||||
"thinking_budget_too_large": "Maximum tokens must be greater than the thinking budget."
|
||||
"thinking_budget_too_large": "Maximum tokens must be greater than the thinking budget.",
|
||||
"web_search_unsupported_model": "Web search is not supported by the selected model. Please choose a compatible model or disable web search."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,16 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.anthropic import CONF_CHAT_MODEL
|
||||
from homeassistant.components.anthropic.const import DEFAULT_CONVERSATION_NAME
|
||||
from homeassistant.components.anthropic.const import (
|
||||
CONF_WEB_SEARCH,
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
CONF_WEB_SEARCH_COUNTRY,
|
||||
CONF_WEB_SEARCH_MAX_USES,
|
||||
CONF_WEB_SEARCH_REGION,
|
||||
CONF_WEB_SEARCH_TIMEZONE,
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
)
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
@@ -55,7 +64,7 @@ def mock_config_entry_with_assist(
|
||||
def mock_config_entry_with_extended_thinking(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> MockConfigEntry:
|
||||
"""Mock a config entry with assist."""
|
||||
"""Mock a config entry with extended thinking."""
|
||||
hass.config_entries.async_update_subentry(
|
||||
mock_config_entry,
|
||||
next(iter(mock_config_entry.subentries.values())),
|
||||
@@ -67,6 +76,29 @@ def mock_config_entry_with_extended_thinking(
|
||||
return mock_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry_with_web_search(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> MockConfigEntry:
|
||||
"""Mock a config entry with server tools enabled."""
|
||||
hass.config_entries.async_update_subentry(
|
||||
mock_config_entry,
|
||||
next(iter(mock_config_entry.subentries.values())),
|
||||
data={
|
||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||
CONF_CHAT_MODEL: "claude-sonnet-4-5",
|
||||
CONF_WEB_SEARCH: True,
|
||||
CONF_WEB_SEARCH_MAX_USES: 5,
|
||||
CONF_WEB_SEARCH_USER_LOCATION: True,
|
||||
CONF_WEB_SEARCH_CITY: "San Francisco",
|
||||
CONF_WEB_SEARCH_REGION: "California",
|
||||
CONF_WEB_SEARCH_COUNTRY: "US",
|
||||
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
||||
},
|
||||
)
|
||||
return mock_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_init_component(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
|
||||
@@ -338,6 +338,123 @@
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_history_conversion[content5]
|
||||
list([
|
||||
dict({
|
||||
'content': "What's on the news today?",
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'content': list([
|
||||
dict({
|
||||
'signature': 'ErU/V+ayA==',
|
||||
'thinking': "The user is asking about today's news, which requires current, real-time information. This is clearly something that requires recent information beyond my knowledge cutoff. I should use the web_search tool to find today's news.",
|
||||
'type': 'thinking',
|
||||
}),
|
||||
dict({
|
||||
'text': "To get today's news, I'll perform a web search",
|
||||
'type': 'text',
|
||||
}),
|
||||
dict({
|
||||
'id': 'srvtoolu_12345ABC',
|
||||
'input': dict({
|
||||
'query': "today's news",
|
||||
}),
|
||||
'name': 'web_search',
|
||||
'type': 'server_tool_use',
|
||||
}),
|
||||
dict({
|
||||
'content': list([
|
||||
dict({
|
||||
'encrypted_content': 'ABCDEFG',
|
||||
'page_age': '2 days ago',
|
||||
'title': "Today's News - Example.com",
|
||||
'type': 'web_search_result',
|
||||
'url': 'https://www.example.com/todays-news',
|
||||
}),
|
||||
dict({
|
||||
'encrypted_content': 'ABCDEFG',
|
||||
'page_age': None,
|
||||
'title': 'Breaking News - NewsSite.com',
|
||||
'type': 'web_search_result',
|
||||
'url': 'https://www.newssite.com/breaking-news',
|
||||
}),
|
||||
]),
|
||||
'tool_use_id': 'srvtoolu_12345ABC',
|
||||
'type': 'web_search_tool_result',
|
||||
}),
|
||||
dict({
|
||||
'text': '''
|
||||
Here's what I found on the web about today's news:
|
||||
1.
|
||||
''',
|
||||
'type': 'text',
|
||||
}),
|
||||
dict({
|
||||
'citations': list([
|
||||
dict({
|
||||
'cited_text': 'This release iterates on some of the features we introduced in the last couple of releases, but also...',
|
||||
'encrypted_index': 'AAA==',
|
||||
'title': 'Home Assistant Release',
|
||||
'type': 'web_search_result_location',
|
||||
'url': 'https://www.example.com/todays-news',
|
||||
}),
|
||||
]),
|
||||
'text': 'New Home Assistant release',
|
||||
'type': 'text',
|
||||
}),
|
||||
dict({
|
||||
'text': '''
|
||||
|
||||
2.
|
||||
''',
|
||||
'type': 'text',
|
||||
}),
|
||||
dict({
|
||||
'citations': list([
|
||||
dict({
|
||||
'cited_text': 'Breaking news from around the world today includes major events in technology, politics, and culture...',
|
||||
'encrypted_index': 'AQE=',
|
||||
'title': 'Breaking News',
|
||||
'type': 'web_search_result_location',
|
||||
'url': 'https://www.newssite.com/breaking-news',
|
||||
}),
|
||||
dict({
|
||||
'cited_text': 'Well, this happened...',
|
||||
'encrypted_index': 'AgI=',
|
||||
'title': 'Breaking News',
|
||||
'type': 'web_search_result_location',
|
||||
'url': 'https://www.newssite.com/breaking-news',
|
||||
}),
|
||||
]),
|
||||
'text': 'Something incredible happened',
|
||||
'type': 'text',
|
||||
}),
|
||||
dict({
|
||||
'text': '''
|
||||
|
||||
Those are the main headlines making news today.
|
||||
''',
|
||||
'type': 'text',
|
||||
}),
|
||||
]),
|
||||
'role': 'assistant',
|
||||
}),
|
||||
dict({
|
||||
'content': 'Are you sure?',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'content': list([
|
||||
dict({
|
||||
'text': 'Yes, I am sure!',
|
||||
'type': 'text',
|
||||
}),
|
||||
]),
|
||||
'role': 'assistant',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_redacted_thinking
|
||||
list([
|
||||
dict({
|
||||
@@ -405,3 +522,102 @@
|
||||
),
|
||||
})
|
||||
# ---
|
||||
# name: test_web_search
|
||||
list([
|
||||
dict({
|
||||
'attachments': None,
|
||||
'content': "What's on the news today?",
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.claude_conversation',
|
||||
'content': "To get today's news, I'll perform a web search",
|
||||
'native': ThinkingBlock(signature='ErUBCkYIARgCIkCYXaVNJShe3A86Hp7XUzh9YsCYBbJTbQsrklTAPtJ2sP/NoB6tSzpK/nTL6CjSo2R6n0KNBIg5MH6asM2R/kmaEgyB/X1FtZq5OQAC7jUaDEPWCdcwGQ4RaBy5wiIwmRxExIlDhoY6tILoVPnOExkC/0igZxHEwxK8RU/fmw0b+o+TwAarzUitwzbo21E5Kh3pa3I6yqVROf1t2F8rFocNUeCegsWV/ytwYV+ayA==', thinking='', type='thinking'),
|
||||
'role': 'assistant',
|
||||
'thinking_content': "The user is asking about today's news, which requires current, real-time information. This is clearly something that requires recent information beyond my knowledge cutoff. I should use the web_search tool to find today's news.",
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'external': True,
|
||||
'id': 'srvtoolu_12345ABC',
|
||||
'tool_args': dict({
|
||||
'query': "today's news",
|
||||
}),
|
||||
'tool_name': 'web_search',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.claude_conversation',
|
||||
'role': 'tool_result',
|
||||
'tool_call_id': 'srvtoolu_12345ABC',
|
||||
'tool_name': 'web_search',
|
||||
'tool_result': dict({
|
||||
'content': list([
|
||||
dict({
|
||||
'encrypted_content': 'ABCDEFG',
|
||||
'page_age': '2 days ago',
|
||||
'title': "Today's News - Example.com",
|
||||
'type': 'web_search_result',
|
||||
'url': 'https://www.example.com/todays-news',
|
||||
}),
|
||||
dict({
|
||||
'encrypted_content': 'ABCDEFG',
|
||||
'page_age': None,
|
||||
'title': 'Breaking News - NewsSite.com',
|
||||
'type': 'web_search_result',
|
||||
'url': 'https://www.newssite.com/breaking-news',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.claude_conversation',
|
||||
'content': '''
|
||||
Here's what I found on the web about today's news:
|
||||
1. New Home Assistant release
|
||||
2. Something incredible happened
|
||||
Those are the main headlines making news today.
|
||||
''',
|
||||
'native': dict({
|
||||
'citation_details': list([
|
||||
dict({
|
||||
'citations': list([
|
||||
dict({
|
||||
'cited_text': 'This release iterates on some of the features we introduced in the last couple of releases, but also...',
|
||||
'encrypted_index': 'AAA==',
|
||||
'title': 'Home Assistant Release',
|
||||
'type': 'web_search_result_location',
|
||||
'url': 'https://www.example.com/todays-news',
|
||||
}),
|
||||
]),
|
||||
'index': 54,
|
||||
'length': 26,
|
||||
}),
|
||||
dict({
|
||||
'citations': list([
|
||||
dict({
|
||||
'cited_text': 'Breaking news from around the world today includes major events in technology, politics, and culture...',
|
||||
'encrypted_index': 'AQE=',
|
||||
'title': 'Breaking News',
|
||||
'type': 'web_search_result_location',
|
||||
'url': 'https://www.newssite.com/breaking-news',
|
||||
}),
|
||||
dict({
|
||||
'cited_text': 'Well, this happened...',
|
||||
'encrypted_index': 'AgI=',
|
||||
'title': 'Breaking News',
|
||||
'type': 'web_search_result_location',
|
||||
'url': 'https://www.newssite.com/breaking-news',
|
||||
}),
|
||||
]),
|
||||
'index': 84,
|
||||
'length': 29,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
|
||||
@@ -9,6 +9,7 @@ from anthropic import (
|
||||
AuthenticationError,
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
types,
|
||||
)
|
||||
from httpx import URL, Request, Response
|
||||
import pytest
|
||||
@@ -22,6 +23,9 @@ from homeassistant.components.anthropic.const import (
|
||||
CONF_RECOMMENDED,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_THINKING_BUDGET,
|
||||
CONF_WEB_SEARCH,
|
||||
CONF_WEB_SEARCH_MAX_USES,
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
@@ -256,6 +260,103 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
assert result2["errors"] == {"base": error}
|
||||
|
||||
|
||||
async def test_subentry_web_search_unsupported_model(
|
||||
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||
) -> None:
|
||||
"""Test error when enabling web search with unsupported model."""
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
|
||||
hass, subentry.subentry_id
|
||||
)
|
||||
options = await hass.config_entries.subentries.async_configure(
|
||||
options_flow["flow_id"],
|
||||
{
|
||||
"prompt": "You are a helpful assistant",
|
||||
"max_tokens": 8192,
|
||||
"chat_model": "claude-3-haiku-20240307",
|
||||
"recommended": False,
|
||||
"web_search": True,
|
||||
"web_search_max_uses": 5,
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert options["type"] is FlowResultType.FORM
|
||||
assert options["errors"] == {"web_search": "web_search_unsupported_model"}
|
||||
|
||||
|
||||
async def test_subentry_web_search_user_location(
|
||||
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||
) -> None:
|
||||
"""Test fetching user location."""
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
|
||||
hass, subentry.subentry_id
|
||||
)
|
||||
|
||||
hass.config.country = "US"
|
||||
hass.config.time_zone = "America/Los_Angeles"
|
||||
hass.states.async_set(
|
||||
"zone.home", "0", {"latitude": 37.7749, "longitude": -122.4194}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"anthropic.resources.messages.AsyncMessages.create",
|
||||
new_callable=AsyncMock,
|
||||
return_value=types.Message(
|
||||
type="message",
|
||||
id="mock_message_id",
|
||||
role="assistant",
|
||||
model="claude-sonnet-4-0",
|
||||
usage=types.Usage(input_tokens=100, output_tokens=100),
|
||||
content=[
|
||||
types.TextBlock(
|
||||
type="text", text='"city": "San Francisco", "region": "California"}'
|
||||
)
|
||||
],
|
||||
),
|
||||
) as mock_create:
|
||||
options = await hass.config_entries.subentries.async_configure(
|
||||
options_flow["flow_id"],
|
||||
{
|
||||
"prompt": "You are a helpful assistant",
|
||||
"max_tokens": 8192,
|
||||
"chat_model": "claude-sonnet-4-5",
|
||||
"recommended": False,
|
||||
"web_search": True,
|
||||
"web_search_max_uses": 5,
|
||||
"user_location": True,
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert (
|
||||
mock_create.call_args.kwargs["messages"][0]["content"] == "Where are the "
|
||||
"following coordinates located: (37.7749, -122.4194)? Please respond only "
|
||||
"with a JSON object using the following schema:\n"
|
||||
"{'type': 'object', 'properties': {'city': {'type': 'string', 'description': "
|
||||
"'Free text input for the city, e.g. `San Francisco`'}, 'region': {'type': "
|
||||
"'string', 'description': 'Free text input for the region, e.g. `California`'"
|
||||
"}}, 'required': []}"
|
||||
)
|
||||
assert options["type"] is FlowResultType.ABORT
|
||||
assert options["reason"] == "reconfigure_successful"
|
||||
assert subentry.data == {
|
||||
"chat_model": "claude-sonnet-4-5",
|
||||
"city": "San Francisco",
|
||||
"country": "US",
|
||||
"max_tokens": 8192,
|
||||
"prompt": "You are a helpful assistant",
|
||||
"recommended": False,
|
||||
"region": "California",
|
||||
"temperature": 1.0,
|
||||
"thinking_budget": 0,
|
||||
"timezone": "America/Los_Angeles",
|
||||
"user_location": True,
|
||||
"web_search": True,
|
||||
"web_search_max_uses": 5,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("current_options", "new_options", "expected_options"),
|
||||
[
|
||||
@@ -277,6 +378,9 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||
CONF_THINKING_BUDGET: RECOMMENDED_THINKING_BUDGET,
|
||||
CONF_WEB_SEARCH: False,
|
||||
CONF_WEB_SEARCH_MAX_USES: 5,
|
||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||
},
|
||||
),
|
||||
(
|
||||
@@ -287,6 +391,9 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||
CONF_THINKING_BUDGET: RECOMMENDED_THINKING_BUDGET,
|
||||
CONF_WEB_SEARCH: False,
|
||||
CONF_WEB_SEARCH_MAX_USES: 5,
|
||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||
},
|
||||
{
|
||||
CONF_RECOMMENDED: True,
|
||||
|
||||
@@ -6,6 +6,9 @@ from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from anthropic import RateLimitError
|
||||
from anthropic.types import (
|
||||
CitationsDelta,
|
||||
CitationsWebSearchResultLocation,
|
||||
CitationWebSearchResultLocationParam,
|
||||
InputJSONDelta,
|
||||
Message,
|
||||
MessageDeltaUsage,
|
||||
@@ -17,13 +20,17 @@ from anthropic.types import (
|
||||
RawMessageStopEvent,
|
||||
RawMessageStreamEvent,
|
||||
RedactedThinkingBlock,
|
||||
ServerToolUseBlock,
|
||||
SignatureDelta,
|
||||
TextBlock,
|
||||
TextCitation,
|
||||
TextDelta,
|
||||
ThinkingBlock,
|
||||
ThinkingDelta,
|
||||
ToolUseBlock,
|
||||
Usage,
|
||||
WebSearchResultBlock,
|
||||
WebSearchToolResultBlock,
|
||||
)
|
||||
from anthropic.types.raw_message_delta_event import Delta
|
||||
from freezegun import freeze_time
|
||||
@@ -33,6 +40,7 @@ from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.anthropic.entity import CitationDetails, ContentDetails
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@@ -78,15 +86,25 @@ def create_messages(
|
||||
|
||||
|
||||
def create_content_block(
|
||||
index: int, text_parts: list[str]
|
||||
index: int, text_parts: list[str], citations: list[TextCitation] | None = None
|
||||
) -> list[RawMessageStreamEvent]:
|
||||
"""Create a text content block with the specified deltas."""
|
||||
return [
|
||||
RawContentBlockStartEvent(
|
||||
type="content_block_start",
|
||||
content_block=TextBlock(text="", type="text"),
|
||||
content_block=TextBlock(
|
||||
text="", type="text", citations=[] if citations else None
|
||||
),
|
||||
index=index,
|
||||
),
|
||||
*[
|
||||
RawContentBlockDeltaEvent(
|
||||
delta=CitationsDelta(citation=citation, type="citations_delta"),
|
||||
index=index,
|
||||
type="content_block_delta",
|
||||
)
|
||||
for citation in (citations or [])
|
||||
],
|
||||
*[
|
||||
RawContentBlockDeltaEvent(
|
||||
delta=TextDelta(text=text_part, type="text_delta"),
|
||||
@@ -174,6 +192,46 @@ def create_tool_use_block(
|
||||
]
|
||||
|
||||
|
||||
def create_web_search_block(
|
||||
index: int, id: str, query_parts: list[str]
|
||||
) -> list[RawMessageStreamEvent]:
|
||||
"""Create a server tool use block for web search."""
|
||||
return [
|
||||
RawContentBlockStartEvent(
|
||||
type="content_block_start",
|
||||
content_block=ServerToolUseBlock(
|
||||
type="server_tool_use", id=id, input={}, name="web_search"
|
||||
),
|
||||
index=index,
|
||||
),
|
||||
*[
|
||||
RawContentBlockDeltaEvent(
|
||||
delta=InputJSONDelta(type="input_json_delta", partial_json=query_part),
|
||||
index=index,
|
||||
type="content_block_delta",
|
||||
)
|
||||
for query_part in query_parts
|
||||
],
|
||||
RawContentBlockStopEvent(index=index, type="content_block_stop"),
|
||||
]
|
||||
|
||||
|
||||
def create_web_search_result_block(
|
||||
index: int, id: str, results: list[WebSearchResultBlock]
|
||||
) -> list[RawMessageStreamEvent]:
|
||||
"""Create a server tool result block for web search results."""
|
||||
return [
|
||||
RawContentBlockStartEvent(
|
||||
type="content_block_start",
|
||||
content_block=WebSearchToolResultBlock(
|
||||
type="web_search_tool_result", tool_use_id=id, content=results
|
||||
),
|
||||
index=index,
|
||||
),
|
||||
RawContentBlockStopEvent(index=index, type="content_block_stop"),
|
||||
]
|
||||
|
||||
|
||||
async def test_entity(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
@@ -850,6 +908,119 @@ async def test_extended_thinking_tool_call(
|
||||
assert mock_create.mock_calls[1][2]["messages"] == snapshot
|
||||
|
||||
|
||||
async def test_web_search(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_web_search: MockConfigEntry,
|
||||
mock_init_component,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test web search."""
|
||||
web_search_results = [
|
||||
WebSearchResultBlock(
|
||||
type="web_search_result",
|
||||
title="Today's News - Example.com",
|
||||
url="https://www.example.com/todays-news",
|
||||
page_age="2 days ago",
|
||||
encrypted_content="ABCDEFG",
|
||||
),
|
||||
WebSearchResultBlock(
|
||||
type="web_search_result",
|
||||
title="Breaking News - NewsSite.com",
|
||||
url="https://www.newssite.com/breaking-news",
|
||||
page_age=None,
|
||||
encrypted_content="ABCDEFG",
|
||||
),
|
||||
]
|
||||
with patch(
|
||||
"anthropic.resources.messages.AsyncMessages.create",
|
||||
new_callable=AsyncMock,
|
||||
return_value=stream_generator(
|
||||
create_messages(
|
||||
[
|
||||
*create_thinking_block(
|
||||
0,
|
||||
[
|
||||
"The user is",
|
||||
" asking about today's news, which",
|
||||
" requires current, real-time information",
|
||||
". This is clearly something that requires recent",
|
||||
" information beyond my knowledge cutoff.",
|
||||
" I should use the web",
|
||||
"_search tool to fin",
|
||||
"d today's news.",
|
||||
],
|
||||
),
|
||||
*create_content_block(
|
||||
1, ["To get today's news, I'll perform a web search"]
|
||||
),
|
||||
*create_web_search_block(
|
||||
2,
|
||||
"srvtoolu_12345ABC",
|
||||
["", '{"que', 'ry"', ": \"today's", ' news"}'],
|
||||
),
|
||||
*create_web_search_result_block(
|
||||
3, "srvtoolu_12345ABC", web_search_results
|
||||
),
|
||||
*create_content_block(
|
||||
4,
|
||||
["Here's what I found on the web about today's news:\n", "1. "],
|
||||
),
|
||||
*create_content_block(
|
||||
5,
|
||||
["New Home Assistant release"],
|
||||
citations=[
|
||||
CitationsWebSearchResultLocation(
|
||||
type="web_search_result_location",
|
||||
cited_text="This release iterates on some of the features we introduced in the last couple of releases, but also...",
|
||||
encrypted_index="AAA==",
|
||||
title="Home Assistant Release",
|
||||
url="https://www.example.com/todays-news",
|
||||
)
|
||||
],
|
||||
),
|
||||
*create_content_block(6, ["\n2. "]),
|
||||
*create_content_block(
|
||||
7,
|
||||
["Something incredible happened"],
|
||||
citations=[
|
||||
CitationsWebSearchResultLocation(
|
||||
type="web_search_result_location",
|
||||
cited_text="Breaking news from around the world today includes major events in technology, politics, and culture...",
|
||||
encrypted_index="AQE=",
|
||||
title="Breaking News",
|
||||
url="https://www.newssite.com/breaking-news",
|
||||
),
|
||||
CitationsWebSearchResultLocation(
|
||||
type="web_search_result_location",
|
||||
cited_text="Well, this happened...",
|
||||
encrypted_index="AgI=",
|
||||
title="Breaking News",
|
||||
url="https://www.newssite.com/breaking-news",
|
||||
),
|
||||
],
|
||||
),
|
||||
*create_content_block(
|
||||
8, ["\nThose are the main headlines making news today."]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"What's on the news today?",
|
||||
None,
|
||||
Context(),
|
||||
agent_id="conversation.claude_conversation",
|
||||
)
|
||||
|
||||
chat_log = hass.data.get(conversation.chat_log.DATA_CHAT_LOGS).get(
|
||||
result.conversation_id
|
||||
)
|
||||
# Don't test the prompt because it's not deterministic
|
||||
assert chat_log.content[1:] == snapshot
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content",
|
||||
[
|
||||
@@ -929,6 +1100,93 @@ async def test_extended_thinking_tool_call(
|
||||
content="Should I add milk to the shopping list?",
|
||||
),
|
||||
],
|
||||
[
|
||||
conversation.chat_log.SystemContent("You are a helpful assistant."),
|
||||
conversation.chat_log.UserContent("What's on the news today?"),
|
||||
conversation.chat_log.AssistantContent(
|
||||
agent_id="conversation.claude_conversation",
|
||||
content="To get today's news, I'll perform a web search",
|
||||
thinking_content="The user is asking about today's news, which requires current, real-time information. This is clearly something that requires recent information beyond my knowledge cutoff. I should use the web_search tool to find today's news.",
|
||||
native=ThinkingBlock(
|
||||
signature="ErU/V+ayA==", thinking="", type="thinking"
|
||||
),
|
||||
tool_calls=[
|
||||
llm.ToolInput(
|
||||
id="srvtoolu_12345ABC",
|
||||
tool_name="web_search",
|
||||
tool_args={"query": "today's news"},
|
||||
external=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
conversation.chat_log.ToolResultContent(
|
||||
agent_id="conversation.claude_conversation",
|
||||
tool_call_id="srvtoolu_12345ABC",
|
||||
tool_name="web_search",
|
||||
tool_result={
|
||||
"content": [
|
||||
{
|
||||
"type": "web_search_result",
|
||||
"title": "Today's News - Example.com",
|
||||
"url": "https://www.example.com/todays-news",
|
||||
"page_age": "2 days ago",
|
||||
"encrypted_content": "ABCDEFG",
|
||||
},
|
||||
{
|
||||
"type": "web_search_result",
|
||||
"title": "Breaking News - NewsSite.com",
|
||||
"url": "https://www.newssite.com/breaking-news",
|
||||
"page_age": None,
|
||||
"encrypted_content": "ABCDEFG",
|
||||
},
|
||||
]
|
||||
},
|
||||
),
|
||||
conversation.chat_log.AssistantContent(
|
||||
agent_id="conversation.claude_conversation",
|
||||
content="Here's what I found on the web about today's news:\n"
|
||||
"1. New Home Assistant release\n"
|
||||
"2. Something incredible happened\n"
|
||||
"Those are the main headlines making news today.",
|
||||
native=ContentDetails(
|
||||
citation_details=[
|
||||
CitationDetails(
|
||||
index=54,
|
||||
length=26,
|
||||
citations=[
|
||||
CitationWebSearchResultLocationParam(
|
||||
type="web_search_result_location",
|
||||
cited_text="This release iterates on some of the features we introduced in the last couple of releases, but also...",
|
||||
encrypted_index="AAA==",
|
||||
title="Home Assistant Release",
|
||||
url="https://www.example.com/todays-news",
|
||||
),
|
||||
],
|
||||
),
|
||||
CitationDetails(
|
||||
index=84,
|
||||
length=29,
|
||||
citations=[
|
||||
CitationWebSearchResultLocationParam(
|
||||
type="web_search_result_location",
|
||||
cited_text="Breaking news from around the world today includes major events in technology, politics, and culture...",
|
||||
encrypted_index="AQE=",
|
||||
title="Breaking News",
|
||||
url="https://www.newssite.com/breaking-news",
|
||||
),
|
||||
CitationWebSearchResultLocationParam(
|
||||
type="web_search_result_location",
|
||||
cited_text="Well, this happened...",
|
||||
encrypted_index="AgI=",
|
||||
title="Breaking News",
|
||||
url="https://www.newssite.com/breaking-news",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
],
|
||||
],
|
||||
)
|
||||
async def test_history_conversion(
|
||||
|
||||
Reference in New Issue
Block a user