mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
Stronger type annotations for conversation content (#140725)
stronger type annotations for conversation content
This commit is contained in:
parent
012b4645f3
commit
056616f9c5
@ -51,8 +51,7 @@ def async_get_chat_log(
|
||||
)
|
||||
if user_input is not None and (
|
||||
(content := chat_log.content[-1]).role != "user"
|
||||
# MyPy doesn't understand that content is a UserContent here
|
||||
or content.content != user_input.text # type: ignore[union-attr]
|
||||
or content.content != user_input.text
|
||||
):
|
||||
chat_log.async_add_user_content(UserContent(content=user_input.text))
|
||||
|
||||
@ -128,7 +127,7 @@ class ConverseError(HomeAssistantError):
|
||||
class SystemContent:
|
||||
"""Base class for chat messages."""
|
||||
|
||||
role: str = field(init=False, default="system")
|
||||
role: Literal["system"] = field(init=False, default="system")
|
||||
content: str
|
||||
|
||||
|
||||
@ -136,7 +135,7 @@ class SystemContent:
|
||||
class UserContent:
|
||||
"""Assistant content."""
|
||||
|
||||
role: str = field(init=False, default="user")
|
||||
role: Literal["user"] = field(init=False, default="user")
|
||||
content: str
|
||||
|
||||
|
||||
@ -144,7 +143,7 @@ class UserContent:
|
||||
class AssistantContent:
|
||||
"""Assistant content."""
|
||||
|
||||
role: str = field(init=False, default="assistant")
|
||||
role: Literal["assistant"] = field(init=False, default="assistant")
|
||||
agent_id: str
|
||||
content: str | None = None
|
||||
tool_calls: list[llm.ToolInput] | None = None
|
||||
@ -154,7 +153,7 @@ class AssistantContent:
|
||||
class ToolResultContent:
|
||||
"""Tool result content."""
|
||||
|
||||
role: str = field(init=False, default="tool_result")
|
||||
role: Literal["tool_result"] = field(init=False, default="tool_result")
|
||||
agent_id: str
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
@ -193,8 +192,8 @@ class ChatLog:
|
||||
|
||||
return (
|
||||
last_msg.role == "assistant"
|
||||
and last_msg.content is not None # type: ignore[union-attr]
|
||||
and last_msg.content.strip().endswith( # type: ignore[union-attr]
|
||||
and last_msg.content is not None
|
||||
and last_msg.content.strip().endswith(
|
||||
(
|
||||
"?",
|
||||
";", # Greek question mark
|
||||
|
@ -188,7 +188,7 @@ def _convert_content(
|
||||
| conversation.SystemContent,
|
||||
) -> Content:
|
||||
"""Convert HA content to Google content."""
|
||||
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
|
||||
if content.role != "assistant" or not content.tool_calls:
|
||||
role = "model" if content.role == "assistant" else content.role
|
||||
return Content(
|
||||
role=role,
|
||||
@ -321,24 +321,14 @@ class GoogleGenerativeAIConversationEntity(
|
||||
|
||||
for chat_content in chat_log.content[1:-1]:
|
||||
if chat_content.role == "tool_result":
|
||||
# mypy doesn't like picking a type based on checking shared property 'role'
|
||||
tool_results.append(cast(conversation.ToolResultContent, chat_content))
|
||||
tool_results.append(chat_content)
|
||||
continue
|
||||
|
||||
if tool_results:
|
||||
messages.append(_create_google_tool_response_content(tool_results))
|
||||
tool_results.clear()
|
||||
|
||||
messages.append(
|
||||
_convert_content(
|
||||
cast(
|
||||
conversation.UserContent
|
||||
| conversation.SystemContent
|
||||
| conversation.AssistantContent,
|
||||
chat_content,
|
||||
)
|
||||
)
|
||||
)
|
||||
messages.append(_convert_content(chat_content))
|
||||
|
||||
if tool_results:
|
||||
messages.append(_create_google_tool_response_content(tool_results))
|
||||
|
@ -82,13 +82,13 @@ def _convert_content_to_param(
|
||||
tool_call_id=content.tool_call_id,
|
||||
content=json.dumps(content.tool_result),
|
||||
)
|
||||
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
|
||||
role = content.role
|
||||
if content.role != "assistant" or not content.tool_calls:
|
||||
role: Literal["system", "user", "assistant", "developer"] = content.role
|
||||
if role == "system":
|
||||
role = "developer"
|
||||
return cast(
|
||||
ChatCompletionMessageParam,
|
||||
{"role": content.role, "content": content.content}, # type: ignore[union-attr]
|
||||
{"role": content.role, "content": content.content},
|
||||
)
|
||||
|
||||
# Handle the Assistant content including tool calls.
|
||||
|
Loading…
x
Reference in New Issue
Block a user