Improve conversation typing (#136084)

This commit is contained in:
Marc Mueller 2025-01-20 15:32:18 +01:00 committed by GitHub
parent a7d5e52ffe
commit 29b7d5c2e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,11 +1,13 @@
"""Conversation history."""
from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field, replace
from datetime import datetime, timedelta
import logging
from typing import Generic, Literal, TypeVar
from typing import Literal
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import (
@ -25,16 +27,15 @@ from homeassistant.util.hass_dict import HassKey
from .const import DOMAIN
from .models import ConversationInput, ConversationResult
DATA_CHAT_HISTORY: HassKey["dict[str, ChatSession]"] = HassKey(
DATA_CHAT_HISTORY: HassKey[dict[str, ChatSession]] = HassKey(
"conversation_chat_session"
)
DATA_CHAT_HISTORY_CLEANUP: HassKey["SessionCleanup"] = HassKey(
DATA_CHAT_HISTORY_CLEANUP: HassKey[SessionCleanup] = HassKey(
"conversation_chat_session_cleanup"
)
LOGGER = logging.getLogger(__name__)
CONVERSATION_TIMEOUT = timedelta(minutes=5)
_NativeT = TypeVar("_NativeT")
class SessionCleanup:
@ -89,7 +90,7 @@ class SessionCleanup:
async def async_get_chat_session(
hass: HomeAssistant,
user_input: ConversationInput,
) -> AsyncGenerator["ChatSession"]:
) -> AsyncGenerator[ChatSession]:
"""Return chat session."""
all_history = hass.data.get(DATA_CHAT_HISTORY)
if all_history is None:
@ -164,7 +165,7 @@ class ConverseError(HomeAssistantError):
@dataclass
class ChatMessage(Generic[_NativeT]):
class ChatMessage[_NativeT]:
"""Base class for chat messages.
When role is native, the content is to be ignored and message
@ -184,7 +185,7 @@ class ChatMessage(Generic[_NativeT]):
@dataclass
class ChatSession(Generic[_NativeT]):
class ChatSession[_NativeT]:
"""Class holding all information for a specific conversation."""
hass: HomeAssistant