Merge branch 'dev' into dev

This commit is contained in:
Jonathan Sady do Nascimento 2025-02-03 12:54:57 -03:00 committed by GitHub
commit 3ceb25adae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
84 changed files with 1802 additions and 1296 deletions

View File

@ -7,6 +7,6 @@
"integration_type": "service", "integration_type": "service",
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"loggers": ["python_homeassistant_analytics"], "loggers": ["python_homeassistant_analytics"],
"requirements": ["python-homeassistant-analytics==0.8.1"], "requirements": ["python-homeassistant-analytics==0.9.0"],
"single_config_entry": true "single_config_entry": true
} }

View File

@ -272,6 +272,7 @@ class AnthropicConversationEntity(
continue continue
tool_input = llm.ToolInput( tool_input = llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.name, tool_name=tool_call.name,
tool_args=cast(dict[str, Any], tool_call.input), tool_args=cast(dict[str, Any], tool_call.input),
) )

View File

@ -9,6 +9,7 @@ import voluptuous as vol
from homeassistant.components import stt from homeassistant.components import stt
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import chat_session
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import (
@ -114,8 +115,9 @@ async def async_pipeline_from_audio_stream(
Raises PipelineNotFound if no pipeline is found. Raises PipelineNotFound if no pipeline is found.
""" """
with chat_session.async_get_chat_session(hass, conversation_id) as session:
pipeline_input = PipelineInput( pipeline_input = PipelineInput(
conversation_id=conversation_id, conversation_id=session.conversation_id,
device_id=device_id, device_id=device_id,
stt_metadata=stt_metadata, stt_metadata=stt_metadata,
stt_stream=stt_stream, stt_stream=stt_stream,

View File

@ -624,7 +624,7 @@ class PipelineRun:
return return
pipeline_data.pipeline_debug[self.pipeline.id][self.id].events.append(event) pipeline_data.pipeline_debug[self.pipeline.id][self.id].events.append(event)
def start(self, device_id: str | None) -> None: def start(self, conversation_id: str, device_id: str | None) -> None:
"""Emit run start event.""" """Emit run start event."""
self._device_id = device_id self._device_id = device_id
self._start_debug_recording_thread() self._start_debug_recording_thread()
@ -632,6 +632,7 @@ class PipelineRun:
data = { data = {
"pipeline": self.pipeline.id, "pipeline": self.pipeline.id,
"language": self.language, "language": self.language,
"conversation_id": conversation_id,
} }
if self.runner_data is not None: if self.runner_data is not None:
data["runner_data"] = self.runner_data data["runner_data"] = self.runner_data
@ -1015,7 +1016,7 @@ class PipelineRun:
async def recognize_intent( async def recognize_intent(
self, self,
intent_input: str, intent_input: str,
conversation_id: str | None, conversation_id: str,
device_id: str | None, device_id: str | None,
conversation_extra_system_prompt: str | None, conversation_extra_system_prompt: str | None,
) -> str: ) -> str:
@ -1063,11 +1064,11 @@ class PipelineRun:
agent_id=self.intent_agent, agent_id=self.intent_agent,
extra_system_prompt=conversation_extra_system_prompt, extra_system_prompt=conversation_extra_system_prompt,
) )
processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT
agent_id = user_input.agent_id agent_id = self.intent_agent
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
intent_response: intent.IntentResponse | None = None intent_response: intent.IntentResponse | None = None
if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT: if not processed_locally:
# Sentence triggers override conversation agent # Sentence triggers override conversation agent
if ( if (
trigger_response_text trigger_response_text
@ -1105,13 +1106,13 @@ class PipelineRun:
speech: str = intent_response.speech.get("plain", {}).get( speech: str = intent_response.speech.get("plain", {}).get(
"speech", "" "speech", ""
) )
chat_log.async_add_message( async for _ in chat_log.async_add_assistant_content(
conversation.Content( conversation.AssistantContent(
role="assistant",
agent_id=agent_id, agent_id=agent_id,
content=speech, content=speech,
) )
) ):
pass
conversation_result = conversation.ConversationResult( conversation_result = conversation.ConversationResult(
response=intent_response, response=intent_response,
conversation_id=session.conversation_id, conversation_id=session.conversation_id,
@ -1409,12 +1410,15 @@ def _pipeline_debug_recording_thread_proc(
wav_writer.close() wav_writer.close()
@dataclass @dataclass(kw_only=True)
class PipelineInput: class PipelineInput:
"""Input to a pipeline run.""" """Input to a pipeline run."""
run: PipelineRun run: PipelineRun
conversation_id: str
"""Identifier for the conversation."""
stt_metadata: stt.SpeechMetadata | None = None stt_metadata: stt.SpeechMetadata | None = None
"""Metadata of stt input audio. Required when start_stage = stt.""" """Metadata of stt input audio. Required when start_stage = stt."""
@ -1430,9 +1434,6 @@ class PipelineInput:
tts_input: str | None = None tts_input: str | None = None
"""Input for text-to-speech. Required when start_stage = tts.""" """Input for text-to-speech. Required when start_stage = tts."""
conversation_id: str | None = None
"""Identifier for the conversation."""
conversation_extra_system_prompt: str | None = None conversation_extra_system_prompt: str | None = None
"""Extra prompt information for the conversation agent.""" """Extra prompt information for the conversation agent."""
@ -1441,7 +1442,7 @@ class PipelineInput:
async def execute(self) -> None: async def execute(self) -> None:
"""Run pipeline.""" """Run pipeline."""
self.run.start(device_id=self.device_id) self.run.start(conversation_id=self.conversation_id, device_id=self.device_id)
current_stage: PipelineStage | None = self.run.start_stage current_stage: PipelineStage | None = self.run.start_stage
stt_audio_buffer: list[EnhancedAudioChunk] = [] stt_audio_buffer: list[EnhancedAudioChunk] = []
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None

View File

@ -14,7 +14,11 @@ import voluptuous as vol
from homeassistant.components import conversation, stt, tts, websocket_api from homeassistant.components import conversation, stt, tts, websocket_api
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers import (
chat_session,
config_validation as cv,
entity_registry as er,
)
from homeassistant.util import language as language_util from homeassistant.util import language as language_util
from .const import ( from .const import (
@ -145,7 +149,6 @@ async def websocket_run(
# Arguments to PipelineInput # Arguments to PipelineInput
input_args: dict[str, Any] = { input_args: dict[str, Any] = {
"conversation_id": msg.get("conversation_id"),
"device_id": msg.get("device_id"), "device_id": msg.get("device_id"),
} }
@ -233,6 +236,10 @@ async def websocket_run(
audio_settings=audio_settings or AudioSettings(), audio_settings=audio_settings or AudioSettings(),
) )
with chat_session.async_get_chat_session(
hass, msg.get("conversation_id")
) as session:
input_args["conversation_id"] = session.conversation_id
pipeline_input = PipelineInput(**input_args) pipeline_input = PipelineInput(**input_args)
try: try:

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass
from enum import StrEnum from enum import StrEnum
import logging import logging
import time import time
from typing import Any, Final, Literal, final from typing import Any, Literal, final
from homeassistant.components import conversation, media_source, stt, tts from homeassistant.components import conversation, media_source, stt, tts
from homeassistant.components.assist_pipeline import ( from homeassistant.components.assist_pipeline import (
@ -28,14 +28,12 @@ from homeassistant.components.tts import (
) )
from homeassistant.core import Context, callback from homeassistant.core import Context, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity from homeassistant.helpers import chat_session, entity
from homeassistant.helpers.entity import EntityDescription from homeassistant.helpers.entity import EntityDescription
from .const import AssistSatelliteEntityFeature from .const import AssistSatelliteEntityFeature
from .errors import AssistSatelliteError, SatelliteBusyError from .errors import AssistSatelliteError, SatelliteBusyError
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -114,7 +112,6 @@ class AssistSatelliteEntity(entity.Entity):
_attr_vad_sensitivity_entity_id: str | None = None _attr_vad_sensitivity_entity_id: str | None = None
_conversation_id: str | None = None _conversation_id: str | None = None
_conversation_id_time: float | None = None
_run_has_tts: bool = False _run_has_tts: bool = False
_is_announcing = False _is_announcing = False
@ -260,6 +257,21 @@ class AssistSatelliteEntity(entity.Entity):
else: else:
self._extra_system_prompt = start_message or None self._extra_system_prompt = start_message or None
with (
# Not passing in a conversation ID will force a new one to be created
chat_session.async_get_chat_session(self.hass) as session,
conversation.async_get_chat_log(self.hass, session) as chat_log,
):
self._conversation_id = session.conversation_id
if start_message:
async for _tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=self.entity_id, content=start_message
)
):
pass # no tool responses.
try: try:
await self.async_start_conversation(announcement) await self.async_start_conversation(announcement)
finally: finally:
@ -325,18 +337,18 @@ class AssistSatelliteEntity(entity.Entity):
assert self._context is not None assert self._context is not None
# Reset conversation id if necessary
if self._conversation_id_time and (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
self._conversation_id_time = None
# Set entity state based on pipeline events # Set entity state based on pipeline events
self._run_has_tts = False self._run_has_tts = False
assert self.platform.config_entry is not None assert self.platform.config_entry is not None
self._pipeline_task = self.platform.config_entry.async_create_background_task(
with chat_session.async_get_chat_session(
self.hass, self._conversation_id
) as session:
# Store the conversation ID. If it is no longer valid, get_chat_session will reset it
self._conversation_id = session.conversation_id
self._pipeline_task = (
self.platform.config_entry.async_create_background_task(
self.hass, self.hass,
async_pipeline_from_audio_stream( async_pipeline_from_audio_stream(
self.hass, self.hass,
@ -352,7 +364,7 @@ class AssistSatelliteEntity(entity.Entity):
), ),
stt_stream=audio_stream, stt_stream=audio_stream,
pipeline_id=self._resolve_pipeline(), pipeline_id=self._resolve_pipeline(),
conversation_id=self._conversation_id, conversation_id=session.conversation_id,
device_id=device_id, device_id=device_id,
tts_audio_output=self.tts_options, tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase, wake_word_phrase=wake_word_phrase,
@ -365,6 +377,7 @@ class AssistSatelliteEntity(entity.Entity):
), ),
f"{self.entity_id}_pipeline", f"{self.entity_id}_pipeline",
) )
)
try: try:
await self._pipeline_task await self._pipeline_task
@ -393,11 +406,6 @@ class AssistSatelliteEntity(entity.Entity):
self._set_state(AssistSatelliteState.LISTENING) self._set_state(AssistSatelliteState.LISTENING)
elif event.type is PipelineEventType.INTENT_START: elif event.type is PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING) self._set_state(AssistSatelliteState.PROCESSING)
elif event.type is PipelineEventType.INTENT_END:
assert event.data is not None
# Update timeout
self._conversation_id_time = time.monotonic()
self._conversation_id = event.data["intent_output"]["conversation_id"]
elif event.type is PipelineEventType.TTS_START: elif event.type is PipelineEventType.TTS_START:
# Wait until tts_response_finished is called to return to waiting state # Wait until tts_response_finished is called to return to waiting state
self._run_has_tts = True self._run_has_tts = True

View File

@ -19,6 +19,8 @@ from .const import (
) )
from .entity import BangOlufsenEntity from .entity import BangOlufsenEntity
PARALLEL_UPDATES = 0
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,

View File

@ -20,7 +20,7 @@
"bluetooth-adapters==0.21.1", "bluetooth-adapters==0.21.1",
"bluetooth-auto-recovery==1.4.2", "bluetooth-auto-recovery==1.4.2",
"bluetooth-data-tools==1.23.3", "bluetooth-data-tools==1.23.3",
"dbus-fast==2.31.0", "dbus-fast==2.32.0",
"habluetooth==3.21.0" "habluetooth==3.21.0"
] ]
} }

View File

@ -30,6 +30,16 @@ from .agent_manager import (
async_get_agent, async_get_agent,
get_agent_manager, get_agent_manager,
) )
from .chat_log import (
AssistantContent,
ChatLog,
Content,
ConverseError,
SystemContent,
ToolResultContent,
UserContent,
async_get_chat_log,
)
from .const import ( from .const import (
ATTR_AGENT_ID, ATTR_AGENT_ID,
ATTR_CONVERSATION_ID, ATTR_CONVERSATION_ID,
@ -48,13 +58,13 @@ from .default_agent import DefaultAgent, async_setup_default_agent
from .entity import ConversationEntity from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult from .models import AbstractConversationAgent, ConversationInput, ConversationResult
from .session import ChatLog, Content, ConverseError, NativeContent, async_get_chat_log
from .trace import ConversationTraceEventType, async_conversation_trace_append from .trace import ConversationTraceEventType, async_conversation_trace_append
__all__ = [ __all__ = [
"DOMAIN", "DOMAIN",
"HOME_ASSISTANT_AGENT", "HOME_ASSISTANT_AGENT",
"OLD_HOME_ASSISTANT_AGENT", "OLD_HOME_ASSISTANT_AGENT",
"AssistantContent",
"ChatLog", "ChatLog",
"Content", "Content",
"ConversationEntity", "ConversationEntity",
@ -63,7 +73,9 @@ __all__ = [
"ConversationResult", "ConversationResult",
"ConversationTraceEventType", "ConversationTraceEventType",
"ConverseError", "ConverseError",
"NativeContent", "SystemContent",
"ToolResultContent",
"UserContent",
"async_conversation_trace_append", "async_conversation_trace_append",
"async_converse", "async_converse",
"async_get_agent_info", "async_get_agent_info",

View File

@ -2,19 +2,16 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Generator from collections.abc import AsyncGenerator, Generator
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from datetime import datetime
import logging import logging
from typing import Literal
import voluptuous as vol import voluptuous as vol
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import chat_session, intent, llm, template from homeassistant.helpers import chat_session, intent, llm, template
from homeassistant.util import dt as dt_util
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType from homeassistant.util.json import JsonObjectType
@ -31,7 +28,7 @@ LOGGER = logging.getLogger(__name__)
def async_get_chat_log( def async_get_chat_log(
hass: HomeAssistant, hass: HomeAssistant,
session: chat_session.ChatSession, session: chat_session.ChatSession,
user_input: ConversationInput, user_input: ConversationInput | None = None,
) -> Generator[ChatLog]: ) -> Generator[ChatLog]:
"""Return chat log for a specific chat session.""" """Return chat log for a specific chat session."""
all_history = hass.data.get(DATA_CHAT_HISTORY) all_history = hass.data.get(DATA_CHAT_HISTORY)
@ -42,9 +39,9 @@ def async_get_chat_log(
history = all_history.get(session.conversation_id) history = all_history.get(session.conversation_id)
if history: if history:
history = replace(history, messages=history.messages.copy()) history = replace(history, content=history.content.copy())
else: else:
history = ChatLog(hass, session.conversation_id, user_input.agent_id) history = ChatLog(hass, session.conversation_id)
@callback @callback
def do_cleanup() -> None: def do_cleanup() -> None:
@ -53,22 +50,19 @@ def async_get_chat_log(
session.async_on_cleanup(do_cleanup) session.async_on_cleanup(do_cleanup)
message: Content = Content( if user_input is not None:
role="user", history.async_add_user_content(UserContent(content=user_input.text))
agent_id=user_input.agent_id,
content=user_input.text, last_message = history.content[-1]
)
history.async_add_message(message)
yield history yield history
if history.messages[-1] is message: if history.content[-1] is last_message:
LOGGER.debug( LOGGER.debug(
"History opened but no assistant message was added, ignoring update" "History opened but no assistant message was added, ignoring update"
) )
return return
history.last_updated = dt_util.utcnow()
all_history[session.conversation_id] = history all_history[session.conversation_id] = history
@ -94,63 +88,94 @@ class ConverseError(HomeAssistantError):
) )
@dataclass @dataclass(frozen=True)
class Content: class SystemContent:
"""Base class for chat messages.""" """Base class for chat messages."""
role: Literal["system", "assistant", "user"] role: str = field(init=False, default="system")
agent_id: str | None
content: str content: str
@dataclass(frozen=True) @dataclass(frozen=True)
class NativeContent[_NativeT]: class UserContent:
"""Native content.""" """Assistant content."""
role: str = field(init=False, default="native") role: str = field(init=False, default="user")
content: str
@dataclass(frozen=True)
class AssistantContent:
"""Assistant content."""
role: str = field(init=False, default="assistant")
agent_id: str agent_id: str
content: _NativeT content: str
tool_calls: list[llm.ToolInput] | None = None
@dataclass(frozen=True)
class ToolResultContent:
"""Tool result content."""
role: str = field(init=False, default="tool_result")
agent_id: str
tool_call_id: str
tool_name: str
tool_result: JsonObjectType
Content = SystemContent | UserContent | AssistantContent | ToolResultContent
@dataclass @dataclass
class ChatLog[_NativeT]: class ChatLog:
"""Class holding the chat history of a specific conversation.""" """Class holding the chat history of a specific conversation."""
hass: HomeAssistant hass: HomeAssistant
conversation_id: str conversation_id: str
agent_id: str | None content: list[Content] = field(default_factory=lambda: [SystemContent(content="")])
user_name: str | None = None
messages: list[Content | NativeContent[_NativeT]] = field(
default_factory=lambda: [Content(role="system", agent_id=None, content="")]
)
extra_system_prompt: str | None = None extra_system_prompt: str | None = None
llm_api: llm.APIInstance | None = None llm_api: llm.APIInstance | None = None
last_updated: datetime = field(default_factory=dt_util.utcnow)
@callback @callback
def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None: def async_add_user_content(self, content: UserContent) -> None:
"""Process intent.""" """Add user content to the log."""
if message.role == "system": self.content.append(content)
raise ValueError("Cannot add system messages to history")
if message.role != "native" and self.messages[-1].role == message.role:
raise ValueError("Cannot add two assistant or user messages in a row")
self.messages.append(message) async def async_add_assistant_content(
self, content: AssistantContent
) -> AsyncGenerator[ToolResultContent]:
"""Add assistant content."""
self.content.append(content)
@callback if content.tool_calls is None:
def async_get_messages( return
self, agent_id: str | None = None
) -> list[Content | NativeContent[_NativeT]]:
"""Get messages for a specific agent ID.
This will filter out any native message tied to other agent IDs. if self.llm_api is None:
It can still include assistant/user messages generated by other agents. raise ValueError("No LLM API configured")
"""
return [ for tool_input in content.tool_calls:
message LOGGER.debug(
for message in self.messages "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
if message.role != "native" or message.agent_id == agent_id )
]
try:
tool_result = await self.llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_result = {"error": type(e).__name__}
if str(e):
tool_result["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_result)
response_content = ToolResultContent(
agent_id=content.agent_id,
tool_call_id=tool_input.id,
tool_name=tool_input.tool_name,
tool_result=tool_result,
)
self.content.append(response_content)
yield response_content
async def async_update_llm_data( async def async_update_llm_data(
self, self,
@ -250,36 +275,16 @@ class ChatLog[_NativeT]:
prompt = "\n".join(prompt_parts) prompt = "\n".join(prompt_parts)
self.llm_api = llm_api self.llm_api = llm_api
self.user_name = user_name
self.extra_system_prompt = extra_system_prompt self.extra_system_prompt = extra_system_prompt
self.messages[0] = Content( self.content[0] = SystemContent(content=prompt)
role="system",
agent_id=user_input.agent_id,
content=prompt,
)
LOGGER.debug("Prompt: %s", self.messages) LOGGER.debug("Prompt: %s", self.content)
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None) LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
trace.async_conversation_trace_append( trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, trace.ConversationTraceEventType.AGENT_DETAIL,
{ {
"messages": self.messages, "messages": self.content,
"tools": self.llm_api.tools if self.llm_api else None, "tools": self.llm_api.tools if self.llm_api else None,
}, },
) )
async def async_call_tool(self, tool_input: llm.ToolInput) -> JsonObjectType:
"""Invoke LLM tool for the configured LLM API."""
if not self.llm_api:
raise ValueError("No LLM API configured")
LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
try:
tool_response = await self.llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_response)
return tool_response

View File

@ -55,6 +55,7 @@ from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_track_state_added_domain from homeassistant.helpers.event import async_track_state_added_domain
from homeassistant.util.json import JsonObjectType, json_loads_object from homeassistant.util.json import JsonObjectType, json_loads_object
from .chat_log import AssistantContent, async_get_chat_log
from .const import ( from .const import (
DATA_DEFAULT_ENTITY, DATA_DEFAULT_ENTITY,
DEFAULT_EXPOSED_ATTRIBUTES, DEFAULT_EXPOSED_ATTRIBUTES,
@ -63,7 +64,6 @@ from .const import (
) )
from .entity import ConversationEntity from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult from .models import ConversationInput, ConversationResult
from .session import Content, async_get_chat_log
from .trace import ConversationTraceEventType, async_conversation_trace_append from .trace import ConversationTraceEventType, async_conversation_trace_append
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -379,13 +379,13 @@ class DefaultAgent(ConversationEntity):
) )
speech: str = response.speech.get("plain", {}).get("speech", "") speech: str = response.speech.get("plain", {}).get("speech", "")
chat_log.async_add_message( async for _tool_result in chat_log.async_add_assistant_content(
Content( AssistantContent(
role="assistant", agent_id=user_input.agent_id, # type: ignore[arg-type]
agent_id=user_input.agent_id,
content=speech, content=speech,
) )
) ):
pass
return ConversationResult( return ConversationResult(
response=response, conversation_id=session.conversation_id response=response, conversation_id=session.conversation_id

View File

@ -22,5 +22,5 @@
"integration_type": "device", "integration_type": "device",
"iot_class": "local_polling", "iot_class": "local_polling",
"loggers": ["eq3btsmart"], "loggers": ["eq3btsmart"],
"requirements": ["eq3btsmart==1.4.1", "bleak-esphome==2.6.0"] "requirements": ["eq3btsmart==1.4.1", "bleak-esphome==2.7.0"]
} }

View File

@ -18,7 +18,7 @@
"requirements": [ "requirements": [
"aioesphomeapi==29.0.0", "aioesphomeapi==29.0.0",
"esphome-dashboard-api==1.2.3", "esphome-dashboard-api==1.2.3",
"bleak-esphome==2.6.0" "bleak-esphome==2.7.0"
], ],
"zeroconf": ["_esphomelib._tcp.local."] "zeroconf": ["_esphomelib._tcp.local."]
} }

View File

@ -28,14 +28,14 @@
"user": { "user": {
"description": "Enter the settings to connect to the camera.", "description": "Enter the settings to connect to the camera.",
"data": { "data": {
"still_image_url": "Still Image URL (e.g. http://...)", "still_image_url": "Still image URL (e.g. http://...)",
"stream_source": "Stream Source URL (e.g. rtsp://...)", "stream_source": "Stream source URL (e.g. rtsp://...)",
"rtsp_transport": "RTSP transport protocol", "rtsp_transport": "RTSP transport protocol",
"authentication": "Authentication", "authentication": "Authentication",
"limit_refetch_to_url_change": "Limit refetch to url change", "limit_refetch_to_url_change": "Limit refetch to URL change",
"password": "[%key:common::config_flow::data::password%]", "password": "[%key:common::config_flow::data::password%]",
"username": "[%key:common::config_flow::data::username%]", "username": "[%key:common::config_flow::data::username%]",
"framerate": "Frame Rate (Hz)", "framerate": "Frame rate (Hz)",
"verify_ssl": "[%key:common::config_flow::data::verify_ssl%]" "verify_ssl": "[%key:common::config_flow::data::verify_ssl%]"
} }
}, },

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import codecs import codecs
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Literal from typing import Any, Literal, cast
from google.api_core.exceptions import GoogleAPIError from google.api_core.exceptions import GoogleAPIError
import google.generativeai as genai import google.generativeai as genai
@ -149,15 +149,53 @@ def _escape_decode(value: Any) -> Any:
return value return value
def _chat_message_convert( def _create_google_tool_response_content(
message: conversation.Content | conversation.NativeContent[genai_types.ContentDict], content: list[conversation.ToolResultContent],
) -> genai_types.ContentDict: ) -> protos.Content:
"""Convert any native chat message for this agent to the native format.""" """Create a Google tool response content."""
if message.role == "native": return protos.Content(
return message.content parts=[
protos.Part(
function_response=protos.FunctionResponse(
name=tool_result.tool_name, response=tool_result.tool_result
)
)
for tool_result in content
]
)
role = "model" if message.role == "assistant" else message.role
return {"role": role, "parts": message.content} def _convert_content(
content: conversation.UserContent
| conversation.AssistantContent
| conversation.SystemContent,
) -> genai_types.ContentDict:
"""Convert HA content to Google content."""
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
role = "model" if content.role == "assistant" else content.role
return {"role": role, "parts": content.content}
# Handle the Assistant content with tool calls.
assert type(content) is conversation.AssistantContent
parts = []
if content.content:
parts.append(protos.Part(text=content.content))
if content.tool_calls:
parts.extend(
[
protos.Part(
function_call=protos.FunctionCall(
name=tool_call.tool_name,
args=_escape_decode(tool_call.tool_args),
)
)
for tool_call in content.tool_calls
]
)
return protos.Content({"role": "model", "parts": parts})
class GoogleGenerativeAIConversationEntity( class GoogleGenerativeAIConversationEntity(
@ -220,7 +258,7 @@ class GoogleGenerativeAIConversationEntity(
async def _async_handle_message( async def _async_handle_message(
self, self,
user_input: conversation.ConversationInput, user_input: conversation.ConversationInput,
session: conversation.ChatLog[genai_types.ContentDict], chat_log: conversation.ChatLog,
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Call the API.""" """Call the API."""
@ -228,7 +266,7 @@ class GoogleGenerativeAIConversationEntity(
options = self.entry.options options = self.entry.options
try: try:
await session.async_update_llm_data( await chat_log.async_update_llm_data(
DOMAIN, DOMAIN,
user_input, user_input,
options.get(CONF_LLM_HASS_API), options.get(CONF_LLM_HASS_API),
@ -238,10 +276,10 @@ class GoogleGenerativeAIConversationEntity(
return err.as_conversation_result() return err.as_conversation_result()
tools: list[dict[str, Any]] | None = None tools: list[dict[str, Any]] | None = None
if session.llm_api: if chat_log.llm_api:
tools = [ tools = [
_format_tool(tool, session.llm_api.custom_serializer) _format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in session.llm_api.tools for tool in chat_log.llm_api.tools
] ]
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
@ -252,9 +290,36 @@ class GoogleGenerativeAIConversationEntity(
"gemini-1.0" not in model_name and "gemini-pro" not in model_name "gemini-1.0" not in model_name and "gemini-pro" not in model_name
) )
prompt, *messages = [ prompt = chat_log.content[0].content # type: ignore[union-attr]
_chat_message_convert(message) for message in session.async_get_messages() messages: list[genai_types.ContentDict] = []
]
# Google groups tool results, we do not. Group them before sending.
tool_results: list[conversation.ToolResultContent] = []
for chat_content in chat_log.content[1:]:
if chat_content.role == "tool_result":
# mypy doesn't like picking a type based on checking shared property 'role'
tool_results.append(cast(conversation.ToolResultContent, chat_content))
continue
if tool_results:
messages.append(_create_google_tool_response_content(tool_results))
tool_results.clear()
messages.append(
_convert_content(
cast(
conversation.UserContent
| conversation.SystemContent
| conversation.AssistantContent,
chat_content,
)
)
)
if tool_results:
messages.append(_create_google_tool_response_content(tool_results))
model = genai.GenerativeModel( model = genai.GenerativeModel(
model_name=model_name, model_name=model_name,
generation_config={ generation_config={
@ -282,12 +347,12 @@ class GoogleGenerativeAIConversationEntity(
), ),
}, },
tools=tools or None, tools=tools or None,
system_instruction=prompt["parts"] if supports_system_instruction else None, system_instruction=prompt if supports_system_instruction else None,
) )
if not supports_system_instruction: if not supports_system_instruction:
messages = [ messages = [
{"role": "user", "parts": prompt["parts"]}, {"role": "user", "parts": prompt},
{"role": "model", "parts": "Ok"}, {"role": "model", "parts": "Ok"},
*messages, *messages,
] ]
@ -325,50 +390,40 @@ class GoogleGenerativeAIConversationEntity(
content = " ".join( content = " ".join(
[part.text.strip() for part in chat_response.parts if part.text] [part.text.strip() for part in chat_response.parts if part.text]
) )
if content:
session.async_add_message(
conversation.Content(
role="assistant",
agent_id=user_input.agent_id,
content=content,
)
)
function_calls = [ tool_calls = []
part.function_call for part in chat_response.parts if part.function_call for part in chat_response.parts:
] if not part.function_call:
continue
if not function_calls or not session.llm_api: tool_call = MessageToDict(part.function_call._pb) # noqa: SLF001
break
tool_responses = []
for function_call in function_calls:
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
tool_name = tool_call["name"] tool_name = tool_call["name"]
tool_args = _escape_decode(tool_call["args"]) tool_args = _escape_decode(tool_call["args"])
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args) tool_calls.append(
function_response = await session.async_call_tool(tool_input) llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
tool_responses.append(
protos.Part(
function_response=protos.FunctionResponse(
name=tool_name, response=function_response
) )
)
) chat_request = _create_google_tool_response_content(
chat_request = protos.Content(parts=tool_responses) [
session.async_add_message( tool_response
conversation.NativeContent( async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id, agent_id=user_input.agent_id,
content=chat_request, content=content,
tool_calls=tool_calls or None,
) )
) )
]
)
if not tool_calls:
break
response = intent.IntentResponse(language=user_input.language) response = intent.IntentResponse(language=user_input.language)
response.async_set_speech( response.async_set_speech(
" ".join([part.text.strip() for part in chat_response.parts if part.text]) " ".join([part.text.strip() for part in chat_response.parts if part.text])
) )
return conversation.ConversationResult( return conversation.ConversationResult(
response=response, conversation_id=session.conversation_id response=response, conversation_id=chat_log.conversation_id
) )
async def _async_entry_update_listener( async def _async_entry_update_listener(

View File

@ -517,17 +517,22 @@ class SupervisorBackupReaderWriter(BackupReaderWriter):
raise HomeAssistantError(message) from err raise HomeAssistantError(message) from err
restore_complete = asyncio.Event() restore_complete = asyncio.Event()
restore_errors: list[dict[str, str]] = []
@callback @callback
def on_job_progress(data: Mapping[str, Any]) -> None: def on_job_progress(data: Mapping[str, Any]) -> None:
"""Handle backup restore progress.""" """Handle backup restore progress."""
if data.get("done") is True: if data.get("done") is True:
restore_complete.set() restore_complete.set()
restore_errors.extend(data.get("errors", []))
unsub = self._async_listen_job_events(job.job_id, on_job_progress) unsub = self._async_listen_job_events(job.job_id, on_job_progress)
try: try:
await self._get_job_state(job.job_id, on_job_progress) await self._get_job_state(job.job_id, on_job_progress)
await restore_complete.wait() await restore_complete.wait()
if restore_errors:
# We should add more specific error handling here in the future
raise BackupReaderWriterError(f"Restore failed: {restore_errors}")
finally: finally:
unsub() unsub()
@ -554,6 +559,18 @@ class SupervisorBackupReaderWriter(BackupReaderWriter):
) )
return return
restore_errors = data.get("errors", [])
if restore_errors:
_LOGGER.warning("Restore backup failed: %s", restore_errors)
# We should add more specific error handling here in the future
on_progress(
RestoreBackupEvent(
reason="unknown_error",
stage=None,
state=RestoreBackupState.FAILED,
)
)
else:
on_progress( on_progress(
RestoreBackupEvent( RestoreBackupEvent(
reason="", stage=None, state=RestoreBackupState.COMPLETED reason="", stage=None, state=RestoreBackupState.COMPLETED

View File

@ -1,6 +1,7 @@
"""Constants for the homee integration.""" """Constants for the homee integration."""
from homeassistant.const import ( from homeassistant.const import (
DEGREE,
LIGHT_LUX, LIGHT_LUX,
PERCENTAGE, PERCENTAGE,
REVOLUTIONS_PER_MINUTE, REVOLUTIONS_PER_MINUTE,
@ -32,6 +33,7 @@ HOMEE_UNIT_TO_HA_UNIT = {
"W": UnitOfPower.WATT, "W": UnitOfPower.WATT,
"m/s": UnitOfSpeed.METERS_PER_SECOND, "m/s": UnitOfSpeed.METERS_PER_SECOND,
"km/h": UnitOfSpeed.KILOMETERS_PER_HOUR, "km/h": UnitOfSpeed.KILOMETERS_PER_HOUR,
"°": DEGREE,
"°F": UnitOfTemperature.FAHRENHEIT, "°F": UnitOfTemperature.FAHRENHEIT,
"°C": UnitOfTemperature.CELSIUS, "°C": UnitOfTemperature.CELSIUS,
"K": UnitOfTemperature.KELVIN, "K": UnitOfTemperature.KELVIN,
@ -51,7 +53,7 @@ OPEN_CLOSE_MAP_REVERSED = {
0.0: "closed", 0.0: "closed",
1.0: "open", 1.0: "open",
2.0: "partial", 2.0: "partial",
3.0: "cosing", 3.0: "closing",
4.0: "opening", 4.0: "opening",
} }
WINDOW_MAP = { WINDOW_MAP = {

View File

@ -8,5 +8,5 @@
"documentation": "https://www.home-assistant.io/integrations/lcn", "documentation": "https://www.home-assistant.io/integrations/lcn",
"iot_class": "local_push", "iot_class": "local_push",
"loggers": ["pypck"], "loggers": ["pypck"],
"requirements": ["pypck==0.8.3", "lcn-frontend==0.2.3"] "requirements": ["pypck==0.8.5", "lcn-frontend==0.2.3"]
} }

View File

@ -30,7 +30,7 @@ class MotionMountEntity(Entity):
self.config_entry = config_entry self.config_entry = config_entry
# We store the pin, as we might need it during reconnect # We store the pin, as we might need it during reconnect
self.pin = config_entry.data[CONF_PIN] self.pin = config_entry.data.get(CONF_PIN)
mac = format_mac(mm.mac.hex()) mac = format_mac(mm.mac.hex())

View File

@ -5,34 +5,33 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
from kiota_abstractions.api_error import APIError from onedrive_personal_sdk import OneDriveClient
from kiota_abstractions.authentication import BaseBearerTokenAuthenticationProvider from onedrive_personal_sdk.exceptions import (
from msgraph import GraphRequestAdapter, GraphServiceClient AuthenticationError,
from msgraph.generated.drives.item.items.items_request_builder import ( HttpRequestException,
ItemsRequestBuilder, OneDriveException,
) )
from msgraph.generated.models.drive_item import DriveItem
from msgraph.generated.models.folder import Folder
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.config_entry_oauth2_flow import ( from homeassistant.helpers.config_entry_oauth2_flow import (
OAuth2Session, OAuth2Session,
async_get_config_entry_implementation, async_get_config_entry_implementation,
) )
from homeassistant.helpers.httpx_client import create_async_httpx_client
from homeassistant.helpers.instance_id import async_get as async_get_instance_id from homeassistant.helpers.instance_id import async_get as async_get_instance_id
from .api import OneDriveConfigEntryAccessTokenProvider from .api import OneDriveConfigEntryAccessTokenProvider
from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN, OAUTH_SCOPES from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN
@dataclass @dataclass
class OneDriveRuntimeData: class OneDriveRuntimeData:
"""Runtime data for the OneDrive integration.""" """Runtime data for the OneDrive integration."""
items: ItemsRequestBuilder client: OneDriveClient
token_provider: OneDriveConfigEntryAccessTokenProvider
backup_folder_id: str backup_folder_id: str
@ -47,29 +46,18 @@ async def async_setup_entry(hass: HomeAssistant, entry: OneDriveConfigEntry) ->
session = OAuth2Session(hass, entry, implementation) session = OAuth2Session(hass, entry, implementation)
auth_provider = BaseBearerTokenAuthenticationProvider( token_provider = OneDriveConfigEntryAccessTokenProvider(session)
access_token_provider=OneDriveConfigEntryAccessTokenProvider(session)
)
adapter = GraphRequestAdapter(
auth_provider=auth_provider,
client=create_async_httpx_client(hass, follow_redirects=True),
)
graph_client = GraphServiceClient( client = OneDriveClient(token_provider, async_get_clientsession(hass))
request_adapter=adapter,
scopes=OAUTH_SCOPES,
)
assert entry.unique_id
drive_item = graph_client.drives.by_drive_id(entry.unique_id)
# get approot, will be created automatically if it does not exist # get approot, will be created automatically if it does not exist
try: try:
approot = await drive_item.special.by_drive_item_id("approot").get() approot = await client.get_approot()
except APIError as err: except AuthenticationError as err:
if err.response_status_code == 403:
raise ConfigEntryAuthFailed( raise ConfigEntryAuthFailed(
translation_domain=DOMAIN, translation_key="authentication_failed" translation_domain=DOMAIN, translation_key="authentication_failed"
) from err ) from err
except (HttpRequestException, OneDriveException, TimeoutError) as err:
_LOGGER.debug("Failed to get approot", exc_info=True) _LOGGER.debug("Failed to get approot", exc_info=True)
raise ConfigEntryNotReady( raise ConfigEntryNotReady(
translation_domain=DOMAIN, translation_domain=DOMAIN,
@ -77,24 +65,24 @@ async def async_setup_entry(hass: HomeAssistant, entry: OneDriveConfigEntry) ->
translation_placeholders={"folder": "approot"}, translation_placeholders={"folder": "approot"},
) from err ) from err
if approot is None or not approot.id: instance_id = await async_get_instance_id(hass)
_LOGGER.debug("Failed to get approot, was None") backup_folder_name = f"backups_{instance_id[:8]}"
try:
backup_folder = await client.create_folder(
parent_id=approot.id, name=backup_folder_name
)
except (HttpRequestException, OneDriveException, TimeoutError) as err:
_LOGGER.debug("Failed to create backup folder", exc_info=True)
raise ConfigEntryNotReady( raise ConfigEntryNotReady(
translation_domain=DOMAIN, translation_domain=DOMAIN,
translation_key="failed_to_get_folder", translation_key="failed_to_get_folder",
translation_placeholders={"folder": "approot"}, translation_placeholders={"folder": backup_folder_name},
) ) from err
instance_id = await async_get_instance_id(hass)
backup_folder_id = await _async_create_folder_if_not_exists(
items=drive_item.items,
base_folder_id=approot.id,
folder=f"backups_{instance_id[:8]}",
)
entry.runtime_data = OneDriveRuntimeData( entry.runtime_data = OneDriveRuntimeData(
items=drive_item.items, client=client,
backup_folder_id=backup_folder_id, token_provider=token_provider,
backup_folder_id=backup_folder.id,
) )
_async_notify_backup_listeners_soon(hass) _async_notify_backup_listeners_soon(hass)
@ -116,54 +104,3 @@ def _async_notify_backup_listeners(hass: HomeAssistant) -> None:
@callback @callback
def _async_notify_backup_listeners_soon(hass: HomeAssistant) -> None: def _async_notify_backup_listeners_soon(hass: HomeAssistant) -> None:
hass.loop.call_soon(_async_notify_backup_listeners, hass) hass.loop.call_soon(_async_notify_backup_listeners, hass)
async def _async_create_folder_if_not_exists(
items: ItemsRequestBuilder,
base_folder_id: str,
folder: str,
) -> str:
"""Check if a folder exists and create it if it does not exist."""
folder_item: DriveItem | None = None
try:
folder_item = await items.by_drive_item_id(f"{base_folder_id}:/{folder}:").get()
except APIError as err:
if err.response_status_code != 404:
_LOGGER.debug("Failed to get folder %s", folder, exc_info=True)
raise ConfigEntryNotReady(
translation_domain=DOMAIN,
translation_key="failed_to_get_folder",
translation_placeholders={"folder": folder},
) from err
# is 404 not found, create folder
_LOGGER.debug("Creating folder %s", folder)
request_body = DriveItem(
name=folder,
folder=Folder(),
additional_data={
"@microsoft_graph_conflict_behavior": "fail",
},
)
try:
folder_item = await items.by_drive_item_id(base_folder_id).children.post(
request_body
)
except APIError as create_err:
_LOGGER.debug("Failed to create folder %s", folder, exc_info=True)
raise ConfigEntryNotReady(
translation_domain=DOMAIN,
translation_key="failed_to_create_folder",
translation_placeholders={"folder": folder},
) from create_err
_LOGGER.debug("Created folder %s", folder)
else:
_LOGGER.debug("Found folder %s", folder)
if folder_item is None or not folder_item.id:
_LOGGER.debug("Failed to get folder %s, was None", folder)
raise ConfigEntryNotReady(
translation_domain=DOMAIN,
translation_key="failed_to_get_folder",
translation_placeholders={"folder": folder},
)
return folder_item.id

View File

@ -1,28 +1,14 @@
"""API for OneDrive bound to Home Assistant OAuth.""" """API for OneDrive bound to Home Assistant OAuth."""
from typing import Any, cast from typing import cast
from kiota_abstractions.authentication import AccessTokenProvider, AllowedHostsValidator from onedrive_personal_sdk import TokenProvider
from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.const import CONF_ACCESS_TOKEN
from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers import config_entry_oauth2_flow
class OneDriveAccessTokenProvider(AccessTokenProvider): class OneDriveConfigFlowAccessTokenProvider(TokenProvider):
"""Provide OneDrive authentication tied to an OAuth2 based config entry."""
def __init__(self) -> None:
"""Initialize OneDrive auth."""
super().__init__()
# currently allowing all hosts
self._allowed_hosts_validator = AllowedHostsValidator(allowed_hosts=[])
def get_allowed_hosts_validator(self) -> AllowedHostsValidator:
"""Retrieve the allowed hosts validator."""
return self._allowed_hosts_validator
class OneDriveConfigFlowAccessTokenProvider(OneDriveAccessTokenProvider):
"""Provide OneDrive authentication tied to an OAuth2 based config entry.""" """Provide OneDrive authentication tied to an OAuth2 based config entry."""
def __init__(self, token: str) -> None: def __init__(self, token: str) -> None:
@ -30,14 +16,12 @@ class OneDriveConfigFlowAccessTokenProvider(OneDriveAccessTokenProvider):
super().__init__() super().__init__()
self._token = token self._token = token
async def get_authorization_token( # pylint: disable=dangerous-default-value def async_get_access_token(self) -> str:
self, uri: str, additional_authentication_context: dict[str, Any] = {} """Return a valid access token."""
) -> str:
"""Return a valid authorization token."""
return self._token return self._token
class OneDriveConfigEntryAccessTokenProvider(OneDriveAccessTokenProvider): class OneDriveConfigEntryAccessTokenProvider(TokenProvider):
"""Provide OneDrive authentication tied to an OAuth2 based config entry.""" """Provide OneDrive authentication tied to an OAuth2 based config entry."""
def __init__(self, oauth_session: config_entry_oauth2_flow.OAuth2Session) -> None: def __init__(self, oauth_session: config_entry_oauth2_flow.OAuth2Session) -> None:
@ -45,9 +29,6 @@ class OneDriveConfigEntryAccessTokenProvider(OneDriveAccessTokenProvider):
super().__init__() super().__init__()
self._oauth_session = oauth_session self._oauth_session = oauth_session
async def get_authorization_token( # pylint: disable=dangerous-default-value def async_get_access_token(self) -> str:
self, uri: str, additional_authentication_context: dict[str, Any] = {} """Return a valid access token."""
) -> str:
"""Return a valid authorization token."""
await self._oauth_session.async_ensure_token_valid()
return cast(str, self._oauth_session.token[CONF_ACCESS_TOKEN]) return cast(str, self._oauth_session.token[CONF_ACCESS_TOKEN])

View File

@ -2,37 +2,22 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator, Callable, Coroutine from collections.abc import AsyncIterator, Callable, Coroutine
from functools import wraps from functools import wraps
import html import html
import json import json
import logging import logging
from typing import Any, Concatenate, cast from typing import Any, Concatenate
from httpx import Response, TimeoutException from aiohttp import ClientTimeout
from kiota_abstractions.api_error import APIError from onedrive_personal_sdk.clients.large_file_upload import LargeFileUploadClient
from kiota_abstractions.authentication import AnonymousAuthenticationProvider from onedrive_personal_sdk.exceptions import (
from kiota_abstractions.headers_collection import HeadersCollection AuthenticationError,
from kiota_abstractions.method import Method HashMismatchError,
from kiota_abstractions.native_response_handler import NativeResponseHandler OneDriveException,
from kiota_abstractions.request_information import RequestInformation
from kiota_http.middleware.options import ResponseHandlerOption
from msgraph import GraphRequestAdapter
from msgraph.generated.drives.item.items.item.content.content_request_builder import (
ContentRequestBuilder,
) )
from msgraph.generated.drives.item.items.item.create_upload_session.create_upload_session_post_request_body import ( from onedrive_personal_sdk.models.items import File, Folder, ItemUpdate
CreateUploadSessionPostRequestBody, from onedrive_personal_sdk.models.upload import FileInfo
)
from msgraph.generated.drives.item.items.item.drive_item_item_request_builder import (
DriveItemItemRequestBuilder,
)
from msgraph.generated.models.drive_item import DriveItem
from msgraph.generated.models.drive_item_uploadable_properties import (
DriveItemUploadableProperties,
)
from msgraph_core.models import LargeFileUploadSession
from homeassistant.components.backup import ( from homeassistant.components.backup import (
AgentBackup, AgentBackup,
@ -41,14 +26,14 @@ from homeassistant.components.backup import (
suggested_filename, suggested_filename,
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.httpx_client import get_async_client from homeassistant.helpers.aiohttp_client import async_get_clientsession
from . import OneDriveConfigEntry from . import OneDriveConfigEntry
from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
UPLOAD_CHUNK_SIZE = 16 * 320 * 1024 # 5.2MB UPLOAD_CHUNK_SIZE = 16 * 320 * 1024 # 5.2MB
MAX_RETRIES = 5 TIMEOUT = ClientTimeout(connect=10, total=43200) # 12 hours
async def async_get_backup_agents( async def async_get_backup_agents(
@ -92,18 +77,18 @@ def handle_backup_errors[_R, **P](
) -> _R: ) -> _R:
try: try:
return await func(self, *args, **kwargs) return await func(self, *args, **kwargs)
except APIError as err: except AuthenticationError as err:
if err.response_status_code == 403:
self._entry.async_start_reauth(self._hass) self._entry.async_start_reauth(self._hass)
raise BackupAgentError("Authentication error") from err
except OneDriveException as err:
_LOGGER.error( _LOGGER.error(
"Error during backup in %s: Status %s, message %s", "Error during backup in %s:, message %s",
func.__name__, func.__name__,
err.response_status_code, err,
err.message,
) )
_LOGGER.debug("Full error: %s", err, exc_info=True) _LOGGER.debug("Full error: %s", err, exc_info=True)
raise BackupAgentError("Backup operation failed") from err raise BackupAgentError("Backup operation failed") from err
except TimeoutException as err: except TimeoutError as err:
_LOGGER.error( _LOGGER.error(
"Error during backup in %s: Timeout", "Error during backup in %s: Timeout",
func.__name__, func.__name__,
@ -123,7 +108,8 @@ class OneDriveBackupAgent(BackupAgent):
super().__init__() super().__init__()
self._hass = hass self._hass = hass
self._entry = entry self._entry = entry
self._items = entry.runtime_data.items self._client = entry.runtime_data.client
self._token_provider = entry.runtime_data.token_provider
self._folder_id = entry.runtime_data.backup_folder_id self._folder_id = entry.runtime_data.backup_folder_id
self.name = entry.title self.name = entry.title
assert entry.unique_id assert entry.unique_id
@ -134,24 +120,12 @@ class OneDriveBackupAgent(BackupAgent):
self, backup_id: str, **kwargs: Any self, backup_id: str, **kwargs: Any
) -> AsyncIterator[bytes]: ) -> AsyncIterator[bytes]:
"""Download a backup file.""" """Download a backup file."""
# this forces the query to return a raw httpx response, but breaks typing item = await self._find_item_by_backup_id(backup_id)
backup = await self._find_item_by_backup_id(backup_id) if item is None:
if backup is None or backup.id is None:
raise BackupAgentError("Backup not found") raise BackupAgentError("Backup not found")
request_config = ( stream = await self._client.download_drive_item(item.id, timeout=TIMEOUT)
ContentRequestBuilder.ContentRequestBuilderGetRequestConfiguration( return stream.iter_chunked(1024)
options=[ResponseHandlerOption(NativeResponseHandler())],
)
)
response = cast(
Response,
await self._items.by_drive_item_id(backup.id).content.get(
request_configuration=request_config
),
)
return response.aiter_bytes(chunk_size=1024)
@handle_backup_errors @handle_backup_errors
async def async_upload_backup( async def async_upload_backup(
@ -163,27 +137,20 @@ class OneDriveBackupAgent(BackupAgent):
) -> None: ) -> None:
"""Upload a backup.""" """Upload a backup."""
# upload file in chunks to support large files file = FileInfo(
upload_session_request_body = CreateUploadSessionPostRequestBody( suggested_filename(backup),
item=DriveItemUploadableProperties( backup.size,
additional_data={ self._folder_id,
"@microsoft.graph.conflictBehavior": "fail", await open_stream(),
},
) )
try:
item = await LargeFileUploadClient.upload(
self._token_provider, file, session=async_get_clientsession(self._hass)
) )
file_item = self._get_backup_file_item(suggested_filename(backup)) except HashMismatchError as err:
upload_session = await file_item.create_upload_session.post(
upload_session_request_body
)
if upload_session is None or upload_session.upload_url is None:
raise BackupAgentError( raise BackupAgentError(
translation_domain=DOMAIN, translation_key="backup_no_upload_session" "Hash validation failed, backup file might be corrupt"
) ) from err
await self._upload_file(
upload_session.upload_url, await open_stream(), backup.size
)
# store metadata in description # store metadata in description
backup_dict = backup.as_dict() backup_dict = backup.as_dict()
@ -191,7 +158,10 @@ class OneDriveBackupAgent(BackupAgent):
description = json.dumps(backup_dict) description = json.dumps(backup_dict)
_LOGGER.debug("Creating metadata: %s", description) _LOGGER.debug("Creating metadata: %s", description)
await file_item.patch(DriveItem(description=description)) await self._client.update_drive_item(
path_or_id=item.id,
data=ItemUpdate(description=description),
)
@handle_backup_errors @handle_backup_errors
async def async_delete_backup( async def async_delete_backup(
@ -200,35 +170,31 @@ class OneDriveBackupAgent(BackupAgent):
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Delete a backup file.""" """Delete a backup file."""
backup = await self._find_item_by_backup_id(backup_id) item = await self._find_item_by_backup_id(backup_id)
if backup is None or backup.id is None: if item is None:
return return
await self._items.by_drive_item_id(backup.id).delete() await self._client.delete_drive_item(item.id)
@handle_backup_errors @handle_backup_errors
async def async_list_backups(self, **kwargs: Any) -> list[AgentBackup]: async def async_list_backups(self, **kwargs: Any) -> list[AgentBackup]:
"""List backups.""" """List backups."""
backups: list[AgentBackup] = [] return [
items = await self._items.by_drive_item_id(f"{self._folder_id}").children.get() self._backup_from_description(item.description)
if items and (values := items.value): for item in await self._client.list_drive_items(self._folder_id)
for item in values: if item.description and "homeassistant_version" in item.description
if (description := item.description) is None: ]
continue
if "homeassistant_version" in description:
backups.append(self._backup_from_description(description))
return backups
@handle_backup_errors @handle_backup_errors
async def async_get_backup( async def async_get_backup(
self, backup_id: str, **kwargs: Any self, backup_id: str, **kwargs: Any
) -> AgentBackup | None: ) -> AgentBackup | None:
"""Return a backup.""" """Return a backup."""
backup = await self._find_item_by_backup_id(backup_id) item = await self._find_item_by_backup_id(backup_id)
if backup is None: return (
return None self._backup_from_description(item.description)
if item and item.description
assert backup.description # already checked in _find_item_by_backup_id else None
return self._backup_from_description(backup.description) )
def _backup_from_description(self, description: str) -> AgentBackup: def _backup_from_description(self, description: str) -> AgentBackup:
"""Create a backup object from a description.""" """Create a backup object from a description."""
@ -237,91 +203,13 @@ class OneDriveBackupAgent(BackupAgent):
) # OneDrive encodes the description on save automatically ) # OneDrive encodes the description on save automatically
return AgentBackup.from_dict(json.loads(description)) return AgentBackup.from_dict(json.loads(description))
async def _find_item_by_backup_id(self, backup_id: str) -> DriveItem | None: async def _find_item_by_backup_id(self, backup_id: str) -> File | Folder | None:
"""Find a backup item by its backup ID.""" """Find an item by backup ID."""
return next(
items = await self._items.by_drive_item_id(f"{self._folder_id}").children.get() (
if items and (values := items.value): item
for item in values: for item in await self._client.list_drive_items(self._folder_id)
if (description := item.description) is None: if item.description and backup_id in item.description
continue ),
if backup_id in description: None,
return item
return None
def _get_backup_file_item(self, backup_id: str) -> DriveItemItemRequestBuilder:
return self._items.by_drive_item_id(f"{self._folder_id}:/{backup_id}:")
async def _upload_file(
self, upload_url: str, stream: AsyncIterator[bytes], total_size: int
) -> None:
"""Use custom large file upload; SDK does not support stream."""
adapter = GraphRequestAdapter(
auth_provider=AnonymousAuthenticationProvider(),
client=get_async_client(self._hass),
) )
async def async_upload(
start: int, end: int, chunk_data: bytes
) -> LargeFileUploadSession:
info = RequestInformation()
info.url = upload_url
info.http_method = Method.PUT
info.headers = HeadersCollection()
info.headers.try_add("Content-Range", f"bytes {start}-{end}/{total_size}")
info.headers.try_add("Content-Length", str(len(chunk_data)))
info.headers.try_add("Content-Type", "application/octet-stream")
_LOGGER.debug(info.headers.get_all())
info.set_stream_content(chunk_data)
result = await adapter.send_async(info, LargeFileUploadSession, {})
_LOGGER.debug("Next expected range: %s", result.next_expected_ranges)
return result
start = 0
buffer: list[bytes] = []
buffer_size = 0
retries = 0
async for chunk in stream:
buffer.append(chunk)
buffer_size += len(chunk)
if buffer_size >= UPLOAD_CHUNK_SIZE:
chunk_data = b"".join(buffer)
uploaded_chunks = 0
while (
buffer_size > UPLOAD_CHUNK_SIZE
): # Loop in case the buffer is >= UPLOAD_CHUNK_SIZE * 2
slice_start = uploaded_chunks * UPLOAD_CHUNK_SIZE
try:
await async_upload(
start,
start + UPLOAD_CHUNK_SIZE - 1,
chunk_data[slice_start : slice_start + UPLOAD_CHUNK_SIZE],
)
except APIError as err:
if (
err.response_status_code and err.response_status_code < 500
): # no retry on 4xx errors
raise
if retries < MAX_RETRIES:
await asyncio.sleep(2**retries)
retries += 1
continue
raise
except TimeoutException:
if retries < MAX_RETRIES:
retries += 1
continue
raise
retries = 0
start += UPLOAD_CHUNK_SIZE
uploaded_chunks += 1
buffer_size -= UPLOAD_CHUNK_SIZE
buffer = [chunk_data[UPLOAD_CHUNK_SIZE * uploaded_chunks :]]
# upload the remaining bytes
if buffer:
_LOGGER.debug("Last chunk")
chunk_data = b"".join(buffer)
await async_upload(start, start + len(chunk_data) - 1, chunk_data)

View File

@ -4,16 +4,13 @@ from collections.abc import Mapping
import logging import logging
from typing import Any, cast from typing import Any, cast
from kiota_abstractions.api_error import APIError from onedrive_personal_sdk.clients.client import OneDriveClient
from kiota_abstractions.authentication import BaseBearerTokenAuthenticationProvider from onedrive_personal_sdk.exceptions import OneDriveException
from kiota_abstractions.method import Method
from kiota_abstractions.request_information import RequestInformation
from msgraph import GraphRequestAdapter, GraphServiceClient
from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_TOKEN from homeassistant.const import CONF_ACCESS_TOKEN, CONF_TOKEN
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.config_entry_oauth2_flow import AbstractOAuth2FlowHandler from homeassistant.helpers.config_entry_oauth2_flow import AbstractOAuth2FlowHandler
from homeassistant.helpers.httpx_client import get_async_client
from .api import OneDriveConfigFlowAccessTokenProvider from .api import OneDriveConfigFlowAccessTokenProvider
from .const import DOMAIN, OAUTH_SCOPES from .const import DOMAIN, OAUTH_SCOPES
@ -39,48 +36,24 @@ class OneDriveConfigFlow(AbstractOAuth2FlowHandler, domain=DOMAIN):
data: dict[str, Any], data: dict[str, Any],
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Handle the initial step.""" """Handle the initial step."""
auth_provider = BaseBearerTokenAuthenticationProvider( token_provider = OneDriveConfigFlowAccessTokenProvider(
access_token_provider=OneDriveConfigFlowAccessTokenProvider(
cast(str, data[CONF_TOKEN][CONF_ACCESS_TOKEN]) cast(str, data[CONF_TOKEN][CONF_ACCESS_TOKEN])
) )
)
adapter = GraphRequestAdapter( graph_client = OneDriveClient(
auth_provider=auth_provider, token_provider, async_get_clientsession(self.hass)
client=get_async_client(self.hass),
) )
graph_client = GraphServiceClient(
request_adapter=adapter,
scopes=OAUTH_SCOPES,
)
# need to get adapter from client, as client changes it
request_adapter = cast(GraphRequestAdapter, graph_client.request_adapter)
request_info = RequestInformation(
method=Method.GET,
url_template="{+baseurl}/me/drive/special/approot",
path_parameters={},
)
parent_span = request_adapter.start_tracing_span(request_info, "get_approot")
# get the OneDrive id
# use low level methods, to avoid files.read permissions
# which would be required by drives.me.get()
try: try:
response = await request_adapter.get_http_response_message( approot = await graph_client.get_approot()
request_info=request_info, parent_span=parent_span except OneDriveException:
)
except APIError:
self.logger.exception("Failed to connect to OneDrive") self.logger.exception("Failed to connect to OneDrive")
return self.async_abort(reason="connection_error") return self.async_abort(reason="connection_error")
except Exception: except Exception:
self.logger.exception("Unknown error") self.logger.exception("Unknown error")
return self.async_abort(reason="unknown") return self.async_abort(reason="unknown")
drive: dict = response.json() await self.async_set_unique_id(approot.parent_reference.drive_id)
await self.async_set_unique_id(drive["parentReference"]["driveId"])
if self.source == SOURCE_REAUTH: if self.source == SOURCE_REAUTH:
reauth_entry = self._get_reauth_entry() reauth_entry = self._get_reauth_entry()
@ -94,10 +67,11 @@ class OneDriveConfigFlow(AbstractOAuth2FlowHandler, domain=DOMAIN):
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured()
user = drive.get("createdBy", {}).get("user", {}).get("displayName") title = (
f"{approot.created_by.user.display_name}'s OneDrive"
title = f"{user}'s OneDrive" if user else "OneDrive" if approot.created_by.user and approot.created_by.user.display_name
else "OneDrive"
)
return self.async_create_entry(title=title, data=data) return self.async_create_entry(title=title, data=data)
async def async_step_reauth( async def async_step_reauth(

View File

@ -7,7 +7,7 @@
"documentation": "https://www.home-assistant.io/integrations/onedrive", "documentation": "https://www.home-assistant.io/integrations/onedrive",
"integration_type": "service", "integration_type": "service",
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"loggers": ["msgraph", "msgraph-core", "kiota"], "loggers": ["onedrive_personal_sdk"],
"quality_scale": "bronze", "quality_scale": "bronze",
"requirements": ["msgraph-sdk==1.16.0"] "requirements": ["onedrive-personal-sdk==0.0.1"]
} }

View File

@ -23,31 +23,18 @@
"connection_error": "Failed to connect to OneDrive.", "connection_error": "Failed to connect to OneDrive.",
"wrong_drive": "New account does not contain previously configured OneDrive.", "wrong_drive": "New account does not contain previously configured OneDrive.",
"unknown": "[%key:common::config_flow::error::unknown%]", "unknown": "[%key:common::config_flow::error::unknown%]",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]", "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
"failed_to_create_folder": "Failed to create backup folder"
}, },
"create_entry": { "create_entry": {
"default": "[%key:common::config_flow::create_entry::authenticated%]" "default": "[%key:common::config_flow::create_entry::authenticated%]"
} }
}, },
"exceptions": { "exceptions": {
"backup_not_found": {
"message": "Backup not found"
},
"backup_no_content": {
"message": "Backup has no content"
},
"backup_no_upload_session": {
"message": "Failed to start backup upload"
},
"authentication_failed": { "authentication_failed": {
"message": "Authentication failed" "message": "Authentication failed"
}, },
"failed_to_get_folder": { "failed_to_get_folder": {
"message": "Failed to get {folder} folder" "message": "Failed to get {folder} folder"
},
"failed_to_create_folder": {
"message": "Failed to create {folder} folder"
} }
} }
} }

View File

@ -24,6 +24,7 @@ from homeassistant.helpers.selector import (
SelectOptionDict, SelectOptionDict,
SelectSelector, SelectSelector,
SelectSelectorConfig, SelectSelectorConfig,
SelectSelectorMode,
TemplateSelector, TemplateSelector,
) )
from homeassistant.helpers.typing import VolDictType from homeassistant.helpers.typing import VolDictType
@ -32,14 +33,17 @@ from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
CONF_PROMPT, CONF_PROMPT,
CONF_REASONING_EFFORT,
CONF_RECOMMENDED, CONF_RECOMMENDED,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DOMAIN, DOMAIN,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE, RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P, RECOMMENDED_TOP_P,
UNSUPPORTED_MODELS,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -124,13 +128,18 @@ class OpenAIOptionsFlow(OptionsFlow):
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Manage the options.""" """Manage the options."""
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
errors: dict[str, str] = {}
if user_input is not None: if user_input is not None:
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended: if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if user_input[CONF_LLM_HASS_API] == "none": if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API) user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
if user_input.get(CONF_CHAT_MODEL) in UNSUPPORTED_MODELS:
errors[CONF_CHAT_MODEL] = "model_not_supported"
else:
return self.async_create_entry(title="", data=user_input)
else:
# Re-render the options again, now with the recommended options shown/hidden # Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED] self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
@ -144,6 +153,7 @@ class OpenAIOptionsFlow(OptionsFlow):
return self.async_show_form( return self.async_show_form(
step_id="init", step_id="init",
data_schema=vol.Schema(schema), data_schema=vol.Schema(schema),
errors=errors,
) )
@ -210,6 +220,17 @@ def openai_config_option_schema(
description={"suggested_value": options.get(CONF_TEMPERATURE)}, description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=RECOMMENDED_TEMPERATURE, default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)), ): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
vol.Optional(
CONF_REASONING_EFFORT,
description={"suggested_value": options.get(CONF_REASONING_EFFORT)},
default=RECOMMENDED_REASONING_EFFORT,
): SelectSelector(
SelectSelectorConfig(
options=["low", "medium", "high"],
translation_key="reasoning_effort",
mode=SelectSelectorMode.DROPDOWN,
)
),
} }
) )
return schema return schema

View File

@ -15,3 +15,17 @@ CONF_TOP_P = "top_p"
RECOMMENDED_TOP_P = 1.0 RECOMMENDED_TOP_P = 1.0
CONF_TEMPERATURE = "temperature" CONF_TEMPERATURE = "temperature"
RECOMMENDED_TEMPERATURE = 1.0 RECOMMENDED_TEMPERATURE = 1.0
CONF_REASONING_EFFORT = "reasoning_effort"
RECOMMENDED_REASONING_EFFORT = "low"
UNSUPPORTED_MODELS = [
"o1-mini",
"o1-mini-2024-09-12",
"o1-preview",
"o1-preview-2024-09-12",
"gpt-4o-realtime-preview",
"gpt-4o-realtime-preview-2024-12-17",
"gpt-4o-realtime-preview-2024-10-01",
"gpt-4o-mini-realtime-preview",
"gpt-4o-mini-realtime-preview-2024-12-17",
]

View File

@ -31,12 +31,14 @@ from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
CONF_PROMPT, CONF_PROMPT,
CONF_REASONING_EFFORT,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DOMAIN, DOMAIN,
LOGGER, LOGGER,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE, RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P, RECOMMENDED_TOP_P,
) )
@ -68,7 +70,9 @@ def _format_tool(
return ChatCompletionToolParam(type="function", function=tool_spec) return ChatCompletionToolParam(type="function", function=tool_spec)
def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam: def _convert_message_to_param(
message: ChatCompletionMessage,
) -> ChatCompletionMessageParam:
"""Convert from class to TypedDict.""" """Convert from class to TypedDict."""
tool_calls: list[ChatCompletionMessageToolCallParam] = [] tool_calls: list[ChatCompletionMessageToolCallParam] = []
if message.tool_calls: if message.tool_calls:
@ -92,17 +96,42 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
return param return param
def _chat_message_convert( def _convert_content_to_param(
message: conversation.Content content: conversation.Content,
| conversation.NativeContent[ChatCompletionMessageParam],
) -> ChatCompletionMessageParam: ) -> ChatCompletionMessageParam:
"""Convert any native chat message for this agent to the native format.""" """Convert any native chat message for this agent to the native format."""
if message.role == "native": if content.role == "tool_result":
# mypy doesn't understand that checking role ensures content type assert type(content) is conversation.ToolResultContent
return message.content # type: ignore[return-value] return ChatCompletionToolMessageParam(
role="tool",
tool_call_id=content.tool_call_id,
content=json.dumps(content.tool_result),
)
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
role = content.role
if role == "system":
role = "developer"
return cast( return cast(
ChatCompletionMessageParam, ChatCompletionMessageParam,
{"role": message.role, "content": message.content}, {"role": content.role, "content": content.content}, # type: ignore[union-attr]
)
# Handle the Assistant content including tool calls.
assert type(content) is conversation.AssistantContent
return ChatCompletionAssistantMessageParam(
role="assistant",
content=content.content,
tool_calls=[
ChatCompletionMessageToolCallParam(
id=tool_call.id,
function=Function(
arguments=json.dumps(tool_call.tool_args),
name=tool_call.tool_name,
),
type="function",
)
for tool_call in content.tool_calls
],
) )
@ -166,14 +195,14 @@ class OpenAIConversationEntity(
async def _async_handle_message( async def _async_handle_message(
self, self,
user_input: conversation.ConversationInput, user_input: conversation.ConversationInput,
session: conversation.ChatLog[ChatCompletionMessageParam], chat_log: conversation.ChatLog,
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Call the API.""" """Call the API."""
assert user_input.agent_id assert user_input.agent_id
options = self.entry.options options = self.entry.options
try: try:
await session.async_update_llm_data( await chat_log.async_update_llm_data(
DOMAIN, DOMAIN,
user_input, user_input,
options.get(CONF_LLM_HASS_API), options.get(CONF_LLM_HASS_API),
@ -183,73 +212,77 @@ class OpenAIConversationEntity(
return err.as_conversation_result() return err.as_conversation_result()
tools: list[ChatCompletionToolParam] | None = None tools: list[ChatCompletionToolParam] | None = None
if session.llm_api: if chat_log.llm_api:
tools = [ tools = [
_format_tool(tool, session.llm_api.custom_serializer) _format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in session.llm_api.tools for tool in chat_log.llm_api.tools
] ]
messages = [ model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
_chat_message_convert(message) for message in session.async_get_messages() messages = [_convert_content_to_param(content) for content in chat_log.content]
]
client = self.entry.runtime_data client = self.entry.runtime_data
# To prevent infinite loops, we limit the number of iterations # To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS): for _iteration in range(MAX_TOOL_ITERATIONS):
try: model_args = {
result = await client.chat.completions.create( "model": model,
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), "messages": messages,
messages=messages, "tools": tools or NOT_GIVEN,
tools=tools or NOT_GIVEN, "max_completion_tokens": options.get(
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), ),
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), "top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
user=session.conversation_id, "temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": chat_log.conversation_id,
}
if model.startswith("o"):
model_args["reasoning_effort"] = options.get(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
) )
try:
result = await client.chat.completions.create(**model_args)
except openai.OpenAIError as err: except openai.OpenAIError as err:
LOGGER.error("Error talking to OpenAI: %s", err) LOGGER.error("Error talking to OpenAI: %s", err)
raise HomeAssistantError("Error talking to OpenAI") from err raise HomeAssistantError("Error talking to OpenAI") from err
LOGGER.debug("Response %s", result) LOGGER.debug("Response %s", result)
response = result.choices[0].message response = result.choices[0].message
messages.append(_message_convert(response)) messages.append(_convert_message_to_param(response))
session.async_add_message( tool_calls: list[llm.ToolInput] | None = None
conversation.Content( if response.tool_calls:
role=response.role, tool_calls = [
agent_id=user_input.agent_id, llm.ToolInput(
content=response.content or "", id=tool_call.id,
),
)
if not response.tool_calls or not session.llm_api:
break
for tool_call in response.tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call.function.name, tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments), tool_args=json.loads(tool_call.function.arguments),
) )
tool_response = await session.async_call_tool(tool_input) for tool_call in response.tool_calls
messages.append( ]
ChatCompletionToolMessageParam(
role="tool", messages.extend(
tool_call_id=tool_call.id, [
content=json.dumps(tool_response), _convert_content_to_param(tool_response)
) async for tool_response in chat_log.async_add_assistant_content(
) conversation.AssistantContent(
session.async_add_message(
conversation.NativeContent(
agent_id=user_input.agent_id, agent_id=user_input.agent_id,
content=messages[-1], content=response.content or "",
tool_calls=tool_calls,
) )
) )
]
)
if not tool_calls:
break
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response.content or "") intent_response.async_set_speech(response.content or "")
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=session.conversation_id response=intent_response, conversation_id=chat_log.conversation_id
) )
async def _async_entry_update_listener( async def _async_entry_update_listener(

View File

@ -23,12 +23,26 @@
"temperature": "Temperature", "temperature": "Temperature",
"top_p": "Top P", "top_p": "Top P",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"recommended": "Recommended model settings" "recommended": "Recommended model settings",
"reasoning_effort": "Reasoning effort"
}, },
"data_description": { "data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template." "prompt": "Instruct how the LLM should respond. This can be a template.",
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt (for certain reasoning models)"
} }
} }
},
"error": {
"model_not_supported": "This model is not supported, please select a different model"
}
},
"selector": {
"reasoning_effort": {
"options": {
"low": "Low",
"medium": "Medium",
"high": "High"
}
} }
}, },
"services": { "services": {

View File

@ -18,7 +18,13 @@ from homeassistant.const import (
STATE_OFF, STATE_OFF,
STATE_ON, STATE_ON,
) )
from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.core import (
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
callback,
)
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.collection import ( from homeassistant.helpers.collection import (
CollectionEntity, CollectionEntity,
@ -44,6 +50,7 @@ from .const import (
CONF_TO, CONF_TO,
DOMAIN, DOMAIN,
LOGGER, LOGGER,
SERVICE_GET,
WEEKDAY_TO_CONF, WEEKDAY_TO_CONF,
) )
@ -205,6 +212,14 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
reload_service_handler, reload_service_handler,
) )
component.async_register_entity_service(
SERVICE_GET,
{},
async_get_schedule_service,
supports_response=SupportsResponse.ONLY,
)
await component.async_setup(config)
return True return True
@ -296,6 +311,10 @@ class Schedule(CollectionEntity):
self.async_on_remove(self._clean_up_listener) self.async_on_remove(self._clean_up_listener)
self._update() self._update()
def get_schedule(self) -> ConfigType:
"""Return the schedule."""
return {d: self._config[d] for d in WEEKDAY_TO_CONF.values()}
@callback @callback
def _update(self, _: datetime | None = None) -> None: def _update(self, _: datetime | None = None) -> None:
"""Update the states of the schedule.""" """Update the states of the schedule."""
@ -390,3 +409,10 @@ class Schedule(CollectionEntity):
data_keys.update(time_range_custom_data.keys()) data_keys.update(time_range_custom_data.keys())
return frozenset(data_keys) return frozenset(data_keys)
async def async_get_schedule_service(
schedule: Schedule, service_call: ServiceCall
) -> ServiceResponse:
"""Return the schedule configuration."""
return schedule.get_schedule()

View File

@ -37,3 +37,5 @@ WEEKDAY_TO_CONF: Final = {
5: CONF_SATURDAY, 5: CONF_SATURDAY,
6: CONF_SUNDAY, 6: CONF_SUNDAY,
} }
SERVICE_GET: Final = "get_schedule"

View File

@ -2,6 +2,9 @@
"services": { "services": {
"reload": { "reload": {
"service": "mdi:reload" "service": "mdi:reload"
},
"get_schedule": {
"service": "mdi:calendar-export"
} }
} }
} }

View File

@ -1 +1,5 @@
reload: reload:
get_schedule:
target:
entity:
domain: schedule

View File

@ -25,6 +25,10 @@
"reload": { "reload": {
"name": "[%key:common::action::reload%]", "name": "[%key:common::action::reload%]",
"description": "Reloads schedules from the YAML-configuration." "description": "Reloads schedules from the YAML-configuration."
},
"get_schedule": {
"name": "Get schedule",
"description": "Retrieve one or multiple schedules."
} }
} }
} }

View File

@ -1,9 +1,9 @@
{ {
"config": { "config": {
"flow_title": "Add Shark IQ Account", "flow_title": "Add Shark IQ account",
"step": { "step": {
"user": { "user": {
"description": "Sign into your Shark Clean account to control your devices.", "description": "Sign into your SharkClean account to control your devices.",
"data": { "data": {
"username": "[%key:common::config_flow::data::username%]", "username": "[%key:common::config_flow::data::username%]",
"password": "[%key:common::config_flow::data::password%]", "password": "[%key:common::config_flow::data::password%]",
@ -37,18 +37,18 @@
"region": { "region": {
"options": { "options": {
"europe": "Europe", "europe": "Europe",
"elsewhere": "Everywhere Else" "elsewhere": "Everywhere else"
} }
} }
}, },
"exceptions": { "exceptions": {
"invalid_room": { "invalid_room": {
"message": "The room {room} is unavailable to your vacuum. Make sure all rooms match the Shark App, including capitalization." "message": "The room {room} is unavailable to your vacuum. Make sure all rooms match the SharkClean app, including capitalization."
} }
}, },
"services": { "services": {
"clean_room": { "clean_room": {
"name": "Clean Room", "name": "Clean room",
"description": "Cleans a specific user-defined room or set of rooms.", "description": "Cleans a specific user-defined room or set of rooms.",
"fields": { "fields": {
"rooms": { "rooms": {

View File

@ -272,6 +272,18 @@ RPC_SENSORS: Final = {
entity_category=EntityCategory.DIAGNOSTIC, entity_category=EntityCategory.DIAGNOSTIC,
entity_class=RpcBluTrvBinarySensor, entity_class=RpcBluTrvBinarySensor,
), ),
"flood": RpcBinarySensorDescription(
key="flood",
sub_key="alarm",
name="Flood",
device_class=BinarySensorDeviceClass.MOISTURE,
),
"mute": RpcBinarySensorDescription(
key="flood",
sub_key="mute",
name="Mute",
entity_category=EntityCategory.DIAGNOSTIC,
),
} }

View File

@ -14,6 +14,7 @@ from homeassistant.config_entries import SOURCE_USER, ConfigFlow, ConfigFlowResu
from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PASSWORD, CONF_USERNAME from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PASSWORD, CONF_USERNAME
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.device_registry import format_mac from homeassistant.helpers.device_registry import format_mac
from homeassistant.helpers.service_info.dhcp import DhcpServiceInfo
from homeassistant.helpers.service_info.zeroconf import ZeroconfServiceInfo from homeassistant.helpers.service_info.zeroconf import ZeroconfServiceInfo
from .const import DOMAIN from .const import DOMAIN
@ -35,7 +36,8 @@ STEP_AUTH_DATA_SCHEMA = vol.Schema(
class SmlightConfigFlow(ConfigFlow, domain=DOMAIN): class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for SMLIGHT Zigbee.""" """Handle a config flow for SMLIGHT Zigbee."""
host: str _host: str
_device_name: str
client: Api2 client: Api2
async def async_step_user( async def async_step_user(
@ -45,11 +47,13 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
errors: dict[str, str] = {} errors: dict[str, str] = {}
if user_input is not None: if user_input is not None:
self.host = user_input[CONF_HOST] self._host = user_input[CONF_HOST]
self.client = Api2(self.host, session=async_get_clientsession(self.hass)) self.client = Api2(self._host, session=async_get_clientsession(self.hass))
try: try:
info = await self.client.get_info() info = await self.client.get_info()
self._host = str(info.device_ip)
self._device_name = str(info.hostname)
if info.model not in Devices: if info.model not in Devices:
return self.async_abort(reason="unsupported_device") return self.async_abort(reason="unsupported_device")
@ -93,15 +97,14 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
self, discovery_info: ZeroconfServiceInfo self, discovery_info: ZeroconfServiceInfo
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Handle a discovered Lan coordinator.""" """Handle a discovered Lan coordinator."""
local_name = discovery_info.hostname[:-1] mac: str | None = discovery_info.properties.get("mac")
node_name = local_name.removesuffix(".local") self._device_name = discovery_info.hostname.removesuffix(".local.")
self._host = discovery_info.host
self.host = local_name self.context["title_placeholders"] = {CONF_NAME: self._device_name}
self.context["title_placeholders"] = {CONF_NAME: node_name} self.client = Api2(self._host, session=async_get_clientsession(self.hass))
self.client = Api2(self.host, session=async_get_clientsession(self.hass))
mac = discovery_info.properties.get("mac") # fallback for legacy firmware older than v2.3.x
# fallback for legacy firmware
if mac is None: if mac is None:
try: try:
info = await self.client.get_info() info = await self.client.get_info()
@ -111,7 +114,7 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
mac = info.MAC mac = info.MAC
await self.async_set_unique_id(format_mac(mac)) await self.async_set_unique_id(format_mac(mac))
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured(updates={CONF_HOST: self._host})
return await self.async_step_confirm_discovery() return await self.async_step_confirm_discovery()
@ -122,7 +125,6 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
errors: dict[str, str] = {} errors: dict[str, str] = {}
if user_input is not None: if user_input is not None:
user_input[CONF_HOST] = self.host
try: try:
info = await self.client.get_info() info = await self.client.get_info()
@ -142,7 +144,7 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_show_form( return self.async_show_form(
step_id="confirm_discovery", step_id="confirm_discovery",
description_placeholders={"host": self.host}, description_placeholders={"host": self._device_name},
errors=errors, errors=errors,
) )
@ -151,8 +153,8 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Handle reauth when API Authentication failed.""" """Handle reauth when API Authentication failed."""
self.host = entry_data[CONF_HOST] self._host = entry_data[CONF_HOST]
self.client = Api2(self.host, session=async_get_clientsession(self.hass)) self.client = Api2(self._host, session=async_get_clientsession(self.hass))
return await self.async_step_reauth_confirm() return await self.async_step_reauth_confirm()
@ -182,6 +184,16 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
errors=errors, errors=errors,
) )
async def async_step_dhcp(
self, discovery_info: DhcpServiceInfo
) -> ConfigFlowResult:
"""Handle DHCP discovery."""
await self.async_set_unique_id(format_mac(discovery_info.macaddress))
self._abort_if_unique_id_configured(updates={CONF_HOST: discovery_info.ip})
# This should never happen since we only listen to DHCP requests
# for configured devices.
return self.async_abort(reason="already_configured")
async def _async_check_auth_required(self, user_input: dict[str, Any]) -> bool: async def _async_check_auth_required(self, user_input: dict[str, Any]) -> bool:
"""Check if auth required and attempt to authenticate.""" """Check if auth required and attempt to authenticate."""
if await self.client.check_auth_needed(): if await self.client.check_auth_needed():
@ -200,11 +212,10 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
await self.async_set_unique_id( await self.async_set_unique_id(
format_mac(info.MAC), raise_on_progress=self.source != SOURCE_USER format_mac(info.MAC), raise_on_progress=self.source != SOURCE_USER
) )
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured(updates={CONF_HOST: self._host})
if user_input.get(CONF_HOST) is None: user_input[CONF_HOST] = self._host
user_input[CONF_HOST] = self.host
assert info.model is not None assert info.model is not None
title = self.context.get("title_placeholders", {}).get(CONF_NAME) or info.model title = self._device_name or info.model
return self.async_create_entry(title=title, data=user_input) return self.async_create_entry(title=title, data=user_input)

View File

@ -3,10 +3,15 @@
"name": "SMLIGHT SLZB", "name": "SMLIGHT SLZB",
"codeowners": ["@tl-sl"], "codeowners": ["@tl-sl"],
"config_flow": true, "config_flow": true,
"dhcp": [
{
"registered_devices": true
}
],
"documentation": "https://www.home-assistant.io/integrations/smlight", "documentation": "https://www.home-assistant.io/integrations/smlight",
"integration_type": "device", "integration_type": "device",
"iot_class": "local_push", "iot_class": "local_push",
"requirements": ["pysmlight==0.2.1"], "requirements": ["pysmlight==0.2.2"],
"zeroconf": [ "zeroconf": [
{ {
"type": "_slzb-06._tcp.local." "type": "_slzb-06._tcp.local."

View File

@ -65,6 +65,7 @@ BINARY_SENSORS = [
key="currently_obstructed", key="currently_obstructed",
translation_key="currently_obstructed", translation_key="currently_obstructed",
device_class=BinarySensorDeviceClass.PROBLEM, device_class=BinarySensorDeviceClass.PROBLEM,
entity_category=EntityCategory.DIAGNOSTIC,
value_fn=lambda data: data.status["currently_obstructed"], value_fn=lambda data: data.status["currently_obstructed"],
), ),
StarlinkBinarySensorEntityDescription( StarlinkBinarySensorEntityDescription(
@ -114,4 +115,9 @@ BINARY_SENSORS = [
entity_category=EntityCategory.DIAGNOSTIC, entity_category=EntityCategory.DIAGNOSTIC,
value_fn=lambda data: data.alert["alert_unexpected_location"], value_fn=lambda data: data.alert["alert_unexpected_location"],
), ),
StarlinkBinarySensorEntityDescription(
key="connection",
device_class=BinarySensorDeviceClass.CONNECTIVITY,
value_fn=lambda data: data.status["state"] == "CONNECTED",
),
] ]

View File

@ -14,7 +14,7 @@
}, },
"reconfigure": { "reconfigure": {
"title": "Reconfigure your Tado", "title": "Reconfigure your Tado",
"description": "Reconfigure the entry, for your account: `{username}`.", "description": "Reconfigure the entry for your account: `{username}`.",
"data": { "data": {
"password": "[%key:common::config_flow::data::password%]" "password": "[%key:common::config_flow::data::password%]"
}, },
@ -25,7 +25,7 @@
}, },
"error": { "error": {
"unknown": "[%key:common::config_flow::error::unknown%]", "unknown": "[%key:common::config_flow::error::unknown%]",
"no_homes": "There are no homes linked to this tado account.", "no_homes": "There are no homes linked to this Tado account.",
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]", "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]" "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
} }
@ -33,7 +33,7 @@
"options": { "options": {
"step": { "step": {
"init": { "init": {
"description": "Fallback mode lets you choose when to fallback to Smart Schedule from your manual zone overlay. (NEXT_TIME_BLOCK:= Change at next Smart Schedule change; MANUAL:= Dont change until you cancel; TADO_DEFAULT:= Change based on your setting in Tado App).", "description": "Fallback mode lets you choose when to fallback to Smart Schedule from your manual zone overlay. (NEXT_TIME_BLOCK:= Change at next Smart Schedule change; MANUAL:= Don't change until you cancel; TADO_DEFAULT:= Change based on your setting in the Tado app).",
"data": { "data": {
"fallback": "Choose fallback mode." "fallback": "Choose fallback mode."
}, },
@ -102,11 +102,11 @@
}, },
"time_period": { "time_period": {
"name": "Time period", "name": "Time period",
"description": "Choose this or Overlay. Set the time period for the change if you want to be specific. Alternatively use Overlay." "description": "Choose this or 'Overlay'. Set the time period for the change if you want to be specific."
}, },
"requested_overlay": { "requested_overlay": {
"name": "Overlay", "name": "Overlay",
"description": "Choose this or Time Period. Allows you to choose an overlay. MANUAL:=Overlay until user removes; NEXT_TIME_BLOCK:=Overlay until next timeblock; TADO_DEFAULT:=Overlay based on tado app setting." "description": "Choose this or 'Time period'. Allows you to choose an overlay. MANUAL:=Overlay until user removes; NEXT_TIME_BLOCK:=Overlay until next timeblock; TADO_DEFAULT:=Overlay based on Tado app setting."
} }
} }
}, },
@ -151,8 +151,8 @@
}, },
"issues": { "issues": {
"water_heater_fallback": { "water_heater_fallback": {
"title": "Tado Water Heater entities now support fallback options", "title": "Tado water heater entities now support fallback options",
"description": "Due to added support for water heaters entities, these entities may use different overlay. Please configure integration entity and tado app water heater zone overlay options. Otherwise, please configure the integration entity and Tado app water heater zone overlay options (under Settings -> Rooms & Devices -> Hot Water)." "description": "Due to added support for water heaters entities, these entities may use a different overlay. Please configure the integration entity and Tado app water heater zone overlay options (under Settings -> Rooms & Devices -> Hot Water)."
} }
} }
} }

View File

@ -7,6 +7,7 @@ from pyvesync import VeSync
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_PASSWORD, CONF_USERNAME, Platform from homeassistant.const import CONF_PASSWORD, CONF_USERNAME, Platform
from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from .common import async_generate_device_list from .common import async_generate_device_list
@ -91,3 +92,37 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.data.pop(DOMAIN) hass.data.pop(DOMAIN)
return unload_ok return unload_ok
async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Migrate old entry."""
_LOGGER.debug(
"Migrating VeSync config entry: %s minor version: %s",
config_entry.version,
config_entry.minor_version,
)
if config_entry.minor_version == 1:
# Migrate switch/outlets entity to a new unique ID
_LOGGER.debug("Migrating VeSync config entry from version 1 to version 2")
entity_registry = er.async_get(hass)
registry_entries = er.async_entries_for_config_entry(
entity_registry, config_entry.entry_id
)
for reg_entry in registry_entries:
if "-" not in reg_entry.unique_id and reg_entry.entity_id.startswith(
Platform.SWITCH
):
_LOGGER.debug(
"Migrating switch/outlet entity from unique_id: %s to unique_id: %s",
reg_entry.unique_id,
reg_entry.unique_id + "-device_status",
)
entity_registry.async_update_entity(
reg_entry.entity_id,
new_unique_id=reg_entry.unique_id + "-device_status",
)
else:
_LOGGER.debug("Skipping entity with unique_id: %s", reg_entry.unique_id)
hass.config_entries.async_update_entry(config_entry, minor_version=2)
return True

View File

@ -24,6 +24,7 @@ class VeSyncFlowHandler(ConfigFlow, domain=DOMAIN):
"""Handle a config flow.""" """Handle a config flow."""
VERSION = 1 VERSION = 1
MINOR_VERSION = 2
@callback @callback
def _show_form(self, errors: dict[str, str] | None = None) -> ConfigFlowResult: def _show_form(self, errors: dict[str, str] | None = None) -> ConfigFlowResult:

View File

@ -12,5 +12,5 @@
"documentation": "https://www.home-assistant.io/integrations/vesync", "documentation": "https://www.home-assistant.io/integrations/vesync",
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"loggers": ["pyvesync"], "loggers": ["pyvesync"],
"requirements": ["pyvesync==2.1.16"] "requirements": ["pyvesync==2.1.17"]
} }

View File

@ -83,6 +83,7 @@ class VeSyncSwitchHA(VeSyncBaseSwitch, SwitchEntity):
) -> None: ) -> None:
"""Initialize the VeSync switch device.""" """Initialize the VeSync switch device."""
super().__init__(plug, coordinator) super().__init__(plug, coordinator)
self._attr_unique_id = f"{super().unique_id}-device_status"
self.smartplug = plug self.smartplug = plug
@ -94,4 +95,5 @@ class VeSyncLightSwitch(VeSyncBaseSwitch, SwitchEntity):
) -> None: ) -> None:
"""Initialize Light Switch device class.""" """Initialize Light Switch device class."""
super().__init__(switch, coordinator) super().__init__(switch, coordinator)
self._attr_unique_id = f"{super().unique_id}-device_status"
self.switch = switch self.switch = switch

View File

@ -616,6 +616,10 @@ DHCP: Final[list[dict[str, str | bool]]] = [
"hostname": "hub*", "hostname": "hub*",
"macaddress": "286D97*", "macaddress": "286D97*",
}, },
{
"domain": "smlight",
"registered_devices": True,
},
{ {
"domain": "solaredge", "domain": "solaredge",
"hostname": "target", "hostname": "target",

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass, field as dc_field
from datetime import timedelta from datetime import timedelta
from decimal import Decimal from decimal import Decimal
from enum import Enum from enum import Enum
@ -36,6 +36,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import dt as dt_util, yaml as yaml_util from homeassistant.util import dt as dt_util, yaml as yaml_util
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType from homeassistant.util.json import JsonObjectType
from homeassistant.util.ulid import ulid_now
from . import ( from . import (
area_registry as ar, area_registry as ar,
@ -139,6 +140,8 @@ class ToolInput:
tool_name: str tool_name: str
tool_args: dict[str, Any] tool_args: dict[str, Any]
# Using lambda for default to allow patching in tests
id: str = dc_field(default_factory=lambda: ulid_now()) # pylint: disable=unnecessary-lambda
class Tool: class Tool:

View File

@ -29,7 +29,7 @@ certifi>=2021.5.30
ciso8601==2.3.2 ciso8601==2.3.2
cronsim==2.6 cronsim==2.6
cryptography==44.0.0 cryptography==44.0.0
dbus-fast==2.31.0 dbus-fast==2.32.0
fnv-hash-fast==1.2.2 fnv-hash-fast==1.2.2
go2rtc-client==0.1.2 go2rtc-client==0.1.2
ha-ffmpeg==3.2.2 ha-ffmpeg==3.2.2
@ -53,7 +53,7 @@ psutil-home-assistant==0.0.1
PyJWT==2.10.1 PyJWT==2.10.1
pymicro-vad==1.0.1 pymicro-vad==1.0.1
PyNaCl==1.5.0 PyNaCl==1.5.0
pyOpenSSL==24.3.0 pyOpenSSL==25.0.0
pyserial==3.5 pyserial==3.5
pyspeex-noise==1.0.2 pyspeex-noise==1.0.2
python-slugify==8.0.4 python-slugify==8.0.4

View File

@ -59,7 +59,7 @@ dependencies = [
"cryptography==44.0.0", "cryptography==44.0.0",
"Pillow==11.1.0", "Pillow==11.1.0",
"propcache==0.2.1", "propcache==0.2.1",
"pyOpenSSL==24.3.0", "pyOpenSSL==25.0.0",
"orjson==3.10.12", "orjson==3.10.12",
"packaging>=23.1", "packaging>=23.1",
"psutil-home-assistant==0.0.1", "psutil-home-assistant==0.0.1",

2
requirements.txt generated
View File

@ -31,7 +31,7 @@ PyJWT==2.10.1
cryptography==44.0.0 cryptography==44.0.0
Pillow==11.1.0 Pillow==11.1.0
propcache==0.2.1 propcache==0.2.1
pyOpenSSL==24.3.0 pyOpenSSL==25.0.0
orjson==3.10.12 orjson==3.10.12
packaging>=23.1 packaging>=23.1
psutil-home-assistant==0.0.1 psutil-home-assistant==0.0.1

18
requirements_all.txt generated
View File

@ -600,7 +600,7 @@ bizkaibus==0.1.1
# homeassistant.components.eq3btsmart # homeassistant.components.eq3btsmart
# homeassistant.components.esphome # homeassistant.components.esphome
bleak-esphome==2.6.0 bleak-esphome==2.7.0
# homeassistant.components.bluetooth # homeassistant.components.bluetooth
bleak-retry-connector==3.8.0 bleak-retry-connector==3.8.0
@ -741,7 +741,7 @@ datadog==0.15.0
datapoint==0.9.9 datapoint==0.9.9
# homeassistant.components.bluetooth # homeassistant.components.bluetooth
dbus-fast==2.31.0 dbus-fast==2.32.0
# homeassistant.components.debugpy # homeassistant.components.debugpy
debugpy==1.8.11 debugpy==1.8.11
@ -1437,9 +1437,6 @@ motioneye-client==0.3.14
# homeassistant.components.bang_olufsen # homeassistant.components.bang_olufsen
mozart-api==4.1.1.116.4 mozart-api==4.1.1.116.4
# homeassistant.components.onedrive
msgraph-sdk==1.16.0
# homeassistant.components.mullvad # homeassistant.components.mullvad
mullvad-api==1.0.0 mullvad-api==1.0.0
@ -1561,6 +1558,9 @@ omnilogic==0.4.5
# homeassistant.components.ondilo_ico # homeassistant.components.ondilo_ico
ondilo==0.5.0 ondilo==0.5.0
# homeassistant.components.onedrive
onedrive-personal-sdk==0.0.1
# homeassistant.components.onvif # homeassistant.components.onvif
onvif-zeep-async==3.2.5 onvif-zeep-async==3.2.5
@ -2205,7 +2205,7 @@ pypalazzetti==0.1.19
pypca==0.0.7 pypca==0.0.7
# homeassistant.components.lcn # homeassistant.components.lcn
pypck==0.8.3 pypck==0.8.5
# homeassistant.components.pjlink # homeassistant.components.pjlink
pypjlink2==1.2.1 pypjlink2==1.2.1
@ -2313,7 +2313,7 @@ pysmarty2==0.10.1
pysml==0.0.12 pysml==0.0.12
# homeassistant.components.smlight # homeassistant.components.smlight
pysmlight==0.2.1 pysmlight==0.2.2
# homeassistant.components.snmp # homeassistant.components.snmp
pysnmp==6.2.6 pysnmp==6.2.6
@ -2391,7 +2391,7 @@ python-gitlab==1.6.0
python-google-drive-api==0.0.2 python-google-drive-api==0.0.2
# homeassistant.components.analytics_insights # homeassistant.components.analytics_insights
python-homeassistant-analytics==0.8.1 python-homeassistant-analytics==0.9.0
# homeassistant.components.homewizard # homeassistant.components.homewizard
python-homewizard-energy==v8.3.2 python-homewizard-energy==v8.3.2
@ -2516,7 +2516,7 @@ pyvera==0.3.15
pyversasense==0.0.6 pyversasense==0.0.6
# homeassistant.components.vesync # homeassistant.components.vesync
pyvesync==2.1.16 pyvesync==2.1.17
# homeassistant.components.vizio # homeassistant.components.vizio
pyvizio==0.1.61 pyvizio==0.1.61

View File

@ -8,31 +8,31 @@
-c homeassistant/package_constraints.txt -c homeassistant/package_constraints.txt
-r requirements_test_pre_commit.txt -r requirements_test_pre_commit.txt
astroid==3.3.8 astroid==3.3.8
coverage==7.6.8 coverage==7.6.10
freezegun==1.5.1 freezegun==1.5.1
license-expression==30.4.0 license-expression==30.4.1
mock-open==1.4.0 mock-open==1.4.0
mypy-dev==1.16.0a1 mypy-dev==1.16.0a1
pre-commit==4.0.0 pre-commit==4.0.0
pydantic==2.10.6 pydantic==2.10.6
pylint==3.3.3 pylint==3.3.4
pylint-per-file-ignores==1.3.2 pylint-per-file-ignores==1.4.0
pipdeptree==2.23.4 pipdeptree==2.25.0
pytest-asyncio==0.24.0 pytest-asyncio==0.25.3
pytest-aiohttp==1.0.5 pytest-aiohttp==1.0.5
pytest-cov==6.0.0 pytest-cov==6.0.0
pytest-freezer==0.4.8 pytest-freezer==0.4.9
pytest-github-actions-annotate-failures==0.2.0 pytest-github-actions-annotate-failures==0.3.0
pytest-socket==0.7.0 pytest-socket==0.7.0
pytest-sugar==1.0.0 pytest-sugar==1.0.0
pytest-timeout==2.3.1 pytest-timeout==2.3.1
pytest-unordered==0.6.1 pytest-unordered==0.6.1
pytest-picked==0.5.0 pytest-picked==0.5.1
pytest-xdist==3.6.1 pytest-xdist==3.6.1
pytest==8.3.4 pytest==8.3.4
requests-mock==1.12.1 requests-mock==1.12.1
respx==0.22.0 respx==0.22.0
syrupy==4.8.0 syrupy==4.8.1
tqdm==4.66.5 tqdm==4.66.5
types-aiofiles==24.1.0.20241221 types-aiofiles==24.1.0.20241221
types-atomicwrites==1.4.5.1 types-atomicwrites==1.4.5.1

View File

@ -528,7 +528,7 @@ bimmer-connected[china]==0.17.2
# homeassistant.components.eq3btsmart # homeassistant.components.eq3btsmart
# homeassistant.components.esphome # homeassistant.components.esphome
bleak-esphome==2.6.0 bleak-esphome==2.7.0
# homeassistant.components.bluetooth # homeassistant.components.bluetooth
bleak-retry-connector==3.8.0 bleak-retry-connector==3.8.0
@ -634,7 +634,7 @@ datadog==0.15.0
datapoint==0.9.9 datapoint==0.9.9
# homeassistant.components.bluetooth # homeassistant.components.bluetooth
dbus-fast==2.31.0 dbus-fast==2.32.0
# homeassistant.components.debugpy # homeassistant.components.debugpy
debugpy==1.8.11 debugpy==1.8.11
@ -1206,9 +1206,6 @@ motioneye-client==0.3.14
# homeassistant.components.bang_olufsen # homeassistant.components.bang_olufsen
mozart-api==4.1.1.116.4 mozart-api==4.1.1.116.4
# homeassistant.components.onedrive
msgraph-sdk==1.16.0
# homeassistant.components.mullvad # homeassistant.components.mullvad
mullvad-api==1.0.0 mullvad-api==1.0.0
@ -1306,6 +1303,9 @@ omnilogic==0.4.5
# homeassistant.components.ondilo_ico # homeassistant.components.ondilo_ico
ondilo==0.5.0 ondilo==0.5.0
# homeassistant.components.onedrive
onedrive-personal-sdk==0.0.1
# homeassistant.components.onvif # homeassistant.components.onvif
onvif-zeep-async==3.2.5 onvif-zeep-async==3.2.5
@ -1795,7 +1795,7 @@ pyownet==0.10.0.post1
pypalazzetti==0.1.19 pypalazzetti==0.1.19
# homeassistant.components.lcn # homeassistant.components.lcn
pypck==0.8.3 pypck==0.8.5
# homeassistant.components.pjlink # homeassistant.components.pjlink
pypjlink2==1.2.1 pypjlink2==1.2.1
@ -1882,7 +1882,7 @@ pysmarty2==0.10.1
pysml==0.0.12 pysml==0.0.12
# homeassistant.components.smlight # homeassistant.components.smlight
pysmlight==0.2.1 pysmlight==0.2.2
# homeassistant.components.snmp # homeassistant.components.snmp
pysnmp==6.2.6 pysnmp==6.2.6
@ -1933,7 +1933,7 @@ python-fullykiosk==0.0.14
python-google-drive-api==0.0.2 python-google-drive-api==0.0.2
# homeassistant.components.analytics_insights # homeassistant.components.analytics_insights
python-homeassistant-analytics==0.8.1 python-homeassistant-analytics==0.9.0
# homeassistant.components.homewizard # homeassistant.components.homewizard
python-homewizard-energy==v8.3.2 python-homewizard-energy==v8.3.2
@ -2031,7 +2031,7 @@ pyuptimerobot==22.2.0
pyvera==0.3.15 pyvera==0.3.15
# homeassistant.components.vesync # homeassistant.components.vesync
pyvesync==2.1.16 pyvesync==2.1.17
# homeassistant.components.vizio # homeassistant.components.vizio
pyvizio==0.1.61 pyvizio==0.1.61

View File

@ -24,7 +24,7 @@ RUN --mount=from=ghcr.io/astral-sh/uv:0.5.21,source=/uv,target=/bin/uv \
--no-cache \ --no-cache \
-c /usr/src/homeassistant/homeassistant/package_constraints.txt \ -c /usr/src/homeassistant/homeassistant/package_constraints.txt \
-r /usr/src/homeassistant/requirements.txt \ -r /usr/src/homeassistant/requirements.txt \
stdlib-list==0.10.0 pipdeptree==2.23.4 tqdm==4.66.5 ruff==0.9.1 \ stdlib-list==0.10.0 pipdeptree==2.25.0 tqdm==4.66.5 ruff==0.9.1 \
PyTurboJPEG==1.7.5 go2rtc-client==0.1.2 ha-ffmpeg==3.2.2 hassil==2.2.0 home-assistant-intents==2025.1.28 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2 PyTurboJPEG==1.7.5 go2rtc-client==0.1.2 ha-ffmpeg==3.2.2 hassil==2.2.0 home-assistant-intents==2025.1.28 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2
LABEL "name"="hassfest" LABEL "name"="hassfest"

View File

@ -236,6 +236,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="toolu_0123456789AbCdEfGhIjKlM",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
), ),
@ -373,6 +374,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="toolu_0123456789AbCdEfGhIjKlM",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
), ),

View File

@ -3,6 +3,7 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
}), }),
@ -32,7 +33,7 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': None, 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
@ -94,6 +95,7 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
}), }),
@ -123,7 +125,7 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': None, 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
@ -185,6 +187,7 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
}), }),
@ -214,7 +217,7 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': None, 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
@ -276,6 +279,7 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
}), }),
@ -329,7 +333,7 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': None, 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
@ -391,6 +395,7 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
}), }),
@ -427,6 +432,7 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-conversation-id',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
}), }),
@ -434,7 +440,7 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': None, 'conversation_id': 'mock-conversation-id',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test input', 'intent_input': 'test input',
@ -478,6 +484,7 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-conversation-id',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
}), }),
@ -485,7 +492,7 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': None, 'conversation_id': 'mock-conversation-id',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test input', 'intent_input': 'test input',
@ -529,6 +536,7 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-conversation-id',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
}), }),
@ -536,7 +544,7 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': None, 'conversation_id': 'mock-conversation-id',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test input', 'intent_input': 'test input',
@ -580,6 +588,7 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'conversation_id': 'mock-conversation-id',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
}), }),

View File

@ -1,6 +1,7 @@
# serializer version: 1 # serializer version: 1
# name: test_audio_pipeline # name: test_audio_pipeline
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -31,7 +32,7 @@
# --- # ---
# name: test_audio_pipeline.3 # name: test_audio_pipeline.3
dict({ dict({
'conversation_id': None, 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
@ -84,6 +85,7 @@
# --- # ---
# name: test_audio_pipeline_debug # name: test_audio_pipeline_debug
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -114,7 +116,7 @@
# --- # ---
# name: test_audio_pipeline_debug.3 # name: test_audio_pipeline_debug.3
dict({ dict({
'conversation_id': None, 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
@ -179,6 +181,7 @@
# --- # ---
# name: test_audio_pipeline_with_enhancements # name: test_audio_pipeline_with_enhancements
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -209,7 +212,7 @@
# --- # ---
# name: test_audio_pipeline_with_enhancements.3 # name: test_audio_pipeline_with_enhancements.3
dict({ dict({
'conversation_id': None, 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
@ -262,6 +265,7 @@
# --- # ---
# name: test_audio_pipeline_with_wake_word_no_timeout # name: test_audio_pipeline_with_wake_word_no_timeout
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -314,7 +318,7 @@
# --- # ---
# name: test_audio_pipeline_with_wake_word_no_timeout.5 # name: test_audio_pipeline_with_wake_word_no_timeout.5
dict({ dict({
'conversation_id': None, 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
@ -367,6 +371,7 @@
# --- # ---
# name: test_audio_pipeline_with_wake_word_timeout # name: test_audio_pipeline_with_wake_word_timeout
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -399,6 +404,7 @@
# --- # ---
# name: test_device_capture # name: test_device_capture
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -425,6 +431,7 @@
# --- # ---
# name: test_device_capture_override # name: test_device_capture_override
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -473,6 +480,7 @@
# --- # ---
# name: test_device_capture_queue_full # name: test_device_capture_queue_full
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -512,6 +520,7 @@
# --- # ---
# name: test_intent_failed # name: test_intent_failed
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -522,7 +531,7 @@
# --- # ---
# name: test_intent_failed.1 # name: test_intent_failed.1
dict({ dict({
'conversation_id': None, 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?', 'intent_input': 'Are the lights on?',
@ -535,6 +544,7 @@
# --- # ---
# name: test_intent_timeout # name: test_intent_timeout
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -545,7 +555,7 @@
# --- # ---
# name: test_intent_timeout.1 # name: test_intent_timeout.1
dict({ dict({
'conversation_id': None, 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?', 'intent_input': 'Are the lights on?',
@ -564,6 +574,7 @@
# --- # ---
# name: test_pipeline_empty_tts_output # name: test_pipeline_empty_tts_output
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -574,7 +585,7 @@
# --- # ---
# name: test_pipeline_empty_tts_output.1 # name: test_pipeline_empty_tts_output.1
dict({ dict({
'conversation_id': None, 'conversation_id': 'mock-ulid',
'device_id': None, 'device_id': None,
'engine': 'conversation.home_assistant', 'engine': 'conversation.home_assistant',
'intent_input': 'never mind', 'intent_input': 'never mind',
@ -611,6 +622,7 @@
# --- # ---
# name: test_stt_cooldown_different_ids # name: test_stt_cooldown_different_ids
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -621,6 +633,7 @@
# --- # ---
# name: test_stt_cooldown_different_ids.1 # name: test_stt_cooldown_different_ids.1
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -631,6 +644,7 @@
# --- # ---
# name: test_stt_cooldown_same_id # name: test_stt_cooldown_same_id
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -641,6 +655,7 @@
# --- # ---
# name: test_stt_cooldown_same_id.1 # name: test_stt_cooldown_same_id.1
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -651,6 +666,7 @@
# --- # ---
# name: test_stt_stream_failed # name: test_stt_stream_failed
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -677,6 +693,7 @@
# --- # ---
# name: test_text_only_pipeline[extra_msg0] # name: test_text_only_pipeline[extra_msg0]
dict({ dict({
'conversation_id': 'mock-conversation-id',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -723,6 +740,7 @@
# --- # ---
# name: test_text_only_pipeline[extra_msg1] # name: test_text_only_pipeline[extra_msg1]
dict({ dict({
'conversation_id': 'mock-conversation-id',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -775,6 +793,7 @@
# --- # ---
# name: test_tts_failed # name: test_tts_failed
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -796,6 +815,7 @@
# --- # ---
# name: test_wake_word_cooldown_different_entities # name: test_wake_word_cooldown_different_entities
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -806,6 +826,7 @@
# --- # ---
# name: test_wake_word_cooldown_different_entities.1 # name: test_wake_word_cooldown_different_entities.1
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -857,6 +878,7 @@
# --- # ---
# name: test_wake_word_cooldown_different_ids # name: test_wake_word_cooldown_different_ids
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -867,6 +889,7 @@
# --- # ---
# name: test_wake_word_cooldown_different_ids.1 # name: test_wake_word_cooldown_different_ids.1
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -921,6 +944,7 @@
# --- # ---
# name: test_wake_word_cooldown_same_id # name: test_wake_word_cooldown_same_id
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
@ -931,6 +955,7 @@
# --- # ---
# name: test_wake_word_cooldown_same_id.1 # name: test_wake_word_cooldown_same_id.1
dict({ dict({
'conversation_id': 'mock-ulid',
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({

View File

@ -1,11 +1,12 @@
"""Test Voice Assistant init.""" """Test Voice Assistant init."""
import asyncio import asyncio
from collections.abc import Generator
from dataclasses import asdict from dataclasses import asdict
import itertools as it import itertools as it
from pathlib import Path from pathlib import Path
import tempfile import tempfile
from unittest.mock import ANY, patch from unittest.mock import ANY, Mock, patch
import wave import wave
import hass_nabucasa import hass_nabucasa
@ -41,6 +42,14 @@ from .conftest import (
from tests.typing import ClientSessionGenerator, WebSocketGenerator from tests.typing import ClientSessionGenerator, WebSocketGenerator
@pytest.fixture(autouse=True)
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid of chat sessions."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]: def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
"""Process events to remove dynamic values.""" """Process events to remove dynamic values."""
processed = [] processed = []
@ -684,7 +693,7 @@ async def test_wake_word_detection_aborted(
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
conversation_id=None, conversation_id="mock-conversation-id",
device_id=None, device_id=None,
stt_metadata=stt.SpeechMetadata( stt_metadata=stt.SpeechMetadata(
language="", language="",
@ -771,7 +780,7 @@ async def test_tts_audio_output(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.", tts_input="This is a test.",
conversation_id=None, conversation_id="mock-conversation-id",
device_id=None, device_id=None,
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
@ -828,7 +837,7 @@ async def test_tts_wav_preferred_format(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.", tts_input="This is a test.",
conversation_id=None, conversation_id="mock-conversation-id",
device_id=None, device_id=None,
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
@ -896,7 +905,7 @@ async def test_tts_dict_preferred_format(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.", tts_input="This is a test.",
conversation_id=None, conversation_id="mock-conversation-id",
device_id=None, device_id=None,
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
@ -982,6 +991,7 @@ async def test_sentence_trigger_overrides_conversation_agent(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test trigger sentence", intent_input="test trigger sentence",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
context=Context(), context=Context(),
@ -1059,6 +1069,7 @@ async def test_prefer_local_intents(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="I'd like to order a stout please", intent_input="I'd like to order a stout please",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
context=Context(), context=Context(),
@ -1136,6 +1147,7 @@ async def test_stt_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input", intent_input="test input",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
context=Context(), context=Context(),
@ -1210,6 +1222,7 @@ async def test_tts_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input", intent_input="test input",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
context=Context(), context=Context(),
@ -1284,6 +1297,7 @@ async def test_pipeline_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput( pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input", intent_input="test input",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun( run=assist_pipeline.pipeline.PipelineRun(
hass, hass,
context=Context(), context=Context(),

View File

@ -2,8 +2,9 @@
import asyncio import asyncio
import base64 import base64
from collections.abc import Generator
from typing import Any from typing import Any
from unittest.mock import ANY, patch from unittest.mock import ANY, Mock, patch
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
@ -35,6 +36,14 @@ from tests.common import MockConfigEntry
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
@pytest.fixture(autouse=True)
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid of chat sessions."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
@pytest.mark.parametrize( @pytest.mark.parametrize(
"extra_msg", "extra_msg",
[ [

View File

@ -94,7 +94,9 @@ class MockAssistSatellite(AssistSatelliteEntity):
self, start_announcement: AssistSatelliteConfiguration self, start_announcement: AssistSatelliteConfiguration
) -> None: ) -> None:
"""Start a conversation from the satellite.""" """Start a conversation from the satellite."""
self.start_conversations.append((self._extra_system_prompt, start_announcement)) self.start_conversations.append(
(self._conversation_id, self._extra_system_prompt, start_announcement)
)
@pytest.fixture @pytest.fixture

View File

@ -1,7 +1,8 @@
"""Test the Assist Satellite entity.""" """Test the Assist Satellite entity."""
import asyncio import asyncio
from unittest.mock import patch from collections.abc import Generator
from unittest.mock import Mock, patch
import pytest import pytest
@ -31,6 +32,14 @@ from . import ENTITY_ID
from .conftest import MockAssistSatellite from .conftest import MockAssistSatellite
@pytest.fixture
def mock_chat_session_conversation_id() -> Generator[Mock]:
"""Mock the ulid library."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-conversation-id"
yield mock_ulid_now
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None: async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None:
"""Set up a pipeline with a TTS engine.""" """Set up a pipeline with a TTS engine."""
@ -487,6 +496,7 @@ async def test_vad_sensitivity_entity_not_found(
"extra_system_prompt": "Better system prompt", "extra_system_prompt": "Better system prompt",
}, },
( (
"mock-conversation-id",
"Better system prompt", "Better system prompt",
AssistSatelliteAnnouncement( AssistSatelliteAnnouncement(
message="Hello", message="Hello",
@ -502,6 +512,7 @@ async def test_vad_sensitivity_entity_not_found(
"start_media_id": "media-source://given", "start_media_id": "media-source://given",
}, },
( (
"mock-conversation-id",
"Hello", "Hello",
AssistSatelliteAnnouncement( AssistSatelliteAnnouncement(
message="Hello", message="Hello",
@ -514,6 +525,7 @@ async def test_vad_sensitivity_entity_not_found(
( (
{"start_media_id": "http://example.com/given.mp3"}, {"start_media_id": "http://example.com/given.mp3"},
( (
"mock-conversation-id",
None, None,
AssistSatelliteAnnouncement( AssistSatelliteAnnouncement(
message="", message="",
@ -525,6 +537,7 @@ async def test_vad_sensitivity_entity_not_found(
), ),
], ],
) )
@pytest.mark.usefixtures("mock_chat_session_conversation_id")
async def test_start_conversation( async def test_start_conversation(
hass: HomeAssistant, hass: HomeAssistant,
init_components: ConfigEntry, init_components: ConfigEntry,

View File

@ -9,13 +9,13 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol import voluptuous as vol
from homeassistant.components.conversation import ( from homeassistant.components.conversation import (
Content, AssistantContent,
ConversationInput, ConversationInput,
ConverseError, ConverseError,
NativeContent, ToolResultContent,
async_get_chat_log, async_get_chat_log,
) )
from homeassistant.components.conversation.session import DATA_CHAT_HISTORY from homeassistant.components.conversation.chat_log import DATA_CHAT_HISTORY
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, llm from homeassistant.helpers import chat_session, llm
@ -40,7 +40,7 @@ def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
@pytest.fixture @pytest.fixture
def mock_ulid() -> Generator[Mock]: def mock_ulid() -> Generator[Mock]:
"""Mock the ulid library.""" """Mock the ulid library."""
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now: with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid" mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now yield mock_ulid_now
@ -56,13 +56,13 @@ async def test_cleanup(
): ):
conversation_id = session.conversation_id conversation_id = session.conversation_id
# Add message so it persists # Add message so it persists
chat_log.async_add_message( async for _tool_result in chat_log.async_add_assistant_content(
Content( AssistantContent(
role="assistant", agent_id="mock-agent-id",
agent_id=mock_conversation_input.agent_id, content="Hey!",
content="",
)
) )
):
pytest.fail("should not reach here")
assert conversation_id in hass.data[DATA_CHAT_HISTORY] assert conversation_id in hass.data[DATA_CHAT_HISTORY]
@ -79,7 +79,7 @@ async def test_cleanup(
assert conversation_id not in hass.data[DATA_CHAT_HISTORY] assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
async def test_add_message( async def test_default_content(
hass: HomeAssistant, mock_conversation_input: ConversationInput hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None: ) -> None:
"""Test filtering of messages.""" """Test filtering of messages."""
@ -87,95 +87,11 @@ async def test_add_message(
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
assert len(chat_log.messages) == 2 assert len(chat_log.content) == 2
assert chat_log.content[0].role == "system"
with pytest.raises(ValueError): assert chat_log.content[0].content == ""
chat_log.async_add_message( assert chat_log.content[1].role == "user"
Content(role="system", agent_id=None, content="") assert chat_log.content[1].content == mock_conversation_input.text
)
# No 2 user messages in a row
assert chat_log.messages[1].role == "user"
with pytest.raises(ValueError):
chat_log.async_add_message(Content(role="user", agent_id=None, content=""))
# No 2 assistant messages in a row
chat_log.async_add_message(Content(role="assistant", agent_id=None, content=""))
assert len(chat_log.messages) == 3
assert chat_log.messages[-1].role == "assistant"
with pytest.raises(ValueError):
chat_log.async_add_message(
Content(role="assistant", agent_id=None, content="")
)
async def test_message_filtering(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test filtering of messages."""
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
messages = chat_log.async_get_messages(agent_id=None)
assert len(messages) == 2
assert messages[0] == Content(
role="system",
agent_id=None,
content="",
)
assert messages[1] == Content(
role="user",
agent_id="mock-agent-id",
content=mock_conversation_input.text,
)
# Cannot add a second user message in a row
with pytest.raises(ValueError):
chat_log.async_add_message(
Content(
role="user",
agent_id="mock-agent-id",
content="Hey!",
)
)
chat_log.async_add_message(
Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
)
# Different agent, native messages will be filtered out.
chat_log.async_add_message(
NativeContent(agent_id="another-mock-agent-id", content=1)
)
chat_log.async_add_message(NativeContent(agent_id="mock-agent-id", content=1))
# A non-native message from another agent is not filtered out.
chat_log.async_add_message(
Content(
role="assistant",
agent_id="another-mock-agent-id",
content="Hi!",
)
)
assert len(chat_log.messages) == 6
messages = chat_log.async_get_messages(agent_id="mock-agent-id")
assert len(messages) == 5
assert messages[2] == Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
assert messages[3] == NativeContent(agent_id="mock-agent-id", content=1)
assert messages[4] == Content(
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
)
async def test_llm_api( async def test_llm_api(
@ -268,12 +184,10 @@ async def test_template_variables(
), ),
) )
assert chat_log.user_name == "Test User" assert "The instance name is test home." in chat_log.content[0].content
assert "The user name is Test User." in chat_log.content[0].content
assert "The instance name is test home." in chat_log.messages[0].content assert "The user id is 12345." in chat_log.content[0].content
assert "The user name is Test User." in chat_log.messages[0].content assert "The calling platform is test." in chat_log.content[0].content
assert "The user id is 12345." in chat_log.messages[0].content
assert "The calling platform is test." in chat_log.messages[0].content
async def test_extra_systen_prompt( async def test_extra_systen_prompt(
@ -296,16 +210,16 @@ async def test_extra_systen_prompt(
user_llm_hass_api=None, user_llm_hass_api=None,
user_llm_prompt=None, user_llm_prompt=None,
) )
chat_log.async_add_message( async for _tool_result in chat_log.async_add_assistant_content(
Content( AssistantContent(
role="assistant",
agent_id="mock-agent-id", agent_id="mock-agent-id",
content="Hey!", content="Hey!",
) )
) ):
pytest.fail("should not reach here")
assert chat_log.extra_system_prompt == extra_system_prompt assert chat_log.extra_system_prompt == extra_system_prompt
assert chat_log.messages[0].content.endswith(extra_system_prompt) assert chat_log.content[0].content.endswith(extra_system_prompt)
# Verify that follow-up conversations with no system prompt take previous one # Verify that follow-up conversations with no system prompt take previous one
conversation_id = chat_log.conversation_id conversation_id = chat_log.conversation_id
@ -323,7 +237,7 @@ async def test_extra_systen_prompt(
) )
assert chat_log.extra_system_prompt == extra_system_prompt assert chat_log.extra_system_prompt == extra_system_prompt
assert chat_log.messages[0].content.endswith(extra_system_prompt) assert chat_log.content[0].content.endswith(extra_system_prompt)
# Verify that we take new system prompts # Verify that we take new system prompts
mock_conversation_input.extra_system_prompt = extra_system_prompt2 mock_conversation_input.extra_system_prompt = extra_system_prompt2
@ -338,17 +252,17 @@ async def test_extra_systen_prompt(
user_llm_hass_api=None, user_llm_hass_api=None,
user_llm_prompt=None, user_llm_prompt=None,
) )
chat_log.async_add_message( async for _tool_result in chat_log.async_add_assistant_content(
Content( AssistantContent(
role="assistant",
agent_id="mock-agent-id", agent_id="mock-agent-id",
content="Hey!", content="Hey!",
) )
) ):
pytest.fail("should not reach here")
assert chat_log.extra_system_prompt == extra_system_prompt2 assert chat_log.extra_system_prompt == extra_system_prompt2
assert chat_log.messages[0].content.endswith(extra_system_prompt2) assert chat_log.content[0].content.endswith(extra_system_prompt2)
assert extra_system_prompt not in chat_log.messages[0].content assert extra_system_prompt not in chat_log.content[0].content
# Verify that follow-up conversations with no system prompt take previous one # Verify that follow-up conversations with no system prompt take previous one
mock_conversation_input.extra_system_prompt = None mock_conversation_input.extra_system_prompt = None
@ -365,7 +279,7 @@ async def test_extra_systen_prompt(
) )
assert chat_log.extra_system_prompt == extra_system_prompt2 assert chat_log.extra_system_prompt == extra_system_prompt2
assert chat_log.messages[0].content.endswith(extra_system_prompt2) assert chat_log.content[0].content.endswith(extra_system_prompt2)
async def test_tool_call( async def test_tool_call(
@ -383,8 +297,7 @@ async def test_tool_call(
mock_tool.async_call.return_value = "Test response" mock_tool.async_call.return_value = "Test response"
with patch( with patch(
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools", "homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
return_value=[],
) as mock_get_tools: ) as mock_get_tools:
mock_get_tools.return_value = [mock_tool] mock_get_tools.return_value = [mock_tool]
@ -398,14 +311,29 @@ async def test_tool_call(
user_llm_hass_api="assist", user_llm_hass_api="assist",
user_llm_prompt=None, user_llm_prompt=None,
) )
result = await chat_log.async_call_tool( result = None
async for tool_result_content in chat_log.async_add_assistant_content(
AssistantContent(
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput( llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "Test Param"}, tool_args={"param1": "Test Param"},
) )
],
) )
):
assert result is None
result = tool_result_content
assert result == "Test response" assert result == ToolResultContent(
agent_id=mock_conversation_input.agent_id,
tool_call_id="mock-tool-call-id",
tool_result="Test response",
tool_name="test_tool",
)
async def test_tool_call_exception( async def test_tool_call_exception(
@ -423,8 +351,7 @@ async def test_tool_call_exception(
mock_tool.async_call.side_effect = HomeAssistantError("Test error") mock_tool.async_call.side_effect = HomeAssistantError("Test error")
with patch( with patch(
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools", "homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
return_value=[],
) as mock_get_tools: ) as mock_get_tools:
mock_get_tools.return_value = [mock_tool] mock_get_tools.return_value = [mock_tool]
@ -438,11 +365,26 @@ async def test_tool_call_exception(
user_llm_hass_api="assist", user_llm_hass_api="assist",
user_llm_prompt=None, user_llm_prompt=None,
) )
result = await chat_log.async_call_tool( result = None
async for tool_result_content in chat_log.async_add_assistant_content(
AssistantContent(
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput( llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "Test Param"}, tool_args={"param1": "Test Param"},
) )
],
) )
):
assert result is None
result = tool_result_content
assert result == {"error": "HomeAssistantError", "error_text": "Test error"} assert result == ToolResultContent(
agent_id=mock_conversation_input.agent_id,
tool_call_id="mock-tool-call-id",
tool_result={"error": "HomeAssistantError", "error_text": "Test error"},
tool_name="test_tool",
)

View File

@ -36,6 +36,13 @@ def freeze_the_time():
yield yield
@pytest.fixture(autouse=True)
def mock_ulid_tools():
"""Mock generated ULIDs for tool calls."""
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
yield
@pytest.mark.parametrize( @pytest.mark.parametrize(
"agent_id", [None, "conversation.google_generative_ai_conversation"] "agent_id", [None, "conversation.google_generative_ai_conversation"]
) )
@ -177,6 +184,7 @@ async def test_chat_history(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools" "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
) )
@pytest.mark.usefixtures("mock_init_component") @pytest.mark.usefixtures("mock_init_component")
@pytest.mark.usefixtures("mock_ulid_tools")
async def test_function_call( async def test_function_call(
mock_get_tools, mock_get_tools,
hass: HomeAssistant, hass: HomeAssistant,
@ -256,6 +264,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool", tool_name="test_tool",
tool_args={ tool_args={
"param1": ["test_value", "param1's value"], "param1": ["test_value", "param1's value"],
@ -287,9 +296,7 @@ async def test_function_call(
detail_event = trace_events[1] detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"] assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
assert [ assert [
p.function_response.name p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
for p in detail_event["data"]["messages"][2]["content"].parts
if p.function_response
] == ["test_tool"] ] == ["test_tool"]
@ -362,6 +369,7 @@ async def test_function_call_without_parameters(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool", tool_name="test_tool",
tool_args={}, tool_args={},
), ),
@ -451,6 +459,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": 1}, tool_args={"param1": 1},
), ),
@ -605,6 +614,7 @@ async def test_template_variables(
mock_chat.send_message_async.return_value = chat_response mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock() mock_part = MagicMock()
mock_part.text = "Model response" mock_part.text = "Model response"
mock_part.function_call = None
chat_response.parts = [mock_part] chat_response.parts = [mock_part]
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, context, agent_id=mock_config_entry.entry_id hass, "hello", None, context, agent_id=mock_config_entry.entry_id

View File

@ -324,6 +324,24 @@ TEST_JOB_DONE = supervisor_jobs.Job(
errors=[], errors=[],
child_jobs=[], child_jobs=[],
) )
TEST_RESTORE_JOB_DONE_WITH_ERROR = supervisor_jobs.Job(
name="backup_manager_partial_restore",
reference="1ef41507",
uuid=UUID(TEST_JOB_ID),
progress=0.0,
stage="copy_additional_locations",
done=True,
errors=[
supervisor_jobs.JobError(
type="BackupInvalidError",
message=(
"Backup was made on supervisor version 2025.02.2.dev3105, "
"can't restore on 2025.01.2.dev3105"
),
)
],
child_jobs=[],
)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -1946,6 +1964,97 @@ async def test_reader_writer_restore_error(
assert response["error"]["code"] == expected_error_code assert response["error"]["code"] == expected_error_code
@pytest.mark.usefixtures("hassio_client", "setup_integration")
async def test_reader_writer_restore_late_error(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
supervisor_client: AsyncMock,
) -> None:
"""Test restoring a backup with error."""
client = await hass_ws_client(hass)
supervisor_client.backups.partial_restore.return_value.job_id = TEST_JOB_ID
supervisor_client.backups.list.return_value = [TEST_BACKUP]
supervisor_client.backups.backup_info.return_value = TEST_BACKUP_DETAILS
supervisor_client.jobs.get_job.return_value = TEST_JOB_NOT_DONE
await client.send_json_auto_id({"type": "backup/subscribe_events"})
response = await client.receive_json()
assert response["event"] == {"manager_state": "idle"}
response = await client.receive_json()
assert response["success"]
await client.send_json_auto_id(
{"type": "backup/restore", "agent_id": "hassio.local", "backup_id": "abc123"}
)
response = await client.receive_json()
assert response["event"] == {
"manager_state": "restore_backup",
"reason": None,
"stage": None,
"state": "in_progress",
}
supervisor_client.backups.partial_restore.assert_called_once_with(
"abc123",
supervisor_backups.PartialRestoreOptions(
addons=None,
background=True,
folders=None,
homeassistant=True,
location=None,
password=None,
),
)
event = {
"event": "job",
"data": {
"name": "backup_manager_partial_restore",
"reference": "7c54aeed",
"uuid": TEST_JOB_ID,
"progress": 0,
"stage": None,
"done": True,
"parent_id": None,
"errors": [
{
"type": "BackupInvalidError",
"message": (
"Backup was made on supervisor version 2025.02.2.dev3105, can't"
" restore on 2025.01.2.dev3105. Must update supervisor first."
),
}
],
"created": "2025-02-03T08:27:49.297997+00:00",
},
}
await client.send_json_auto_id({"type": "supervisor/event", "data": event})
response = await client.receive_json()
assert response["success"]
response = await client.receive_json()
assert response["event"] == {
"manager_state": "restore_backup",
"reason": "backup_reader_writer_error",
"stage": None,
"state": "failed",
}
response = await client.receive_json()
assert response["event"] == {"manager_state": "idle"}
response = await client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "home_assistant_error",
"message": (
"Restore failed: [{'type': 'BackupInvalidError', 'message': \"Backup "
"was made on supervisor version 2025.02.2.dev3105, can't restore on "
'2025.01.2.dev3105. Must update supervisor first."}]'
),
}
@pytest.mark.parametrize( @pytest.mark.parametrize(
("backup", "backup_details", "parameters", "expected_error"), ("backup", "backup_details", "parameters", "expected_error"),
[ [
@ -1999,15 +2108,40 @@ async def test_reader_writer_restore_wrong_parameters(
} }
@pytest.mark.parametrize(
("get_job_result", "last_non_idle_event"),
[
(
TEST_JOB_DONE,
{
"manager_state": "restore_backup",
"reason": "",
"stage": None,
"state": "completed",
},
),
(
TEST_RESTORE_JOB_DONE_WITH_ERROR,
{
"manager_state": "restore_backup",
"reason": "unknown_error",
"stage": None,
"state": "failed",
},
),
],
)
@pytest.mark.usefixtures("hassio_client") @pytest.mark.usefixtures("hassio_client")
async def test_restore_progress_after_restart( async def test_restore_progress_after_restart(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
supervisor_client: AsyncMock, supervisor_client: AsyncMock,
get_job_result: supervisor_jobs.Job,
last_non_idle_event: dict[str, Any],
) -> None: ) -> None:
"""Test restore backup progress after restart.""" """Test restore backup progress after restart."""
supervisor_client.jobs.get_job.return_value = TEST_JOB_DONE supervisor_client.jobs.get_job.return_value = get_job_result
with patch.dict(os.environ, MOCK_ENVIRON | {RESTORE_JOB_ID_ENV: TEST_JOB_ID}): with patch.dict(os.environ, MOCK_ENVIRON | {RESTORE_JOB_ID_ENV: TEST_JOB_ID}):
assert await async_setup_component(hass, BACKUP_DOMAIN, {BACKUP_DOMAIN: {}}) assert await async_setup_component(hass, BACKUP_DOMAIN, {BACKUP_DOMAIN: {}})
@ -2018,12 +2152,7 @@ async def test_restore_progress_after_restart(
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
assert response["result"]["last_non_idle_event"] == { assert response["result"]["last_non_idle_event"] == last_non_idle_event
"manager_state": "restore_backup",
"reason": "",
"stage": None,
"state": "completed",
}
assert response["result"]["state"] == "idle" assert response["result"]["state"] == "idle"

View File

@ -18,6 +18,13 @@ from homeassistant.helpers import intent, llm
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@pytest.fixture(autouse=True)
def mock_ulid_tools():
"""Mock generated ULIDs for tool calls."""
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
yield
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"]) @pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
async def test_chat( async def test_chat(
hass: HomeAssistant, hass: HomeAssistant,
@ -205,6 +212,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool", tool_name="test_tool",
tool_args=expected_tool_args, tool_args=expected_tool_args,
), ),
@ -285,6 +293,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
), ),

View File

@ -1,18 +1,9 @@
"""Fixtures for OneDrive tests.""" """Fixtures for OneDrive tests."""
from collections.abc import AsyncIterator, Generator from collections.abc import AsyncIterator, Generator
from html import escape
from json import dumps
import time import time
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from httpx import Response
from msgraph.generated.models.drive_item import DriveItem
from msgraph.generated.models.drive_item_collection_response import (
DriveItemCollectionResponse,
)
from msgraph.generated.models.upload_session import UploadSession
from msgraph_core.models import LargeFileUploadSession
import pytest import pytest
from homeassistant.components.application_credentials import ( from homeassistant.components.application_credentials import (
@ -23,7 +14,13 @@ from homeassistant.components.onedrive.const import DOMAIN, OAUTH_SCOPES
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .const import BACKUP_METADATA, CLIENT_ID, CLIENT_SECRET from .const import (
CLIENT_ID,
CLIENT_SECRET,
MOCK_APPROOT,
MOCK_BACKUP_FILE,
MOCK_BACKUP_FOLDER,
)
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -70,96 +67,41 @@ def mock_config_entry(expires_at: int, scopes: list[str]) -> MockConfigEntry:
) )
@pytest.fixture
def mock_adapter() -> Generator[MagicMock]:
"""Return a mocked GraphAdapter."""
with (
patch(
"homeassistant.components.onedrive.config_flow.GraphRequestAdapter",
autospec=True,
) as mock_adapter,
patch(
"homeassistant.components.onedrive.backup.GraphRequestAdapter",
new=mock_adapter,
),
):
adapter = mock_adapter.return_value
adapter.get_http_response_message.return_value = Response(
status_code=200,
json={
"parentReference": {"driveId": "mock_drive_id"},
"createdBy": {"user": {"displayName": "John Doe"}},
},
)
yield adapter
adapter.send_async.return_value = LargeFileUploadSession(
next_expected_ranges=["2-"]
)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_graph_client(mock_adapter: MagicMock) -> Generator[MagicMock]: def mock_onedrive_client() -> Generator[MagicMock]:
"""Return a mocked GraphServiceClient.""" """Return a mocked GraphServiceClient."""
with ( with (
patch( patch(
"homeassistant.components.onedrive.config_flow.GraphServiceClient", "homeassistant.components.onedrive.config_flow.OneDriveClient",
autospec=True, autospec=True,
) as graph_client, ) as onedrive_client,
patch( patch(
"homeassistant.components.onedrive.GraphServiceClient", "homeassistant.components.onedrive.OneDriveClient",
new=graph_client, new=onedrive_client,
), ),
): ):
client = graph_client.return_value client = onedrive_client.return_value
client.get_approot.return_value = MOCK_APPROOT
client.create_folder.return_value = MOCK_BACKUP_FOLDER
client.list_drive_items.return_value = [MOCK_BACKUP_FILE]
client.get_drive_item.return_value = MOCK_BACKUP_FILE
client.request_adapter = mock_adapter class MockStreamReader:
async def iter_chunked(self, chunk_size: int) -> AsyncIterator[bytes]:
drives = client.drives.by_drive_id.return_value
drives.special.by_drive_item_id.return_value.get = AsyncMock(
return_value=DriveItem(id="approot")
)
drive_items = drives.items.by_drive_item_id.return_value
drive_items.get = AsyncMock(return_value=DriveItem(id="folder_id"))
drive_items.children.post = AsyncMock(return_value=DriveItem(id="folder_id"))
drive_items.children.get = AsyncMock(
return_value=DriveItemCollectionResponse(
value=[
DriveItem(
id=BACKUP_METADATA["backup_id"],
description=escape(dumps(BACKUP_METADATA)),
),
DriveItem(),
]
)
)
drive_items.delete = AsyncMock(return_value=None)
drive_items.create_upload_session.post = AsyncMock(
return_value=UploadSession(upload_url="https://test.tld")
)
drive_items.patch = AsyncMock(return_value=None)
async def generate_bytes() -> AsyncIterator[bytes]:
"""Asynchronous generator that yields bytes."""
yield b"backup data" yield b"backup data"
drive_items.content.get = AsyncMock( client.download_drive_item.return_value = MockStreamReader()
return_value=Response(status_code=200, content=generate_bytes())
)
yield client yield client
@pytest.fixture @pytest.fixture
def mock_drive_items(mock_graph_client: MagicMock) -> MagicMock: def mock_large_file_upload_client() -> Generator[AsyncMock]:
"""Return a mocked DriveItems.""" """Return a mocked LargeFileUploadClient upload."""
return mock_graph_client.drives.by_drive_id.return_value.items.by_drive_item_id.return_value with patch(
"homeassistant.components.onedrive.backup.LargeFileUploadClient.upload"
) as mock_upload:
@pytest.fixture yield mock_upload
def mock_get_special_folder(mock_graph_client: MagicMock) -> MagicMock:
"""Mock the get special folder method."""
return mock_graph_client.drives.by_drive_id.return_value.special.by_drive_item_id.return_value.get
@pytest.fixture @pytest.fixture
@ -179,10 +121,3 @@ def mock_instance_id() -> Generator[AsyncMock]:
return_value="9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0", return_value="9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0",
): ):
yield yield
@pytest.fixture(autouse=True)
def mock_asyncio_sleep() -> Generator[AsyncMock]:
"""Mock asyncio.sleep."""
with patch("homeassistant.components.onedrive.backup.asyncio.sleep", AsyncMock()):
yield

View File

@ -1,5 +1,18 @@
"""Consts for OneDrive tests.""" """Consts for OneDrive tests."""
from html import escape
from json import dumps
from onedrive_personal_sdk.models.items import (
AppRoot,
Contributor,
File,
Folder,
Hashes,
ItemParentReference,
User,
)
CLIENT_ID = "1234" CLIENT_ID = "1234"
CLIENT_SECRET = "5678" CLIENT_SECRET = "5678"
@ -17,3 +30,48 @@ BACKUP_METADATA = {
"protected": False, "protected": False,
"size": 34519040, "size": 34519040,
} }
CONTRIBUTOR = Contributor(
user=User(
display_name="John Doe",
id="id",
email="john@doe.com",
)
)
MOCK_APPROOT = AppRoot(
id="id",
child_count=0,
size=0,
name="name",
parent_reference=ItemParentReference(
drive_id="mock_drive_id", id="id", path="path"
),
created_by=CONTRIBUTOR,
)
MOCK_BACKUP_FOLDER = Folder(
id="id",
name="name",
size=0,
child_count=0,
parent_reference=ItemParentReference(
drive_id="mock_drive_id", id="id", path="path"
),
created_by=CONTRIBUTOR,
)
MOCK_BACKUP_FILE = File(
id="id",
name="23e64aec.tar",
size=34519040,
parent_reference=ItemParentReference(
drive_id="mock_drive_id", id="id", path="path"
),
hashes=Hashes(
quick_xor_hash="hash",
),
mime_type="application/x-tar",
description=escape(dumps(BACKUP_METADATA)),
created_by=CONTRIBUTOR,
)

View File

@ -3,15 +3,14 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from html import escape
from io import StringIO from io import StringIO
from json import dumps
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from httpx import TimeoutException from onedrive_personal_sdk.exceptions import (
from kiota_abstractions.api_error import APIError AuthenticationError,
from msgraph.generated.models.drive_item import DriveItem HashMismatchError,
from msgraph_core.models import LargeFileUploadSession OneDriveException,
)
import pytest import pytest
from homeassistant.components.backup import DOMAIN as BACKUP_DOMAIN, AgentBackup from homeassistant.components.backup import DOMAIN as BACKUP_DOMAIN, AgentBackup
@ -102,14 +101,10 @@ async def test_agents_list_backups(
async def test_agents_get_backup( async def test_agents_get_backup(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
mock_drive_items: MagicMock,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
) -> None: ) -> None:
"""Test agent get backup.""" """Test agent get backup."""
mock_drive_items.get = AsyncMock(
return_value=DriveItem(description=escape(dumps(BACKUP_METADATA)))
)
backup_id = BACKUP_METADATA["backup_id"] backup_id = BACKUP_METADATA["backup_id"]
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
await client.send_json_auto_id({"type": "backup/details", "backup_id": backup_id}) await client.send_json_auto_id({"type": "backup/details", "backup_id": backup_id})
@ -140,7 +135,7 @@ async def test_agents_get_backup(
async def test_agents_delete( async def test_agents_delete(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
mock_drive_items: MagicMock, mock_onedrive_client: MagicMock,
) -> None: ) -> None:
"""Test agent delete backup.""" """Test agent delete backup."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -155,37 +150,15 @@ async def test_agents_delete(
assert response["success"] assert response["success"]
assert response["result"] == {"agent_errors": {}} assert response["result"] == {"agent_errors": {}}
mock_drive_items.delete.assert_called_once() mock_onedrive_client.delete_drive_item.assert_called_once()
async def test_agents_delete_not_found_does_not_throw(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_drive_items: MagicMock,
) -> None:
"""Test agent delete backup."""
mock_drive_items.children.get = AsyncMock(return_value=[])
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "backup/delete",
"backup_id": BACKUP_METADATA["backup_id"],
}
)
response = await client.receive_json()
assert response["success"]
assert response["result"] == {"agent_errors": {}}
assert mock_drive_items.delete.call_count == 0
async def test_agents_upload( async def test_agents_upload(
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
mock_drive_items: MagicMock, mock_onedrive_client: MagicMock,
mock_large_file_upload_client: AsyncMock,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
mock_adapter: MagicMock,
) -> None: ) -> None:
"""Test agent upload backup.""" """Test agent upload backup."""
client = await hass_client() client = await hass_client()
@ -200,7 +173,6 @@ async def test_agents_upload(
return_value=test_backup, return_value=test_backup,
), ),
patch("pathlib.Path.open") as mocked_open, patch("pathlib.Path.open") as mocked_open,
patch("homeassistant.components.onedrive.backup.UPLOAD_CHUNK_SIZE", 3),
): ):
mocked_open.return_value.read = Mock(side_effect=[b"test", b""]) mocked_open.return_value.read = Mock(side_effect=[b"test", b""])
fetch_backup.return_value = test_backup fetch_backup.return_value = test_backup
@ -211,31 +183,22 @@ async def test_agents_upload(
assert resp.status == 201 assert resp.status == 201
assert f"Uploading backup {test_backup.backup_id}" in caplog.text assert f"Uploading backup {test_backup.backup_id}" in caplog.text
mock_drive_items.create_upload_session.post.assert_called_once() mock_large_file_upload_client.assert_called_once()
mock_drive_items.patch.assert_called_once() mock_onedrive_client.update_drive_item.assert_called_once()
assert mock_adapter.send_async.call_count == 2
assert mock_adapter.method_calls[0].args[0].content == b"tes"
assert mock_adapter.method_calls[0].args[0].headers.get("Content-Range") == {
"bytes 0-2/34519040"
}
assert mock_adapter.method_calls[1].args[0].content == b"t"
assert mock_adapter.method_calls[1].args[0].headers.get("Content-Range") == {
"bytes 3-3/34519040"
}
async def test_broken_upload_session( async def test_agents_upload_corrupt_upload(
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
mock_drive_items: MagicMock, mock_onedrive_client: MagicMock,
mock_large_file_upload_client: AsyncMock,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
) -> None: ) -> None:
"""Test broken upload session.""" """Test hash validation fails."""
mock_large_file_upload_client.side_effect = HashMismatchError("test")
client = await hass_client() client = await hass_client()
test_backup = AgentBackup.from_dict(BACKUP_METADATA) test_backup = AgentBackup.from_dict(BACKUP_METADATA)
mock_drive_items.create_upload_session.post = AsyncMock(return_value=None)
with ( with (
patch( patch(
"homeassistant.components.backup.manager.BackupManager.async_get_backup", "homeassistant.components.backup.manager.BackupManager.async_get_backup",
@ -254,152 +217,18 @@ async def test_broken_upload_session(
) )
assert resp.status == 201 assert resp.status == 201
assert "Failed to start backup upload" in caplog.text
@pytest.mark.parametrize(
"side_effect",
[
APIError(response_status_code=500),
TimeoutException("Timeout"),
],
)
async def test_agents_upload_errors_retried(
hass_client: ClientSessionGenerator,
caplog: pytest.LogCaptureFixture,
mock_drive_items: MagicMock,
mock_config_entry: MockConfigEntry,
mock_adapter: MagicMock,
side_effect: Exception,
) -> None:
"""Test agent upload backup."""
client = await hass_client()
test_backup = AgentBackup.from_dict(BACKUP_METADATA)
mock_adapter.send_async.side_effect = [
side_effect,
LargeFileUploadSession(next_expected_ranges=["2-"]),
LargeFileUploadSession(next_expected_ranges=["2-"]),
]
with (
patch(
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
) as fetch_backup,
patch(
"homeassistant.components.backup.manager.read_backup",
return_value=test_backup,
),
patch("pathlib.Path.open") as mocked_open,
patch("homeassistant.components.onedrive.backup.UPLOAD_CHUNK_SIZE", 3),
):
mocked_open.return_value.read = Mock(side_effect=[b"test", b""])
fetch_backup.return_value = test_backup
resp = await client.post(
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.unique_id}",
data={"file": StringIO("test")},
)
assert resp.status == 201
assert mock_adapter.send_async.call_count == 3
assert f"Uploading backup {test_backup.backup_id}" in caplog.text assert f"Uploading backup {test_backup.backup_id}" in caplog.text
mock_drive_items.patch.assert_called_once() mock_large_file_upload_client.assert_called_once()
assert mock_onedrive_client.update_drive_item.call_count == 0
assert "Hash validation failed, backup file might be corrupt" in caplog.text
async def test_agents_upload_4xx_errors_not_retried(
hass_client: ClientSessionGenerator,
caplog: pytest.LogCaptureFixture,
mock_drive_items: MagicMock,
mock_config_entry: MockConfigEntry,
mock_adapter: MagicMock,
) -> None:
"""Test agent upload backup."""
client = await hass_client()
test_backup = AgentBackup.from_dict(BACKUP_METADATA)
mock_adapter.send_async.side_effect = APIError(response_status_code=404)
with (
patch(
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
) as fetch_backup,
patch(
"homeassistant.components.backup.manager.read_backup",
return_value=test_backup,
),
patch("pathlib.Path.open") as mocked_open,
patch("homeassistant.components.onedrive.backup.UPLOAD_CHUNK_SIZE", 3),
):
mocked_open.return_value.read = Mock(side_effect=[b"test", b""])
fetch_backup.return_value = test_backup
resp = await client.post(
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.unique_id}",
data={"file": StringIO("test")},
)
assert resp.status == 201
assert mock_adapter.send_async.call_count == 1
assert f"Uploading backup {test_backup.backup_id}" in caplog.text
assert mock_drive_items.patch.call_count == 0
assert "Backup operation failed" in caplog.text
@pytest.mark.parametrize(
("side_effect", "error"),
[
(APIError(response_status_code=500), "Backup operation failed"),
(TimeoutException("Timeout"), "Backup operation timed out"),
],
)
async def test_agents_upload_fails_after_max_retries(
hass_client: ClientSessionGenerator,
caplog: pytest.LogCaptureFixture,
mock_drive_items: MagicMock,
mock_config_entry: MockConfigEntry,
mock_adapter: MagicMock,
side_effect: Exception,
error: str,
) -> None:
"""Test agent upload backup."""
client = await hass_client()
test_backup = AgentBackup.from_dict(BACKUP_METADATA)
mock_adapter.send_async.side_effect = side_effect
with (
patch(
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
) as fetch_backup,
patch(
"homeassistant.components.backup.manager.read_backup",
return_value=test_backup,
),
patch("pathlib.Path.open") as mocked_open,
patch("homeassistant.components.onedrive.backup.UPLOAD_CHUNK_SIZE", 3),
):
mocked_open.return_value.read = Mock(side_effect=[b"test", b""])
fetch_backup.return_value = test_backup
resp = await client.post(
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.unique_id}",
data={"file": StringIO("test")},
)
assert resp.status == 201
assert mock_adapter.send_async.call_count == 6
assert f"Uploading backup {test_backup.backup_id}" in caplog.text
assert mock_drive_items.patch.call_count == 0
assert error in caplog.text
async def test_agents_download( async def test_agents_download(
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
mock_drive_items: MagicMock, mock_onedrive_client: MagicMock,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
) -> None: ) -> None:
"""Test agent download backup.""" """Test agent download backup."""
mock_drive_items.get = AsyncMock(
return_value=DriveItem(description=escape(dumps(BACKUP_METADATA)))
)
client = await hass_client() client = await hass_client()
backup_id = BACKUP_METADATA["backup_id"] backup_id = BACKUP_METADATA["backup_id"]
@ -408,29 +237,30 @@ async def test_agents_download(
) )
assert resp.status == 200 assert resp.status == 200
assert await resp.content.read() == b"backup data" assert await resp.content.read() == b"backup data"
mock_drive_items.content.get.assert_called_once()
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "error"), ("side_effect", "error"),
[ [
( (
APIError(response_status_code=500), OneDriveException(),
"Backup operation failed", "Backup operation failed",
), ),
(TimeoutException("Timeout"), "Backup operation timed out"), (TimeoutError(), "Backup operation timed out"),
], ],
) )
async def test_delete_error( async def test_delete_error(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
mock_drive_items: MagicMock, mock_onedrive_client: MagicMock,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
side_effect: Exception, side_effect: Exception,
error: str, error: str,
) -> None: ) -> None:
"""Test error during delete.""" """Test error during delete."""
mock_drive_items.delete = AsyncMock(side_effect=side_effect) mock_onedrive_client.delete_drive_item.side_effect = AsyncMock(
side_effect=side_effect
)
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -448,14 +278,35 @@ async def test_delete_error(
} }
async def test_agents_delete_not_found_does_not_throw(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_onedrive_client: MagicMock,
) -> None:
"""Test agent delete backup."""
mock_onedrive_client.list_drive_items.return_value = []
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "backup/delete",
"backup_id": BACKUP_METADATA["backup_id"],
}
)
response = await client.receive_json()
assert response["success"]
assert response["result"] == {"agent_errors": {}}
async def test_agents_backup_not_found( async def test_agents_backup_not_found(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
mock_drive_items: MagicMock, mock_onedrive_client: MagicMock,
) -> None: ) -> None:
"""Test backup not found.""" """Test backup not found."""
mock_drive_items.children.get = AsyncMock(return_value=[]) mock_onedrive_client.list_drive_items.return_value = []
backup_id = BACKUP_METADATA["backup_id"] backup_id = BACKUP_METADATA["backup_id"]
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
await client.send_json_auto_id({"type": "backup/details", "backup_id": backup_id}) await client.send_json_auto_id({"type": "backup/details", "backup_id": backup_id})
@ -468,13 +319,13 @@ async def test_agents_backup_not_found(
async def test_reauth_on_403( async def test_reauth_on_403(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
mock_drive_items: MagicMock, mock_onedrive_client: MagicMock,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
) -> None: ) -> None:
"""Test we re-authenticate on 403.""" """Test we re-authenticate on 403."""
mock_drive_items.children.get = AsyncMock( mock_onedrive_client.list_drive_items.side_effect = AuthenticationError(
side_effect=APIError(response_status_code=403) 403, "Auth failed"
) )
backup_id = BACKUP_METADATA["backup_id"] backup_id = BACKUP_METADATA["backup_id"]
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -483,7 +334,7 @@ async def test_reauth_on_403(
assert response["success"] assert response["success"]
assert response["result"]["agent_errors"] == { assert response["result"]["agent_errors"] == {
f"{DOMAIN}.{mock_config_entry.unique_id}": "Backup operation failed" f"{DOMAIN}.{mock_config_entry.unique_id}": "Authentication error"
} }
await hass.async_block_till_done() await hass.async_block_till_done()

View File

@ -3,8 +3,7 @@
from http import HTTPStatus from http import HTTPStatus
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from httpx import Response from onedrive_personal_sdk.exceptions import OneDriveException
from kiota_abstractions.api_error import APIError
import pytest import pytest
from homeassistant import config_entries from homeassistant import config_entries
@ -20,7 +19,7 @@ from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers import config_entry_oauth2_flow
from . import setup_integration from . import setup_integration
from .const import CLIENT_ID from .const import CLIENT_ID, MOCK_APPROOT
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.test_util.aiohttp import AiohttpClientMocker from tests.test_util.aiohttp import AiohttpClientMocker
@ -89,25 +88,52 @@ async def test_full_flow(
assert result["data"][CONF_TOKEN]["refresh_token"] == "mock-refresh-token" assert result["data"][CONF_TOKEN]["refresh_token"] == "mock-refresh-token"
@pytest.mark.usefixtures("current_request_with_host")
async def test_full_flow_with_owner_not_found(
hass: HomeAssistant,
hass_client_no_auth: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
mock_setup_entry: AsyncMock,
mock_onedrive_client: MagicMock,
) -> None:
"""Ensure we get a default title if the drive's owner can't be read."""
mock_onedrive_client.get_approot.return_value.created_by.user = None
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await _do_get_token(hass, result, hass_client_no_auth, aioclient_mock)
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] is FlowResultType.CREATE_ENTRY
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
assert len(mock_setup_entry.mock_calls) == 1
assert result["title"] == "OneDrive"
assert result["result"].unique_id == "mock_drive_id"
assert result["data"][CONF_TOKEN][CONF_ACCESS_TOKEN] == "mock-access-token"
assert result["data"][CONF_TOKEN]["refresh_token"] == "mock-refresh-token"
@pytest.mark.usefixtures("current_request_with_host") @pytest.mark.usefixtures("current_request_with_host")
@pytest.mark.parametrize( @pytest.mark.parametrize(
("exception", "error"), ("exception", "error"),
[ [
(Exception, "unknown"), (Exception, "unknown"),
(APIError, "connection_error"), (OneDriveException, "connection_error"),
], ],
) )
async def test_flow_errors( async def test_flow_errors(
hass: HomeAssistant, hass: HomeAssistant,
hass_client_no_auth: ClientSessionGenerator, hass_client_no_auth: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker, aioclient_mock: AiohttpClientMocker,
mock_adapter: MagicMock, mock_onedrive_client: MagicMock,
exception: Exception, exception: Exception,
error: str, error: str,
) -> None: ) -> None:
"""Test errors during flow.""" """Test errors during flow."""
mock_adapter.get_http_response_message.side_effect = exception mock_onedrive_client.get_approot.side_effect = exception
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER} DOMAIN, context={"source": config_entries.SOURCE_USER}
@ -172,15 +198,12 @@ async def test_reauth_flow_id_changed(
aioclient_mock: AiohttpClientMocker, aioclient_mock: AiohttpClientMocker,
mock_setup_entry: AsyncMock, mock_setup_entry: AsyncMock,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
mock_adapter: MagicMock, mock_onedrive_client: MagicMock,
) -> None: ) -> None:
"""Test that the reauth flow fails on a different drive id.""" """Test that the reauth flow fails on a different drive id."""
mock_adapter.get_http_response_message.return_value = Response( app_root = MOCK_APPROOT
status_code=200, app_root.parent_reference.drive_id = "other_drive_id"
json={ mock_onedrive_client.get_approot.return_value = app_root
"parentReference": {"driveId": "other_drive_id"},
},
)
await setup_integration(hass, mock_config_entry) await setup_integration(hass, mock_config_entry)

View File

@ -2,7 +2,7 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
from kiota_abstractions.api_error import APIError from onedrive_personal_sdk.exceptions import AuthenticationError, OneDriveException
import pytest import pytest
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
@ -31,82 +31,31 @@ async def test_load_unload_config_entry(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "state"), ("side_effect", "state"),
[ [
(APIError(response_status_code=403), ConfigEntryState.SETUP_ERROR), (AuthenticationError(403, "Auth failed"), ConfigEntryState.SETUP_ERROR),
(APIError(response_status_code=500), ConfigEntryState.SETUP_RETRY), (OneDriveException(), ConfigEntryState.SETUP_RETRY),
], ],
) )
async def test_approot_errors( async def test_approot_errors(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
mock_get_special_folder: MagicMock, mock_onedrive_client: MagicMock,
side_effect: Exception, side_effect: Exception,
state: ConfigEntryState, state: ConfigEntryState,
) -> None: ) -> None:
"""Test errors during approot retrieval.""" """Test errors during approot retrieval."""
mock_get_special_folder.side_effect = side_effect mock_onedrive_client.get_approot.side_effect = side_effect
await setup_integration(hass, mock_config_entry) await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is state assert mock_config_entry.state is state
async def test_faulty_approot( async def test_get_integration_folder_error(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
mock_get_special_folder: MagicMock, mock_onedrive_client: MagicMock,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test faulty approot retrieval.""" """Test faulty approot retrieval."""
mock_get_special_folder.return_value = None mock_onedrive_client.create_folder.side_effect = OneDriveException()
await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
assert "Failed to get approot folder" in caplog.text
async def test_faulty_integration_folder(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_drive_items: MagicMock,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test faulty approot retrieval."""
mock_drive_items.get.return_value = None
await setup_integration(hass, mock_config_entry) await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
assert "Failed to get backups_9f86d081 folder" in caplog.text assert "Failed to get backups_9f86d081 folder" in caplog.text
async def test_500_error_during_backup_folder_get(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_drive_items: MagicMock,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test error during backup folder creation."""
mock_drive_items.get.side_effect = APIError(response_status_code=500)
await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
assert "Failed to get backups_9f86d081 folder" in caplog.text
async def test_error_during_backup_folder_creation(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_drive_items: MagicMock,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test error during backup folder creation."""
mock_drive_items.get.side_effect = APIError(response_status_code=404)
mock_drive_items.children.post.side_effect = APIError()
await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
assert "Failed to create backups_9f86d081 folder" in caplog.text
async def test_successful_backup_folder_creation(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_drive_items: MagicMock,
) -> None:
"""Test successful backup folder creation."""
mock_drive_items.get.side_effect = APIError(response_status_code=404)
await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is ConfigEntryState.LOADED

View File

@ -12,12 +12,14 @@ from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
CONF_PROMPT, CONF_PROMPT,
CONF_REASONING_EFFORT,
CONF_RECOMMENDED, CONF_RECOMMENDED,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DOMAIN, DOMAIN,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TOP_P, RECOMMENDED_TOP_P,
) )
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
@ -88,6 +90,27 @@ async def test_options(
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
async def test_options_unsupported_model(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test the options form giving error about models not supported."""
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
result = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_CHAT_MODEL: "o1-mini",
CONF_LLM_HASS_API: "assist",
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"chat_model": "model_not_supported"}
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "error"), ("side_effect", "error"),
[ [
@ -148,6 +171,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P, CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
}, },
), ),
( (
@ -158,6 +182,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P, CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
}, },
{ {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,

View File

@ -195,6 +195,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
), ),
@ -359,6 +360,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
), ),

View File

@ -0,0 +1,59 @@
# serializer version: 1
# name: test_service_get[schedule.from_storage-get-after-update]
dict({
'friday': list([
]),
'monday': list([
]),
'saturday': list([
]),
'sunday': list([
]),
'thursday': list([
]),
'tuesday': list([
]),
'wednesday': list([
dict({
'from': datetime.time(17, 0),
'to': datetime.time(19, 0),
}),
]),
})
# ---
# name: test_service_get[schedule.from_storage-get]
dict({
'friday': list([
dict({
'data': dict({
'party_level': 'epic',
}),
'from': datetime.time(17, 0),
'to': datetime.time(23, 59, 59),
}),
]),
'monday': list([
]),
'saturday': list([
dict({
'from': datetime.time(0, 0),
'to': datetime.time(23, 59, 59),
}),
]),
'sunday': list([
dict({
'data': dict({
'entry': 'VIPs only',
}),
'from': datetime.time(0, 0),
'to': datetime.time(23, 59, 59, 999999),
}),
]),
'thursday': list([
]),
'tuesday': list([
]),
'wednesday': list([
]),
})
# ---

View File

@ -8,10 +8,12 @@ from unittest.mock import patch
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
import pytest import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.schedule import STORAGE_VERSION, STORAGE_VERSION_MINOR from homeassistant.components.schedule import STORAGE_VERSION, STORAGE_VERSION_MINOR
from homeassistant.components.schedule.const import ( from homeassistant.components.schedule.const import (
ATTR_NEXT_EVENT, ATTR_NEXT_EVENT,
CONF_ALL_DAYS,
CONF_DATA, CONF_DATA,
CONF_FRIDAY, CONF_FRIDAY,
CONF_FROM, CONF_FROM,
@ -23,12 +25,14 @@ from homeassistant.components.schedule.const import (
CONF_TUESDAY, CONF_TUESDAY,
CONF_WEDNESDAY, CONF_WEDNESDAY,
DOMAIN, DOMAIN,
SERVICE_GET,
) )
from homeassistant.const import ( from homeassistant.const import (
ATTR_EDITABLE, ATTR_EDITABLE,
ATTR_FRIENDLY_NAME, ATTR_FRIENDLY_NAME,
ATTR_ICON, ATTR_ICON,
ATTR_NAME, ATTR_NAME,
CONF_ENTITY_ID,
CONF_ICON, CONF_ICON,
CONF_ID, CONF_ID,
CONF_NAME, CONF_NAME,
@ -754,3 +758,66 @@ async def test_ws_create(
assert result["party_mode"][CONF_MONDAY] == [ assert result["party_mode"][CONF_MONDAY] == [
{CONF_FROM: "12:00:00", CONF_TO: saved_to} {CONF_FROM: "12:00:00", CONF_TO: saved_to}
] ]
async def test_service_get(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
schedule_setup: Callable[..., Coroutine[Any, Any, bool]],
) -> None:
"""Test getting a single schedule via service."""
assert await schedule_setup()
entity_id = "schedule.from_storage"
# Test retrieving a single schedule via service call
service_result = await hass.services.async_call(
DOMAIN,
SERVICE_GET,
{
CONF_ENTITY_ID: entity_id,
},
blocking=True,
return_response=True,
)
result = service_result.get(entity_id)
assert set(result) == CONF_ALL_DAYS
assert result == snapshot(name=f"{entity_id}-get")
# Now we update the schedule via WS
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 1,
"type": f"{DOMAIN}/update",
f"{DOMAIN}_id": entity_id.rsplit(".", maxsplit=1)[-1],
CONF_NAME: "Party pooper",
CONF_ICON: "mdi:party-pooper",
CONF_MONDAY: [],
CONF_TUESDAY: [],
CONF_WEDNESDAY: [{CONF_FROM: "17:00:00", CONF_TO: "19:00:00"}],
CONF_THURSDAY: [],
CONF_FRIDAY: [],
CONF_SATURDAY: [],
CONF_SUNDAY: [],
}
)
resp = await client.receive_json()
assert resp["success"]
# Test retrieving the schedule via service call after WS update
service_result = await hass.services.async_call(
DOMAIN,
SERVICE_GET,
{
CONF_ENTITY_ID: entity_id,
},
blocking=True,
return_response=True,
)
result = service_result.get(entity_id)
assert set(result) == CONF_ALL_DAYS
assert result == snapshot(name=f"{entity_id}-get-after-update")

View File

@ -180,6 +180,7 @@ MOCK_CONFIG = {
"xcounts": {"expr": None, "unit": None}, "xcounts": {"expr": None, "unit": None},
"xfreq": {"expr": None, "unit": None}, "xfreq": {"expr": None, "unit": None},
}, },
"flood:0": {"id": 0, "name": "Test name"},
"light:0": {"name": "test light_0"}, "light:0": {"name": "test light_0"},
"light:1": {"name": "test light_1"}, "light:1": {"name": "test light_1"},
"light:2": {"name": "test light_2"}, "light:2": {"name": "test light_2"},
@ -326,6 +327,7 @@ MOCK_STATUS_RPC = {
"em1:1": {"act_power": 123.3}, "em1:1": {"act_power": 123.3},
"em1data:0": {"total_act_energy": 123456.4}, "em1data:0": {"total_act_energy": 123456.4},
"em1data:1": {"total_act_energy": 987654.3}, "em1data:1": {"total_act_energy": 987654.3},
"flood:0": {"id": 0, "alarm": False, "mute": False},
"thermostat:0": { "thermostat:0": {
"id": 0, "id": 0,
"enable": True, "enable": True,

View File

@ -46,3 +46,96 @@
'state': 'off', 'state': 'off',
}) })
# --- # ---
# name: test_rpc_flood_entities[binary_sensor.test_name_flood-entry]
EntityRegistryEntrySnapshot({
'aliases': set({
}),
'area_id': None,
'capabilities': None,
'config_entry_id': <ANY>,
'device_class': None,
'device_id': <ANY>,
'disabled_by': None,
'domain': 'binary_sensor',
'entity_category': None,
'entity_id': 'binary_sensor.test_name_flood',
'has_entity_name': False,
'hidden_by': None,
'icon': None,
'id': <ANY>,
'labels': set({
}),
'name': None,
'options': dict({
}),
'original_device_class': <BinarySensorDeviceClass.MOISTURE: 'moisture'>,
'original_icon': None,
'original_name': 'Test name flood',
'platform': 'shelly',
'previous_unique_id': None,
'supported_features': 0,
'translation_key': None,
'unique_id': '123456789ABC-flood:0-flood',
'unit_of_measurement': None,
})
# ---
# name: test_rpc_flood_entities[binary_sensor.test_name_flood-state]
StateSnapshot({
'attributes': ReadOnlyDict({
'device_class': 'moisture',
'friendly_name': 'Test name flood',
}),
'context': <ANY>,
'entity_id': 'binary_sensor.test_name_flood',
'last_changed': <ANY>,
'last_reported': <ANY>,
'last_updated': <ANY>,
'state': 'off',
})
# ---
# name: test_rpc_flood_entities[binary_sensor.test_name_mute-entry]
EntityRegistryEntrySnapshot({
'aliases': set({
}),
'area_id': None,
'capabilities': None,
'config_entry_id': <ANY>,
'device_class': None,
'device_id': <ANY>,
'disabled_by': None,
'domain': 'binary_sensor',
'entity_category': <EntityCategory.DIAGNOSTIC: 'diagnostic'>,
'entity_id': 'binary_sensor.test_name_mute',
'has_entity_name': False,
'hidden_by': None,
'icon': None,
'id': <ANY>,
'labels': set({
}),
'name': None,
'options': dict({
}),
'original_device_class': None,
'original_icon': None,
'original_name': 'Test name mute',
'platform': 'shelly',
'previous_unique_id': None,
'supported_features': 0,
'translation_key': None,
'unique_id': '123456789ABC-flood:0-mute',
'unit_of_measurement': None,
})
# ---
# name: test_rpc_flood_entities[binary_sensor.test_name_mute-state]
StateSnapshot({
'attributes': ReadOnlyDict({
'friendly_name': 'Test name mute',
}),
'context': <ANY>,
'entity_id': 'binary_sensor.test_name_mute',
'last_changed': <ANY>,
'last_reported': <ANY>,
'last_updated': <ANY>,
'state': 'off',
})
# ---

View File

@ -496,3 +496,22 @@ async def test_blu_trv_binary_sensor_entity(
entry = entity_registry.async_get(entity_id) entry = entity_registry.async_get(entity_id)
assert entry == snapshot(name=f"{entity_id}-entry") assert entry == snapshot(name=f"{entity_id}-entry")
async def test_rpc_flood_entities(
hass: HomeAssistant,
mock_rpc_device: Mock,
entity_registry: EntityRegistry,
snapshot: SnapshotAssertion,
) -> None:
"""Test RPC flood sensor entities."""
await init_integration(hass, 4)
for entity in ("flood", "mute"):
entity_id = f"{BINARY_SENSOR_DOMAIN}.test_name_{entity}"
state = hass.states.get(entity_id)
assert state == snapshot(name=f"{entity_id}-state")
entry = entity_registry.async_get(entity_id)
assert entry == snapshot(name=f"{entity_id}-entry")

View File

@ -18,7 +18,8 @@ from tests.common import (
load_json_object_fixture, load_json_object_fixture,
) )
MOCK_HOST = "slzb-06.local" MOCK_DEVICE_NAME = "slzb-06"
MOCK_HOST = "192.168.1.161"
MOCK_USERNAME = "test-user" MOCK_USERNAME = "test-user"
MOCK_PASSWORD = "test-pass" MOCK_PASSWORD = "test-pass"

View File

@ -3,7 +3,7 @@
DeviceRegistryEntrySnapshot({ DeviceRegistryEntrySnapshot({
'area_id': None, 'area_id': None,
'config_entries': <ANY>, 'config_entries': <ANY>,
'configuration_url': 'http://slzb-06.local', 'configuration_url': 'http://192.168.1.161',
'connections': set({ 'connections': set({
tuple( tuple(
'mac', 'mac',

View File

@ -8,19 +8,20 @@ from pysmlight.exceptions import SmlightAuthError, SmlightConnectionError
import pytest import pytest
from homeassistant.components.smlight.const import DOMAIN from homeassistant.components.smlight.const import DOMAIN
from homeassistant.config_entries import SOURCE_USER, SOURCE_ZEROCONF from homeassistant.config_entries import SOURCE_DHCP, SOURCE_USER, SOURCE_ZEROCONF
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers.service_info.dhcp import DhcpServiceInfo
from homeassistant.helpers.service_info.zeroconf import ZeroconfServiceInfo from homeassistant.helpers.service_info.zeroconf import ZeroconfServiceInfo
from .conftest import MOCK_HOST, MOCK_PASSWORD, MOCK_USERNAME from .conftest import MOCK_DEVICE_NAME, MOCK_HOST, MOCK_PASSWORD, MOCK_USERNAME
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
DISCOVERY_INFO = ZeroconfServiceInfo( DISCOVERY_INFO = ZeroconfServiceInfo(
ip_address=ip_address("127.0.0.1"), ip_address=ip_address("192.168.1.161"),
ip_addresses=[ip_address("127.0.0.1")], ip_addresses=[ip_address("192.168.1.161")],
hostname="slzb-06.local.", hostname="slzb-06.local.",
name="mock_name", name="mock_name",
port=6638, port=6638,
@ -29,8 +30,8 @@ DISCOVERY_INFO = ZeroconfServiceInfo(
) )
DISCOVERY_INFO_LEGACY = ZeroconfServiceInfo( DISCOVERY_INFO_LEGACY = ZeroconfServiceInfo(
ip_address=ip_address("127.0.0.1"), ip_address=ip_address("192.168.1.161"),
ip_addresses=[ip_address("127.0.0.1")], ip_addresses=[ip_address("192.168.1.161")],
hostname="slzb-06.local.", hostname="slzb-06.local.",
name="mock_name", name="mock_name",
port=6638, port=6638,
@ -52,7 +53,7 @@ async def test_user_flow(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> No
result2 = await hass.config_entries.flow.async_configure( result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
{ {
CONF_HOST: MOCK_HOST, CONF_HOST: "slzb-06p7.local",
}, },
) )
@ -76,7 +77,7 @@ async def test_zeroconf_flow(
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=DISCOVERY_INFO DOMAIN, context={"source": SOURCE_ZEROCONF}, data=DISCOVERY_INFO
) )
assert result["description_placeholders"] == {"host": MOCK_HOST} assert result["description_placeholders"] == {"host": MOCK_DEVICE_NAME}
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "confirm_discovery" assert result["step_id"] == "confirm_discovery"
@ -113,7 +114,7 @@ async def test_zeroconf_flow_auth(
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=DISCOVERY_INFO DOMAIN, context={"source": SOURCE_ZEROCONF}, data=DISCOVERY_INFO
) )
assert result["description_placeholders"] == {"host": MOCK_HOST} assert result["description_placeholders"] == {"host": MOCK_DEVICE_NAME}
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "confirm_discovery" assert result["step_id"] == "confirm_discovery"
@ -167,7 +168,7 @@ async def test_zeroconf_unsupported_abort(
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=DISCOVERY_INFO DOMAIN, context={"source": SOURCE_ZEROCONF}, data=DISCOVERY_INFO
) )
assert result["description_placeholders"] == {"host": MOCK_HOST} assert result["description_placeholders"] == {"host": MOCK_DEVICE_NAME}
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "confirm_discovery" assert result["step_id"] == "confirm_discovery"
@ -489,7 +490,7 @@ async def test_zeroconf_legacy_mac(
data=DISCOVERY_INFO_LEGACY, data=DISCOVERY_INFO_LEGACY,
) )
assert result["description_placeholders"] == {"host": MOCK_HOST} assert result["description_placeholders"] == {"host": MOCK_DEVICE_NAME}
result2 = await hass.config_entries.flow.async_configure( result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={} result["flow_id"], user_input={}
@ -507,6 +508,76 @@ async def test_zeroconf_legacy_mac(
assert len(mock_smlight_client.get_info.mock_calls) == 3 assert len(mock_smlight_client.get_info.mock_calls) == 3
@pytest.mark.usefixtures("mock_smlight_client")
async def test_zeroconf_updates_host(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test zeroconf discovery updates host ip."""
mock_config_entry.add_to_hass(hass)
service_info = DISCOVERY_INFO
service_info.ip_address = ip_address("192.168.1.164")
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=service_info
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "already_configured"
assert mock_config_entry.data[CONF_HOST] == "192.168.1.164"
@pytest.mark.usefixtures("mock_smlight_client")
async def test_dhcp_discovery_updates_host(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test dhcp discovery updates host ip."""
mock_config_entry.add_to_hass(hass)
service_info = DhcpServiceInfo(
ip="192.168.1.164",
hostname="slzb-06",
macaddress="aabbccddeeff",
)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_DHCP}, data=service_info
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "already_configured"
assert mock_config_entry.data[CONF_HOST] == "192.168.1.164"
@pytest.mark.usefixtures("mock_smlight_client")
async def test_dhcp_discovery_aborts(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test dhcp discovery updates host ip."""
mock_config_entry.add_to_hass(hass)
service_info = DhcpServiceInfo(
ip="192.168.1.161",
hostname="slzb-06",
macaddress="000000000000",
)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_DHCP}, data=service_info
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "already_configured"
assert mock_config_entry.data[CONF_HOST] == "192.168.1.161"
async def test_reauth_flow( async def test_reauth_flow(
hass: HomeAssistant, hass: HomeAssistant,
mock_smlight_client: MagicMock, mock_smlight_client: MagicMock,

View File

@ -153,3 +153,25 @@ async def humidifier_config_entry(
await hass.async_block_till_done() await hass.async_block_till_done()
return entry return entry
@pytest.fixture(name="switch_old_id_config_entry")
async def switch_old_id_config_entry(
hass: HomeAssistant, requests_mock: requests_mock.Mocker, config
) -> MockConfigEntry:
"""Create a mock VeSync config entry for `switch` with the old unique ID approach."""
entry = MockConfigEntry(
title="VeSync",
domain=DOMAIN,
data=config[DOMAIN],
version=1,
minor_version=1,
)
entry.add_to_hass(hass)
wall_switch = "Wall Switch"
humidifer = "Humidifier 200s"
mock_multiple_device_responses(requests_mock, [wall_switch, humidifer])
return entry

View File

@ -128,7 +128,7 @@
'sleep', 'sleep',
'manual', 'manual',
]), ]),
'mode': None, 'mode': 'humidity',
'night_light': True, 'night_light': True,
'pid': None, 'pid': None,
'speed': None, 'speed': None,
@ -160,6 +160,30 @@
# --- # ---
# name: test_async_get_device_diagnostics__single_fan # name: test_async_get_device_diagnostics__single_fan
dict({ dict({
'_config_dict': dict({
'features': list([
'air_quality',
]),
'levels': list([
1,
2,
]),
'models': list([
'LV-PUR131S',
'LV-RH131S',
]),
'modes': list([
'manual',
'auto',
'sleep',
'off',
]),
'module': 'VeSyncAir131',
}),
'_features': list([
'air_quality',
]),
'air_quality_feature': True,
'cid': 'abcdefghabcdefghabcdefghabcdefgh', 'cid': 'abcdefghabcdefghabcdefghabcdefgh',
'config': dict({ 'config': dict({
}), }),
@ -180,6 +204,7 @@
'device_region': 'US', 'device_region': 'US',
'device_status': 'unknown', 'device_status': 'unknown',
'device_type': 'LV-PUR131S', 'device_type': 'LV-PUR131S',
'enabled': True,
'extension': None, 'extension': None,
'home_assistant': dict({ 'home_assistant': dict({
'disabled': False, 'disabled': False,
@ -271,6 +296,12 @@
'mac_id': '**REDACTED**', 'mac_id': '**REDACTED**',
'manager': '**REDACTED**', 'manager': '**REDACTED**',
'mode': None, 'mode': None,
'modes': list([
'manual',
'auto',
'sleep',
'off',
]),
'pid': None, 'pid': None,
'speed': None, 'speed': None,
'sub_device_no': None, 'sub_device_no': None,

View File

@ -367,7 +367,7 @@
'previous_unique_id': None, 'previous_unique_id': None,
'supported_features': 0, 'supported_features': 0,
'translation_key': None, 'translation_key': None,
'unique_id': 'outlet', 'unique_id': 'outlet-device_status',
'unit_of_measurement': None, 'unit_of_measurement': None,
}), }),
]) ])
@ -525,7 +525,7 @@
'previous_unique_id': None, 'previous_unique_id': None,
'supported_features': 0, 'supported_features': 0,
'translation_key': None, 'translation_key': None,
'unique_id': 'switch', 'unique_id': 'switch-device_status',
'unit_of_measurement': None, 'unit_of_measurement': None,
}), }),
]) ])

View File

@ -10,6 +10,9 @@ from homeassistant.components.vesync.const import DOMAIN, VS_DEVICES, VS_MANAGER
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from tests.common import MockConfigEntry
async def test_async_setup_entry__not_login( async def test_async_setup_entry__not_login(
@ -125,3 +128,55 @@ async def test_async_new_device_discovery(
assert manager.login.call_count == 1 assert manager.login.call_count == 1
assert hass.data[DOMAIN][VS_MANAGER] == manager assert hass.data[DOMAIN][VS_MANAGER] == manager
assert hass.data[DOMAIN][VS_DEVICES] == [fan, humidifier] assert hass.data[DOMAIN][VS_DEVICES] == [fan, humidifier]
async def test_migrate_config_entry(
hass: HomeAssistant,
switch_old_id_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration of config entry. Only migrates switches to a new unique_id."""
switch: er.RegistryEntry = entity_registry.async_get_or_create(
domain="switch",
platform="vesync",
unique_id="switch",
config_entry=switch_old_id_config_entry,
suggested_object_id="switch",
)
humidifer: er.RegistryEntry = entity_registry.async_get_or_create(
domain="humidifer",
platform="vesync",
unique_id="humidifer",
config_entry=switch_old_id_config_entry,
suggested_object_id="humidifer",
)
assert switch.unique_id == "switch"
assert switch_old_id_config_entry.minor_version == 1
assert humidifer.unique_id == "humidifer"
await hass.config_entries.async_setup(switch_old_id_config_entry.entry_id)
await hass.async_block_till_done()
assert switch_old_id_config_entry.minor_version == 2
migrated_switch = entity_registry.async_get(switch.entity_id)
assert migrated_switch is not None
assert migrated_switch.entity_id.startswith("switch")
assert migrated_switch.unique_id == "switch-device_status"
# Confirm humidifer was not impacted
migrated_humidifer = entity_registry.async_get(humidifer.entity_id)
assert migrated_humidifer is not None
assert migrated_humidifer.unique_id == "humidifer"
# Assert that only one entity exists in the switch domain
switch_entities = [
e for e in entity_registry.entities.values() if e.domain == "switch"
]
assert len(switch_entities) == 1
humidifer_entities = [
e for e in entity_registry.entities.values() if e.domain == "humidifer"
]
assert len(humidifer_entities) == 1