Merge branch 'dev' into dev

This commit is contained in:
Jonathan Sady do Nascimento 2025-02-03 12:54:57 -03:00 committed by GitHub
commit 3ceb25adae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
84 changed files with 1802 additions and 1296 deletions

View File

@ -7,6 +7,6 @@
"integration_type": "service",
"iot_class": "cloud_polling",
"loggers": ["python_homeassistant_analytics"],
"requirements": ["python-homeassistant-analytics==0.8.1"],
"requirements": ["python-homeassistant-analytics==0.9.0"],
"single_config_entry": true
}

View File

@ -272,6 +272,7 @@ class AnthropicConversationEntity(
continue
tool_input = llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.name,
tool_args=cast(dict[str, Any], tool_call.input),
)

View File

@ -9,6 +9,7 @@ import voluptuous as vol
from homeassistant.components import stt
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import chat_session
from homeassistant.helpers.typing import ConfigType
from .const import (
@ -114,24 +115,25 @@ async def async_pipeline_from_audio_stream(
Raises PipelineNotFound if no pipeline is found.
"""
pipeline_input = PipelineInput(
conversation_id=conversation_id,
device_id=device_id,
stt_metadata=stt_metadata,
stt_stream=stt_stream,
wake_word_phrase=wake_word_phrase,
conversation_extra_system_prompt=conversation_extra_system_prompt,
run=PipelineRun(
hass,
context=context,
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
start_stage=start_stage,
end_stage=end_stage,
event_callback=event_callback,
tts_audio_output=tts_audio_output,
wake_word_settings=wake_word_settings,
audio_settings=audio_settings or AudioSettings(),
),
)
await pipeline_input.validate()
await pipeline_input.execute()
with chat_session.async_get_chat_session(hass, conversation_id) as session:
pipeline_input = PipelineInput(
conversation_id=session.conversation_id,
device_id=device_id,
stt_metadata=stt_metadata,
stt_stream=stt_stream,
wake_word_phrase=wake_word_phrase,
conversation_extra_system_prompt=conversation_extra_system_prompt,
run=PipelineRun(
hass,
context=context,
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
start_stage=start_stage,
end_stage=end_stage,
event_callback=event_callback,
tts_audio_output=tts_audio_output,
wake_word_settings=wake_word_settings,
audio_settings=audio_settings or AudioSettings(),
),
)
await pipeline_input.validate()
await pipeline_input.execute()

View File

@ -624,7 +624,7 @@ class PipelineRun:
return
pipeline_data.pipeline_debug[self.pipeline.id][self.id].events.append(event)
def start(self, device_id: str | None) -> None:
def start(self, conversation_id: str, device_id: str | None) -> None:
"""Emit run start event."""
self._device_id = device_id
self._start_debug_recording_thread()
@ -632,6 +632,7 @@ class PipelineRun:
data = {
"pipeline": self.pipeline.id,
"language": self.language,
"conversation_id": conversation_id,
}
if self.runner_data is not None:
data["runner_data"] = self.runner_data
@ -1015,7 +1016,7 @@ class PipelineRun:
async def recognize_intent(
self,
intent_input: str,
conversation_id: str | None,
conversation_id: str,
device_id: str | None,
conversation_extra_system_prompt: str | None,
) -> str:
@ -1063,11 +1064,11 @@ class PipelineRun:
agent_id=self.intent_agent,
extra_system_prompt=conversation_extra_system_prompt,
)
processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT
agent_id = user_input.agent_id
agent_id = self.intent_agent
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
intent_response: intent.IntentResponse | None = None
if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT:
if not processed_locally:
# Sentence triggers override conversation agent
if (
trigger_response_text
@ -1105,13 +1106,13 @@ class PipelineRun:
speech: str = intent_response.speech.get("plain", {}).get(
"speech", ""
)
chat_log.async_add_message(
conversation.Content(
role="assistant",
async for _ in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=agent_id,
content=speech,
)
)
):
pass
conversation_result = conversation.ConversationResult(
response=intent_response,
conversation_id=session.conversation_id,
@ -1409,12 +1410,15 @@ def _pipeline_debug_recording_thread_proc(
wav_writer.close()
@dataclass
@dataclass(kw_only=True)
class PipelineInput:
"""Input to a pipeline run."""
run: PipelineRun
conversation_id: str
"""Identifier for the conversation."""
stt_metadata: stt.SpeechMetadata | None = None
"""Metadata of stt input audio. Required when start_stage = stt."""
@ -1430,9 +1434,6 @@ class PipelineInput:
tts_input: str | None = None
"""Input for text-to-speech. Required when start_stage = tts."""
conversation_id: str | None = None
"""Identifier for the conversation."""
conversation_extra_system_prompt: str | None = None
"""Extra prompt information for the conversation agent."""
@ -1441,7 +1442,7 @@ class PipelineInput:
async def execute(self) -> None:
"""Run pipeline."""
self.run.start(device_id=self.device_id)
self.run.start(conversation_id=self.conversation_id, device_id=self.device_id)
current_stage: PipelineStage | None = self.run.start_stage
stt_audio_buffer: list[EnhancedAudioChunk] = []
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None

View File

@ -14,7 +14,11 @@ import voluptuous as vol
from homeassistant.components import conversation, stt, tts, websocket_api
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers import (
chat_session,
config_validation as cv,
entity_registry as er,
)
from homeassistant.util import language as language_util
from .const import (
@ -145,7 +149,6 @@ async def websocket_run(
# Arguments to PipelineInput
input_args: dict[str, Any] = {
"conversation_id": msg.get("conversation_id"),
"device_id": msg.get("device_id"),
}
@ -233,38 +236,42 @@ async def websocket_run(
audio_settings=audio_settings or AudioSettings(),
)
pipeline_input = PipelineInput(**input_args)
with chat_session.async_get_chat_session(
hass, msg.get("conversation_id")
) as session:
input_args["conversation_id"] = session.conversation_id
pipeline_input = PipelineInput(**input_args)
try:
await pipeline_input.validate()
except PipelineError as error:
# Report more specific error when possible
connection.send_error(msg["id"], error.code, error.message)
return
try:
await pipeline_input.validate()
except PipelineError as error:
# Report more specific error when possible
connection.send_error(msg["id"], error.code, error.message)
return
# Confirm subscription
connection.send_result(msg["id"])
# Confirm subscription
connection.send_result(msg["id"])
run_task = hass.async_create_task(pipeline_input.execute())
run_task = hass.async_create_task(pipeline_input.execute())
# Cancel pipeline if user unsubscribes
connection.subscriptions[msg["id"]] = run_task.cancel
# Cancel pipeline if user unsubscribes
connection.subscriptions[msg["id"]] = run_task.cancel
try:
# Task contains a timeout
async with asyncio.timeout(timeout):
await run_task
except TimeoutError:
pipeline_input.run.process_event(
PipelineEvent(
PipelineEventType.ERROR,
{"code": "timeout", "message": "Timeout running pipeline"},
try:
# Task contains a timeout
async with asyncio.timeout(timeout):
await run_task
except TimeoutError:
pipeline_input.run.process_event(
PipelineEvent(
PipelineEventType.ERROR,
{"code": "timeout", "message": "Timeout running pipeline"},
)
)
)
finally:
if unregister_handler is not None:
# Unregister binary handler
unregister_handler()
finally:
if unregister_handler is not None:
# Unregister binary handler
unregister_handler()
@callback

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass
from enum import StrEnum
import logging
import time
from typing import Any, Final, Literal, final
from typing import Any, Literal, final
from homeassistant.components import conversation, media_source, stt, tts
from homeassistant.components.assist_pipeline import (
@ -28,14 +28,12 @@ from homeassistant.components.tts import (
)
from homeassistant.core import Context, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity
from homeassistant.helpers import chat_session, entity
from homeassistant.helpers.entity import EntityDescription
from .const import AssistSatelliteEntityFeature
from .errors import AssistSatelliteError, SatelliteBusyError
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
_LOGGER = logging.getLogger(__name__)
@ -114,7 +112,6 @@ class AssistSatelliteEntity(entity.Entity):
_attr_vad_sensitivity_entity_id: str | None = None
_conversation_id: str | None = None
_conversation_id_time: float | None = None
_run_has_tts: bool = False
_is_announcing = False
@ -260,6 +257,21 @@ class AssistSatelliteEntity(entity.Entity):
else:
self._extra_system_prompt = start_message or None
with (
# Not passing in a conversation ID will force a new one to be created
chat_session.async_get_chat_session(self.hass) as session,
conversation.async_get_chat_log(self.hass, session) as chat_log,
):
self._conversation_id = session.conversation_id
if start_message:
async for _tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=self.entity_id, content=start_message
)
):
pass # no tool responses.
try:
await self.async_start_conversation(announcement)
finally:
@ -325,51 +337,52 @@ class AssistSatelliteEntity(entity.Entity):
assert self._context is not None
# Reset conversation id if necessary
if self._conversation_id_time and (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
self._conversation_id_time = None
# Set entity state based on pipeline events
self._run_has_tts = False
assert self.platform.config_entry is not None
self._pipeline_task = self.platform.config_entry.async_create_background_task(
self.hass,
async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=self._resolve_pipeline(),
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()
),
start_stage=start_stage,
end_stage=end_stage,
conversation_extra_system_prompt=extra_system_prompt,
),
f"{self.entity_id}_pipeline",
)
try:
await self._pipeline_task
finally:
self._pipeline_task = None
with chat_session.async_get_chat_session(
self.hass, self._conversation_id
) as session:
# Store the conversation ID. If it is no longer valid, get_chat_session will reset it
self._conversation_id = session.conversation_id
self._pipeline_task = (
self.platform.config_entry.async_create_background_task(
self.hass,
async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=self._resolve_pipeline(),
conversation_id=session.conversation_id,
device_id=device_id,
tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()
),
start_stage=start_stage,
end_stage=end_stage,
conversation_extra_system_prompt=extra_system_prompt,
),
f"{self.entity_id}_pipeline",
)
)
try:
await self._pipeline_task
finally:
self._pipeline_task = None
async def _cancel_running_pipeline(self) -> None:
"""Cancel the current pipeline if it's running."""
@ -393,11 +406,6 @@ class AssistSatelliteEntity(entity.Entity):
self._set_state(AssistSatelliteState.LISTENING)
elif event.type is PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING)
elif event.type is PipelineEventType.INTENT_END:
assert event.data is not None
# Update timeout
self._conversation_id_time = time.monotonic()
self._conversation_id = event.data["intent_output"]["conversation_id"]
elif event.type is PipelineEventType.TTS_START:
# Wait until tts_response_finished is called to return to waiting state
self._run_has_tts = True

View File

@ -19,6 +19,8 @@ from .const import (
)
from .entity import BangOlufsenEntity
PARALLEL_UPDATES = 0
async def async_setup_entry(
hass: HomeAssistant,

View File

@ -20,7 +20,7 @@
"bluetooth-adapters==0.21.1",
"bluetooth-auto-recovery==1.4.2",
"bluetooth-data-tools==1.23.3",
"dbus-fast==2.31.0",
"dbus-fast==2.32.0",
"habluetooth==3.21.0"
]
}

View File

@ -30,6 +30,16 @@ from .agent_manager import (
async_get_agent,
get_agent_manager,
)
from .chat_log import (
AssistantContent,
ChatLog,
Content,
ConverseError,
SystemContent,
ToolResultContent,
UserContent,
async_get_chat_log,
)
from .const import (
ATTR_AGENT_ID,
ATTR_CONVERSATION_ID,
@ -48,13 +58,13 @@ from .default_agent import DefaultAgent, async_setup_default_agent
from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
from .session import ChatLog, Content, ConverseError, NativeContent, async_get_chat_log
from .trace import ConversationTraceEventType, async_conversation_trace_append
__all__ = [
"DOMAIN",
"HOME_ASSISTANT_AGENT",
"OLD_HOME_ASSISTANT_AGENT",
"AssistantContent",
"ChatLog",
"Content",
"ConversationEntity",
@ -63,7 +73,9 @@ __all__ = [
"ConversationResult",
"ConversationTraceEventType",
"ConverseError",
"NativeContent",
"SystemContent",
"ToolResultContent",
"UserContent",
"async_conversation_trace_append",
"async_converse",
"async_get_agent_info",

View File

@ -2,19 +2,16 @@
from __future__ import annotations
from collections.abc import Generator
from collections.abc import AsyncGenerator, Generator
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from datetime import datetime
import logging
from typing import Literal
import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import chat_session, intent, llm, template
from homeassistant.util import dt as dt_util
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType
@ -31,7 +28,7 @@ LOGGER = logging.getLogger(__name__)
def async_get_chat_log(
hass: HomeAssistant,
session: chat_session.ChatSession,
user_input: ConversationInput,
user_input: ConversationInput | None = None,
) -> Generator[ChatLog]:
"""Return chat log for a specific chat session."""
all_history = hass.data.get(DATA_CHAT_HISTORY)
@ -42,9 +39,9 @@ def async_get_chat_log(
history = all_history.get(session.conversation_id)
if history:
history = replace(history, messages=history.messages.copy())
history = replace(history, content=history.content.copy())
else:
history = ChatLog(hass, session.conversation_id, user_input.agent_id)
history = ChatLog(hass, session.conversation_id)
@callback
def do_cleanup() -> None:
@ -53,22 +50,19 @@ def async_get_chat_log(
session.async_on_cleanup(do_cleanup)
message: Content = Content(
role="user",
agent_id=user_input.agent_id,
content=user_input.text,
)
history.async_add_message(message)
if user_input is not None:
history.async_add_user_content(UserContent(content=user_input.text))
last_message = history.content[-1]
yield history
if history.messages[-1] is message:
if history.content[-1] is last_message:
LOGGER.debug(
"History opened but no assistant message was added, ignoring update"
)
return
history.last_updated = dt_util.utcnow()
all_history[session.conversation_id] = history
@ -94,63 +88,94 @@ class ConverseError(HomeAssistantError):
)
@dataclass
class Content:
@dataclass(frozen=True)
class SystemContent:
"""Base class for chat messages."""
role: Literal["system", "assistant", "user"]
agent_id: str | None
role: str = field(init=False, default="system")
content: str
@dataclass(frozen=True)
class NativeContent[_NativeT]:
"""Native content."""
class UserContent:
"""Assistant content."""
role: str = field(init=False, default="native")
role: str = field(init=False, default="user")
content: str
@dataclass(frozen=True)
class AssistantContent:
"""Assistant content."""
role: str = field(init=False, default="assistant")
agent_id: str
content: _NativeT
content: str
tool_calls: list[llm.ToolInput] | None = None
@dataclass(frozen=True)
class ToolResultContent:
"""Tool result content."""
role: str = field(init=False, default="tool_result")
agent_id: str
tool_call_id: str
tool_name: str
tool_result: JsonObjectType
Content = SystemContent | UserContent | AssistantContent | ToolResultContent
@dataclass
class ChatLog[_NativeT]:
class ChatLog:
"""Class holding the chat history of a specific conversation."""
hass: HomeAssistant
conversation_id: str
agent_id: str | None
user_name: str | None = None
messages: list[Content | NativeContent[_NativeT]] = field(
default_factory=lambda: [Content(role="system", agent_id=None, content="")]
)
content: list[Content] = field(default_factory=lambda: [SystemContent(content="")])
extra_system_prompt: str | None = None
llm_api: llm.APIInstance | None = None
last_updated: datetime = field(default_factory=dt_util.utcnow)
@callback
def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None:
"""Process intent."""
if message.role == "system":
raise ValueError("Cannot add system messages to history")
if message.role != "native" and self.messages[-1].role == message.role:
raise ValueError("Cannot add two assistant or user messages in a row")
def async_add_user_content(self, content: UserContent) -> None:
"""Add user content to the log."""
self.content.append(content)
self.messages.append(message)
async def async_add_assistant_content(
self, content: AssistantContent
) -> AsyncGenerator[ToolResultContent]:
"""Add assistant content."""
self.content.append(content)
@callback
def async_get_messages(
self, agent_id: str | None = None
) -> list[Content | NativeContent[_NativeT]]:
"""Get messages for a specific agent ID.
if content.tool_calls is None:
return
This will filter out any native message tied to other agent IDs.
It can still include assistant/user messages generated by other agents.
"""
return [
message
for message in self.messages
if message.role != "native" or message.agent_id == agent_id
]
if self.llm_api is None:
raise ValueError("No LLM API configured")
for tool_input in content.tool_calls:
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
try:
tool_result = await self.llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_result = {"error": type(e).__name__}
if str(e):
tool_result["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_result)
response_content = ToolResultContent(
agent_id=content.agent_id,
tool_call_id=tool_input.id,
tool_name=tool_input.tool_name,
tool_result=tool_result,
)
self.content.append(response_content)
yield response_content
async def async_update_llm_data(
self,
@ -250,36 +275,16 @@ class ChatLog[_NativeT]:
prompt = "\n".join(prompt_parts)
self.llm_api = llm_api
self.user_name = user_name
self.extra_system_prompt = extra_system_prompt
self.messages[0] = Content(
role="system",
agent_id=user_input.agent_id,
content=prompt,
)
self.content[0] = SystemContent(content=prompt)
LOGGER.debug("Prompt: %s", self.messages)
LOGGER.debug("Prompt: %s", self.content)
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{
"messages": self.messages,
"messages": self.content,
"tools": self.llm_api.tools if self.llm_api else None,
},
)
async def async_call_tool(self, tool_input: llm.ToolInput) -> JsonObjectType:
"""Invoke LLM tool for the configured LLM API."""
if not self.llm_api:
raise ValueError("No LLM API configured")
LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
try:
tool_response = await self.llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_response)
return tool_response

View File

@ -55,6 +55,7 @@ from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_track_state_added_domain
from homeassistant.util.json import JsonObjectType, json_loads_object
from .chat_log import AssistantContent, async_get_chat_log
from .const import (
DATA_DEFAULT_ENTITY,
DEFAULT_EXPOSED_ATTRIBUTES,
@ -63,7 +64,6 @@ from .const import (
)
from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult
from .session import Content, async_get_chat_log
from .trace import ConversationTraceEventType, async_conversation_trace_append
_LOGGER = logging.getLogger(__name__)
@ -379,13 +379,13 @@ class DefaultAgent(ConversationEntity):
)
speech: str = response.speech.get("plain", {}).get("speech", "")
chat_log.async_add_message(
Content(
role="assistant",
agent_id=user_input.agent_id,
async for _tool_result in chat_log.async_add_assistant_content(
AssistantContent(
agent_id=user_input.agent_id, # type: ignore[arg-type]
content=speech,
)
)
):
pass
return ConversationResult(
response=response, conversation_id=session.conversation_id

View File

@ -22,5 +22,5 @@
"integration_type": "device",
"iot_class": "local_polling",
"loggers": ["eq3btsmart"],
"requirements": ["eq3btsmart==1.4.1", "bleak-esphome==2.6.0"]
"requirements": ["eq3btsmart==1.4.1", "bleak-esphome==2.7.0"]
}

View File

@ -18,7 +18,7 @@
"requirements": [
"aioesphomeapi==29.0.0",
"esphome-dashboard-api==1.2.3",
"bleak-esphome==2.6.0"
"bleak-esphome==2.7.0"
],
"zeroconf": ["_esphomelib._tcp.local."]
}

View File

@ -28,14 +28,14 @@
"user": {
"description": "Enter the settings to connect to the camera.",
"data": {
"still_image_url": "Still Image URL (e.g. http://...)",
"stream_source": "Stream Source URL (e.g. rtsp://...)",
"still_image_url": "Still image URL (e.g. http://...)",
"stream_source": "Stream source URL (e.g. rtsp://...)",
"rtsp_transport": "RTSP transport protocol",
"authentication": "Authentication",
"limit_refetch_to_url_change": "Limit refetch to url change",
"limit_refetch_to_url_change": "Limit refetch to URL change",
"password": "[%key:common::config_flow::data::password%]",
"username": "[%key:common::config_flow::data::username%]",
"framerate": "Frame Rate (Hz)",
"framerate": "Frame rate (Hz)",
"verify_ssl": "[%key:common::config_flow::data::verify_ssl%]"
}
},

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import codecs
from collections.abc import Callable
from typing import Any, Literal
from typing import Any, Literal, cast
from google.api_core.exceptions import GoogleAPIError
import google.generativeai as genai
@ -149,15 +149,53 @@ def _escape_decode(value: Any) -> Any:
return value
def _chat_message_convert(
message: conversation.Content | conversation.NativeContent[genai_types.ContentDict],
) -> genai_types.ContentDict:
"""Convert any native chat message for this agent to the native format."""
if message.role == "native":
return message.content
def _create_google_tool_response_content(
content: list[conversation.ToolResultContent],
) -> protos.Content:
"""Create a Google tool response content."""
return protos.Content(
parts=[
protos.Part(
function_response=protos.FunctionResponse(
name=tool_result.tool_name, response=tool_result.tool_result
)
)
for tool_result in content
]
)
role = "model" if message.role == "assistant" else message.role
return {"role": role, "parts": message.content}
def _convert_content(
content: conversation.UserContent
| conversation.AssistantContent
| conversation.SystemContent,
) -> genai_types.ContentDict:
"""Convert HA content to Google content."""
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
role = "model" if content.role == "assistant" else content.role
return {"role": role, "parts": content.content}
# Handle the Assistant content with tool calls.
assert type(content) is conversation.AssistantContent
parts = []
if content.content:
parts.append(protos.Part(text=content.content))
if content.tool_calls:
parts.extend(
[
protos.Part(
function_call=protos.FunctionCall(
name=tool_call.tool_name,
args=_escape_decode(tool_call.tool_args),
)
)
for tool_call in content.tool_calls
]
)
return protos.Content({"role": "model", "parts": parts})
class GoogleGenerativeAIConversationEntity(
@ -220,7 +258,7 @@ class GoogleGenerativeAIConversationEntity(
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
session: conversation.ChatLog[genai_types.ContentDict],
chat_log: conversation.ChatLog,
) -> conversation.ConversationResult:
"""Call the API."""
@ -228,7 +266,7 @@ class GoogleGenerativeAIConversationEntity(
options = self.entry.options
try:
await session.async_update_llm_data(
await chat_log.async_update_llm_data(
DOMAIN,
user_input,
options.get(CONF_LLM_HASS_API),
@ -238,10 +276,10 @@ class GoogleGenerativeAIConversationEntity(
return err.as_conversation_result()
tools: list[dict[str, Any]] | None = None
if session.llm_api:
if chat_log.llm_api:
tools = [
_format_tool(tool, session.llm_api.custom_serializer)
for tool in session.llm_api.tools
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
@ -252,9 +290,36 @@ class GoogleGenerativeAIConversationEntity(
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
)
prompt, *messages = [
_chat_message_convert(message) for message in session.async_get_messages()
]
prompt = chat_log.content[0].content # type: ignore[union-attr]
messages: list[genai_types.ContentDict] = []
# Google groups tool results, we do not. Group them before sending.
tool_results: list[conversation.ToolResultContent] = []
for chat_content in chat_log.content[1:]:
if chat_content.role == "tool_result":
# mypy doesn't like picking a type based on checking shared property 'role'
tool_results.append(cast(conversation.ToolResultContent, chat_content))
continue
if tool_results:
messages.append(_create_google_tool_response_content(tool_results))
tool_results.clear()
messages.append(
_convert_content(
cast(
conversation.UserContent
| conversation.SystemContent
| conversation.AssistantContent,
chat_content,
)
)
)
if tool_results:
messages.append(_create_google_tool_response_content(tool_results))
model = genai.GenerativeModel(
model_name=model_name,
generation_config={
@ -282,12 +347,12 @@ class GoogleGenerativeAIConversationEntity(
),
},
tools=tools or None,
system_instruction=prompt["parts"] if supports_system_instruction else None,
system_instruction=prompt if supports_system_instruction else None,
)
if not supports_system_instruction:
messages = [
{"role": "user", "parts": prompt["parts"]},
{"role": "user", "parts": prompt},
{"role": "model", "parts": "Ok"},
*messages,
]
@ -325,50 +390,40 @@ class GoogleGenerativeAIConversationEntity(
content = " ".join(
[part.text.strip() for part in chat_response.parts if part.text]
)
if content:
session.async_add_message(
conversation.Content(
role="assistant",
agent_id=user_input.agent_id,
content=content,
)
)
function_calls = [
part.function_call for part in chat_response.parts if part.function_call
]
if not function_calls or not session.llm_api:
break
tool_responses = []
for function_call in function_calls:
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
tool_calls = []
for part in chat_response.parts:
if not part.function_call:
continue
tool_call = MessageToDict(part.function_call._pb) # noqa: SLF001
tool_name = tool_call["name"]
tool_args = _escape_decode(tool_call["args"])
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
function_response = await session.async_call_tool(tool_input)
tool_responses.append(
protos.Part(
function_response=protos.FunctionResponse(
name=tool_name, response=function_response
tool_calls.append(
llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
)
chat_request = _create_google_tool_response_content(
[
tool_response
async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=content,
tool_calls=tool_calls or None,
)
)
)
chat_request = protos.Content(parts=tool_responses)
session.async_add_message(
conversation.NativeContent(
agent_id=user_input.agent_id,
content=chat_request,
)
]
)
if not tool_calls:
break
response = intent.IntentResponse(language=user_input.language)
response.async_set_speech(
" ".join([part.text.strip() for part in chat_response.parts if part.text])
)
return conversation.ConversationResult(
response=response, conversation_id=session.conversation_id
response=response, conversation_id=chat_log.conversation_id
)
async def _async_entry_update_listener(

View File

@ -517,17 +517,22 @@ class SupervisorBackupReaderWriter(BackupReaderWriter):
raise HomeAssistantError(message) from err
restore_complete = asyncio.Event()
restore_errors: list[dict[str, str]] = []
@callback
def on_job_progress(data: Mapping[str, Any]) -> None:
"""Handle backup restore progress."""
if data.get("done") is True:
restore_complete.set()
restore_errors.extend(data.get("errors", []))
unsub = self._async_listen_job_events(job.job_id, on_job_progress)
try:
await self._get_job_state(job.job_id, on_job_progress)
await restore_complete.wait()
if restore_errors:
# We should add more specific error handling here in the future
raise BackupReaderWriterError(f"Restore failed: {restore_errors}")
finally:
unsub()
@ -554,11 +559,23 @@ class SupervisorBackupReaderWriter(BackupReaderWriter):
)
return
on_progress(
RestoreBackupEvent(
reason="", stage=None, state=RestoreBackupState.COMPLETED
restore_errors = data.get("errors", [])
if restore_errors:
_LOGGER.warning("Restore backup failed: %s", restore_errors)
# We should add more specific error handling here in the future
on_progress(
RestoreBackupEvent(
reason="unknown_error",
stage=None,
state=RestoreBackupState.FAILED,
)
)
else:
on_progress(
RestoreBackupEvent(
reason="", stage=None, state=RestoreBackupState.COMPLETED
)
)
)
on_progress(IdleEvent())
unsub()

View File

@ -1,6 +1,7 @@
"""Constants for the homee integration."""
from homeassistant.const import (
DEGREE,
LIGHT_LUX,
PERCENTAGE,
REVOLUTIONS_PER_MINUTE,
@ -32,6 +33,7 @@ HOMEE_UNIT_TO_HA_UNIT = {
"W": UnitOfPower.WATT,
"m/s": UnitOfSpeed.METERS_PER_SECOND,
"km/h": UnitOfSpeed.KILOMETERS_PER_HOUR,
"°": DEGREE,
"°F": UnitOfTemperature.FAHRENHEIT,
"°C": UnitOfTemperature.CELSIUS,
"K": UnitOfTemperature.KELVIN,
@ -51,7 +53,7 @@ OPEN_CLOSE_MAP_REVERSED = {
0.0: "closed",
1.0: "open",
2.0: "partial",
3.0: "cosing",
3.0: "closing",
4.0: "opening",
}
WINDOW_MAP = {

View File

@ -8,5 +8,5 @@
"documentation": "https://www.home-assistant.io/integrations/lcn",
"iot_class": "local_push",
"loggers": ["pypck"],
"requirements": ["pypck==0.8.3", "lcn-frontend==0.2.3"]
"requirements": ["pypck==0.8.5", "lcn-frontend==0.2.3"]
}

View File

@ -30,7 +30,7 @@ class MotionMountEntity(Entity):
self.config_entry = config_entry
# We store the pin, as we might need it during reconnect
self.pin = config_entry.data[CONF_PIN]
self.pin = config_entry.data.get(CONF_PIN)
mac = format_mac(mm.mac.hex())

View File

@ -5,34 +5,33 @@ from __future__ import annotations
from dataclasses import dataclass
import logging
from kiota_abstractions.api_error import APIError
from kiota_abstractions.authentication import BaseBearerTokenAuthenticationProvider
from msgraph import GraphRequestAdapter, GraphServiceClient
from msgraph.generated.drives.item.items.items_request_builder import (
ItemsRequestBuilder,
from onedrive_personal_sdk import OneDriveClient
from onedrive_personal_sdk.exceptions import (
AuthenticationError,
HttpRequestException,
OneDriveException,
)
from msgraph.generated.models.drive_item import DriveItem
from msgraph.generated.models.folder import Folder
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.config_entry_oauth2_flow import (
OAuth2Session,
async_get_config_entry_implementation,
)
from homeassistant.helpers.httpx_client import create_async_httpx_client
from homeassistant.helpers.instance_id import async_get as async_get_instance_id
from .api import OneDriveConfigEntryAccessTokenProvider
from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN, OAUTH_SCOPES
from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN
@dataclass
class OneDriveRuntimeData:
"""Runtime data for the OneDrive integration."""
items: ItemsRequestBuilder
client: OneDriveClient
token_provider: OneDriveConfigEntryAccessTokenProvider
backup_folder_id: str
@ -47,29 +46,18 @@ async def async_setup_entry(hass: HomeAssistant, entry: OneDriveConfigEntry) ->
session = OAuth2Session(hass, entry, implementation)
auth_provider = BaseBearerTokenAuthenticationProvider(
access_token_provider=OneDriveConfigEntryAccessTokenProvider(session)
)
adapter = GraphRequestAdapter(
auth_provider=auth_provider,
client=create_async_httpx_client(hass, follow_redirects=True),
)
token_provider = OneDriveConfigEntryAccessTokenProvider(session)
graph_client = GraphServiceClient(
request_adapter=adapter,
scopes=OAUTH_SCOPES,
)
assert entry.unique_id
drive_item = graph_client.drives.by_drive_id(entry.unique_id)
client = OneDriveClient(token_provider, async_get_clientsession(hass))
# get approot, will be created automatically if it does not exist
try:
approot = await drive_item.special.by_drive_item_id("approot").get()
except APIError as err:
if err.response_status_code == 403:
raise ConfigEntryAuthFailed(
translation_domain=DOMAIN, translation_key="authentication_failed"
) from err
approot = await client.get_approot()
except AuthenticationError as err:
raise ConfigEntryAuthFailed(
translation_domain=DOMAIN, translation_key="authentication_failed"
) from err
except (HttpRequestException, OneDriveException, TimeoutError) as err:
_LOGGER.debug("Failed to get approot", exc_info=True)
raise ConfigEntryNotReady(
translation_domain=DOMAIN,
@ -77,24 +65,24 @@ async def async_setup_entry(hass: HomeAssistant, entry: OneDriveConfigEntry) ->
translation_placeholders={"folder": "approot"},
) from err
if approot is None or not approot.id:
_LOGGER.debug("Failed to get approot, was None")
instance_id = await async_get_instance_id(hass)
backup_folder_name = f"backups_{instance_id[:8]}"
try:
backup_folder = await client.create_folder(
parent_id=approot.id, name=backup_folder_name
)
except (HttpRequestException, OneDriveException, TimeoutError) as err:
_LOGGER.debug("Failed to create backup folder", exc_info=True)
raise ConfigEntryNotReady(
translation_domain=DOMAIN,
translation_key="failed_to_get_folder",
translation_placeholders={"folder": "approot"},
)
instance_id = await async_get_instance_id(hass)
backup_folder_id = await _async_create_folder_if_not_exists(
items=drive_item.items,
base_folder_id=approot.id,
folder=f"backups_{instance_id[:8]}",
)
translation_placeholders={"folder": backup_folder_name},
) from err
entry.runtime_data = OneDriveRuntimeData(
items=drive_item.items,
backup_folder_id=backup_folder_id,
client=client,
token_provider=token_provider,
backup_folder_id=backup_folder.id,
)
_async_notify_backup_listeners_soon(hass)
@ -116,54 +104,3 @@ def _async_notify_backup_listeners(hass: HomeAssistant) -> None:
@callback
def _async_notify_backup_listeners_soon(hass: HomeAssistant) -> None:
hass.loop.call_soon(_async_notify_backup_listeners, hass)
async def _async_create_folder_if_not_exists(
items: ItemsRequestBuilder,
base_folder_id: str,
folder: str,
) -> str:
"""Check if a folder exists and create it if it does not exist."""
folder_item: DriveItem | None = None
try:
folder_item = await items.by_drive_item_id(f"{base_folder_id}:/{folder}:").get()
except APIError as err:
if err.response_status_code != 404:
_LOGGER.debug("Failed to get folder %s", folder, exc_info=True)
raise ConfigEntryNotReady(
translation_domain=DOMAIN,
translation_key="failed_to_get_folder",
translation_placeholders={"folder": folder},
) from err
# is 404 not found, create folder
_LOGGER.debug("Creating folder %s", folder)
request_body = DriveItem(
name=folder,
folder=Folder(),
additional_data={
"@microsoft_graph_conflict_behavior": "fail",
},
)
try:
folder_item = await items.by_drive_item_id(base_folder_id).children.post(
request_body
)
except APIError as create_err:
_LOGGER.debug("Failed to create folder %s", folder, exc_info=True)
raise ConfigEntryNotReady(
translation_domain=DOMAIN,
translation_key="failed_to_create_folder",
translation_placeholders={"folder": folder},
) from create_err
_LOGGER.debug("Created folder %s", folder)
else:
_LOGGER.debug("Found folder %s", folder)
if folder_item is None or not folder_item.id:
_LOGGER.debug("Failed to get folder %s, was None", folder)
raise ConfigEntryNotReady(
translation_domain=DOMAIN,
translation_key="failed_to_get_folder",
translation_placeholders={"folder": folder},
)
return folder_item.id

View File

@ -1,28 +1,14 @@
"""API for OneDrive bound to Home Assistant OAuth."""
from typing import Any, cast
from typing import cast
from kiota_abstractions.authentication import AccessTokenProvider, AllowedHostsValidator
from onedrive_personal_sdk import TokenProvider
from homeassistant.const import CONF_ACCESS_TOKEN
from homeassistant.helpers import config_entry_oauth2_flow
class OneDriveAccessTokenProvider(AccessTokenProvider):
"""Provide OneDrive authentication tied to an OAuth2 based config entry."""
def __init__(self) -> None:
"""Initialize OneDrive auth."""
super().__init__()
# currently allowing all hosts
self._allowed_hosts_validator = AllowedHostsValidator(allowed_hosts=[])
def get_allowed_hosts_validator(self) -> AllowedHostsValidator:
"""Retrieve the allowed hosts validator."""
return self._allowed_hosts_validator
class OneDriveConfigFlowAccessTokenProvider(OneDriveAccessTokenProvider):
class OneDriveConfigFlowAccessTokenProvider(TokenProvider):
"""Provide OneDrive authentication tied to an OAuth2 based config entry."""
def __init__(self, token: str) -> None:
@ -30,14 +16,12 @@ class OneDriveConfigFlowAccessTokenProvider(OneDriveAccessTokenProvider):
super().__init__()
self._token = token
async def get_authorization_token( # pylint: disable=dangerous-default-value
self, uri: str, additional_authentication_context: dict[str, Any] = {}
) -> str:
"""Return a valid authorization token."""
def async_get_access_token(self) -> str:
"""Return a valid access token."""
return self._token
class OneDriveConfigEntryAccessTokenProvider(OneDriveAccessTokenProvider):
class OneDriveConfigEntryAccessTokenProvider(TokenProvider):
"""Provide OneDrive authentication tied to an OAuth2 based config entry."""
def __init__(self, oauth_session: config_entry_oauth2_flow.OAuth2Session) -> None:
@ -45,9 +29,6 @@ class OneDriveConfigEntryAccessTokenProvider(OneDriveAccessTokenProvider):
super().__init__()
self._oauth_session = oauth_session
async def get_authorization_token( # pylint: disable=dangerous-default-value
self, uri: str, additional_authentication_context: dict[str, Any] = {}
) -> str:
"""Return a valid authorization token."""
await self._oauth_session.async_ensure_token_valid()
def async_get_access_token(self) -> str:
"""Return a valid access token."""
return cast(str, self._oauth_session.token[CONF_ACCESS_TOKEN])

View File

@ -2,37 +2,22 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator, Callable, Coroutine
from functools import wraps
import html
import json
import logging
from typing import Any, Concatenate, cast
from typing import Any, Concatenate
from httpx import Response, TimeoutException
from kiota_abstractions.api_error import APIError
from kiota_abstractions.authentication import AnonymousAuthenticationProvider
from kiota_abstractions.headers_collection import HeadersCollection
from kiota_abstractions.method import Method
from kiota_abstractions.native_response_handler import NativeResponseHandler
from kiota_abstractions.request_information import RequestInformation
from kiota_http.middleware.options import ResponseHandlerOption
from msgraph import GraphRequestAdapter
from msgraph.generated.drives.item.items.item.content.content_request_builder import (
ContentRequestBuilder,
from aiohttp import ClientTimeout
from onedrive_personal_sdk.clients.large_file_upload import LargeFileUploadClient
from onedrive_personal_sdk.exceptions import (
AuthenticationError,
HashMismatchError,
OneDriveException,
)
from msgraph.generated.drives.item.items.item.create_upload_session.create_upload_session_post_request_body import (
CreateUploadSessionPostRequestBody,
)
from msgraph.generated.drives.item.items.item.drive_item_item_request_builder import (
DriveItemItemRequestBuilder,
)
from msgraph.generated.models.drive_item import DriveItem
from msgraph.generated.models.drive_item_uploadable_properties import (
DriveItemUploadableProperties,
)
from msgraph_core.models import LargeFileUploadSession
from onedrive_personal_sdk.models.items import File, Folder, ItemUpdate
from onedrive_personal_sdk.models.upload import FileInfo
from homeassistant.components.backup import (
AgentBackup,
@ -41,14 +26,14 @@ from homeassistant.components.backup import (
suggested_filename,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.httpx_client import get_async_client
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from . import OneDriveConfigEntry
from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN
_LOGGER = logging.getLogger(__name__)
UPLOAD_CHUNK_SIZE = 16 * 320 * 1024 # 5.2MB
MAX_RETRIES = 5
TIMEOUT = ClientTimeout(connect=10, total=43200) # 12 hours
async def async_get_backup_agents(
@ -92,18 +77,18 @@ def handle_backup_errors[_R, **P](
) -> _R:
try:
return await func(self, *args, **kwargs)
except APIError as err:
if err.response_status_code == 403:
self._entry.async_start_reauth(self._hass)
except AuthenticationError as err:
self._entry.async_start_reauth(self._hass)
raise BackupAgentError("Authentication error") from err
except OneDriveException as err:
_LOGGER.error(
"Error during backup in %s: Status %s, message %s",
"Error during backup in %s:, message %s",
func.__name__,
err.response_status_code,
err.message,
err,
)
_LOGGER.debug("Full error: %s", err, exc_info=True)
raise BackupAgentError("Backup operation failed") from err
except TimeoutException as err:
except TimeoutError as err:
_LOGGER.error(
"Error during backup in %s: Timeout",
func.__name__,
@ -123,7 +108,8 @@ class OneDriveBackupAgent(BackupAgent):
super().__init__()
self._hass = hass
self._entry = entry
self._items = entry.runtime_data.items
self._client = entry.runtime_data.client
self._token_provider = entry.runtime_data.token_provider
self._folder_id = entry.runtime_data.backup_folder_id
self.name = entry.title
assert entry.unique_id
@ -134,24 +120,12 @@ class OneDriveBackupAgent(BackupAgent):
self, backup_id: str, **kwargs: Any
) -> AsyncIterator[bytes]:
"""Download a backup file."""
# this forces the query to return a raw httpx response, but breaks typing
backup = await self._find_item_by_backup_id(backup_id)
if backup is None or backup.id is None:
item = await self._find_item_by_backup_id(backup_id)
if item is None:
raise BackupAgentError("Backup not found")
request_config = (
ContentRequestBuilder.ContentRequestBuilderGetRequestConfiguration(
options=[ResponseHandlerOption(NativeResponseHandler())],
)
)
response = cast(
Response,
await self._items.by_drive_item_id(backup.id).content.get(
request_configuration=request_config
),
)
return response.aiter_bytes(chunk_size=1024)
stream = await self._client.download_drive_item(item.id, timeout=TIMEOUT)
return stream.iter_chunked(1024)
@handle_backup_errors
async def async_upload_backup(
@ -163,27 +137,20 @@ class OneDriveBackupAgent(BackupAgent):
) -> None:
"""Upload a backup."""
# upload file in chunks to support large files
upload_session_request_body = CreateUploadSessionPostRequestBody(
item=DriveItemUploadableProperties(
additional_data={
"@microsoft.graph.conflictBehavior": "fail",
},
file = FileInfo(
suggested_filename(backup),
backup.size,
self._folder_id,
await open_stream(),
)
try:
item = await LargeFileUploadClient.upload(
self._token_provider, file, session=async_get_clientsession(self._hass)
)
)
file_item = self._get_backup_file_item(suggested_filename(backup))
upload_session = await file_item.create_upload_session.post(
upload_session_request_body
)
if upload_session is None or upload_session.upload_url is None:
except HashMismatchError as err:
raise BackupAgentError(
translation_domain=DOMAIN, translation_key="backup_no_upload_session"
)
await self._upload_file(
upload_session.upload_url, await open_stream(), backup.size
)
"Hash validation failed, backup file might be corrupt"
) from err
# store metadata in description
backup_dict = backup.as_dict()
@ -191,7 +158,10 @@ class OneDriveBackupAgent(BackupAgent):
description = json.dumps(backup_dict)
_LOGGER.debug("Creating metadata: %s", description)
await file_item.patch(DriveItem(description=description))
await self._client.update_drive_item(
path_or_id=item.id,
data=ItemUpdate(description=description),
)
@handle_backup_errors
async def async_delete_backup(
@ -200,35 +170,31 @@ class OneDriveBackupAgent(BackupAgent):
**kwargs: Any,
) -> None:
"""Delete a backup file."""
backup = await self._find_item_by_backup_id(backup_id)
if backup is None or backup.id is None:
item = await self._find_item_by_backup_id(backup_id)
if item is None:
return
await self._items.by_drive_item_id(backup.id).delete()
await self._client.delete_drive_item(item.id)
@handle_backup_errors
async def async_list_backups(self, **kwargs: Any) -> list[AgentBackup]:
"""List backups."""
backups: list[AgentBackup] = []
items = await self._items.by_drive_item_id(f"{self._folder_id}").children.get()
if items and (values := items.value):
for item in values:
if (description := item.description) is None:
continue
if "homeassistant_version" in description:
backups.append(self._backup_from_description(description))
return backups
return [
self._backup_from_description(item.description)
for item in await self._client.list_drive_items(self._folder_id)
if item.description and "homeassistant_version" in item.description
]
@handle_backup_errors
async def async_get_backup(
self, backup_id: str, **kwargs: Any
) -> AgentBackup | None:
"""Return a backup."""
backup = await self._find_item_by_backup_id(backup_id)
if backup is None:
return None
assert backup.description # already checked in _find_item_by_backup_id
return self._backup_from_description(backup.description)
item = await self._find_item_by_backup_id(backup_id)
return (
self._backup_from_description(item.description)
if item and item.description
else None
)
def _backup_from_description(self, description: str) -> AgentBackup:
"""Create a backup object from a description."""
@ -237,91 +203,13 @@ class OneDriveBackupAgent(BackupAgent):
) # OneDrive encodes the description on save automatically
return AgentBackup.from_dict(json.loads(description))
async def _find_item_by_backup_id(self, backup_id: str) -> DriveItem | None:
"""Find a backup item by its backup ID."""
items = await self._items.by_drive_item_id(f"{self._folder_id}").children.get()
if items and (values := items.value):
for item in values:
if (description := item.description) is None:
continue
if backup_id in description:
return item
return None
def _get_backup_file_item(self, backup_id: str) -> DriveItemItemRequestBuilder:
return self._items.by_drive_item_id(f"{self._folder_id}:/{backup_id}:")
async def _upload_file(
self, upload_url: str, stream: AsyncIterator[bytes], total_size: int
) -> None:
"""Use custom large file upload; SDK does not support stream."""
adapter = GraphRequestAdapter(
auth_provider=AnonymousAuthenticationProvider(),
client=get_async_client(self._hass),
async def _find_item_by_backup_id(self, backup_id: str) -> File | Folder | None:
"""Find an item by backup ID."""
return next(
(
item
for item in await self._client.list_drive_items(self._folder_id)
if item.description and backup_id in item.description
),
None,
)
async def async_upload(
start: int, end: int, chunk_data: bytes
) -> LargeFileUploadSession:
info = RequestInformation()
info.url = upload_url
info.http_method = Method.PUT
info.headers = HeadersCollection()
info.headers.try_add("Content-Range", f"bytes {start}-{end}/{total_size}")
info.headers.try_add("Content-Length", str(len(chunk_data)))
info.headers.try_add("Content-Type", "application/octet-stream")
_LOGGER.debug(info.headers.get_all())
info.set_stream_content(chunk_data)
result = await adapter.send_async(info, LargeFileUploadSession, {})
_LOGGER.debug("Next expected range: %s", result.next_expected_ranges)
return result
start = 0
buffer: list[bytes] = []
buffer_size = 0
retries = 0
async for chunk in stream:
buffer.append(chunk)
buffer_size += len(chunk)
if buffer_size >= UPLOAD_CHUNK_SIZE:
chunk_data = b"".join(buffer)
uploaded_chunks = 0
while (
buffer_size > UPLOAD_CHUNK_SIZE
): # Loop in case the buffer is >= UPLOAD_CHUNK_SIZE * 2
slice_start = uploaded_chunks * UPLOAD_CHUNK_SIZE
try:
await async_upload(
start,
start + UPLOAD_CHUNK_SIZE - 1,
chunk_data[slice_start : slice_start + UPLOAD_CHUNK_SIZE],
)
except APIError as err:
if (
err.response_status_code and err.response_status_code < 500
): # no retry on 4xx errors
raise
if retries < MAX_RETRIES:
await asyncio.sleep(2**retries)
retries += 1
continue
raise
except TimeoutException:
if retries < MAX_RETRIES:
retries += 1
continue
raise
retries = 0
start += UPLOAD_CHUNK_SIZE
uploaded_chunks += 1
buffer_size -= UPLOAD_CHUNK_SIZE
buffer = [chunk_data[UPLOAD_CHUNK_SIZE * uploaded_chunks :]]
# upload the remaining bytes
if buffer:
_LOGGER.debug("Last chunk")
chunk_data = b"".join(buffer)
await async_upload(start, start + len(chunk_data) - 1, chunk_data)

View File

@ -4,16 +4,13 @@ from collections.abc import Mapping
import logging
from typing import Any, cast
from kiota_abstractions.api_error import APIError
from kiota_abstractions.authentication import BaseBearerTokenAuthenticationProvider
from kiota_abstractions.method import Method
from kiota_abstractions.request_information import RequestInformation
from msgraph import GraphRequestAdapter, GraphServiceClient
from onedrive_personal_sdk.clients.client import OneDriveClient
from onedrive_personal_sdk.exceptions import OneDriveException
from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_TOKEN
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.config_entry_oauth2_flow import AbstractOAuth2FlowHandler
from homeassistant.helpers.httpx_client import get_async_client
from .api import OneDriveConfigFlowAccessTokenProvider
from .const import DOMAIN, OAUTH_SCOPES
@ -39,48 +36,24 @@ class OneDriveConfigFlow(AbstractOAuth2FlowHandler, domain=DOMAIN):
data: dict[str, Any],
) -> ConfigFlowResult:
"""Handle the initial step."""
auth_provider = BaseBearerTokenAuthenticationProvider(
access_token_provider=OneDriveConfigFlowAccessTokenProvider(
cast(str, data[CONF_TOKEN][CONF_ACCESS_TOKEN])
)
)
adapter = GraphRequestAdapter(
auth_provider=auth_provider,
client=get_async_client(self.hass),
token_provider = OneDriveConfigFlowAccessTokenProvider(
cast(str, data[CONF_TOKEN][CONF_ACCESS_TOKEN])
)
graph_client = GraphServiceClient(
request_adapter=adapter,
scopes=OAUTH_SCOPES,
graph_client = OneDriveClient(
token_provider, async_get_clientsession(self.hass)
)
# need to get adapter from client, as client changes it
request_adapter = cast(GraphRequestAdapter, graph_client.request_adapter)
request_info = RequestInformation(
method=Method.GET,
url_template="{+baseurl}/me/drive/special/approot",
path_parameters={},
)
parent_span = request_adapter.start_tracing_span(request_info, "get_approot")
# get the OneDrive id
# use low level methods, to avoid files.read permissions
# which would be required by drives.me.get()
try:
response = await request_adapter.get_http_response_message(
request_info=request_info, parent_span=parent_span
)
except APIError:
approot = await graph_client.get_approot()
except OneDriveException:
self.logger.exception("Failed to connect to OneDrive")
return self.async_abort(reason="connection_error")
except Exception:
self.logger.exception("Unknown error")
return self.async_abort(reason="unknown")
drive: dict = response.json()
await self.async_set_unique_id(drive["parentReference"]["driveId"])
await self.async_set_unique_id(approot.parent_reference.drive_id)
if self.source == SOURCE_REAUTH:
reauth_entry = self._get_reauth_entry()
@ -94,10 +67,11 @@ class OneDriveConfigFlow(AbstractOAuth2FlowHandler, domain=DOMAIN):
self._abort_if_unique_id_configured()
user = drive.get("createdBy", {}).get("user", {}).get("displayName")
title = f"{user}'s OneDrive" if user else "OneDrive"
title = (
f"{approot.created_by.user.display_name}'s OneDrive"
if approot.created_by.user and approot.created_by.user.display_name
else "OneDrive"
)
return self.async_create_entry(title=title, data=data)
async def async_step_reauth(

View File

@ -7,7 +7,7 @@
"documentation": "https://www.home-assistant.io/integrations/onedrive",
"integration_type": "service",
"iot_class": "cloud_polling",
"loggers": ["msgraph", "msgraph-core", "kiota"],
"loggers": ["onedrive_personal_sdk"],
"quality_scale": "bronze",
"requirements": ["msgraph-sdk==1.16.0"]
"requirements": ["onedrive-personal-sdk==0.0.1"]
}

View File

@ -23,31 +23,18 @@
"connection_error": "Failed to connect to OneDrive.",
"wrong_drive": "New account does not contain previously configured OneDrive.",
"unknown": "[%key:common::config_flow::error::unknown%]",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]",
"failed_to_create_folder": "Failed to create backup folder"
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
},
"create_entry": {
"default": "[%key:common::config_flow::create_entry::authenticated%]"
}
},
"exceptions": {
"backup_not_found": {
"message": "Backup not found"
},
"backup_no_content": {
"message": "Backup has no content"
},
"backup_no_upload_session": {
"message": "Failed to start backup upload"
},
"authentication_failed": {
"message": "Authentication failed"
},
"failed_to_get_folder": {
"message": "Failed to get {folder} folder"
},
"failed_to_create_folder": {
"message": "Failed to create {folder} folder"
}
}
}

View File

@ -24,6 +24,7 @@ from homeassistant.helpers.selector import (
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
SelectSelectorMode,
TemplateSelector,
)
from homeassistant.helpers.typing import VolDictType
@ -32,14 +33,17 @@ from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_REASONING_EFFORT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TOP_P,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
UNSUPPORTED_MODELS,
)
_LOGGER = logging.getLogger(__name__)
@ -124,26 +128,32 @@ class OpenAIOptionsFlow(OptionsFlow):
) -> ConfigFlowResult:
"""Manage the options."""
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
errors: dict[str, str] = {}
if user_input is not None:
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
if user_input.get(CONF_CHAT_MODEL) in UNSUPPORTED_MODELS:
errors[CONF_CHAT_MODEL] = "model_not_supported"
else:
return self.async_create_entry(title="", data=user_input)
else:
# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
options = {
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
CONF_PROMPT: user_input[CONF_PROMPT],
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
}
options = {
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
CONF_PROMPT: user_input[CONF_PROMPT],
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
}
schema = openai_config_option_schema(self.hass, options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
errors=errors,
)
@ -210,6 +220,17 @@ def openai_config_option_schema(
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
vol.Optional(
CONF_REASONING_EFFORT,
description={"suggested_value": options.get(CONF_REASONING_EFFORT)},
default=RECOMMENDED_REASONING_EFFORT,
): SelectSelector(
SelectSelectorConfig(
options=["low", "medium", "high"],
translation_key="reasoning_effort",
mode=SelectSelectorMode.DROPDOWN,
)
),
}
)
return schema

View File

@ -15,3 +15,17 @@ CONF_TOP_P = "top_p"
RECOMMENDED_TOP_P = 1.0
CONF_TEMPERATURE = "temperature"
RECOMMENDED_TEMPERATURE = 1.0
CONF_REASONING_EFFORT = "reasoning_effort"
RECOMMENDED_REASONING_EFFORT = "low"
UNSUPPORTED_MODELS = [
"o1-mini",
"o1-mini-2024-09-12",
"o1-preview",
"o1-preview-2024-09-12",
"gpt-4o-realtime-preview",
"gpt-4o-realtime-preview-2024-12-17",
"gpt-4o-realtime-preview-2024-10-01",
"gpt-4o-mini-realtime-preview",
"gpt-4o-mini-realtime-preview-2024-12-17",
]

View File

@ -31,12 +31,14 @@ from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_REASONING_EFFORT,
CONF_TEMPERATURE,
CONF_TOP_P,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
)
@ -68,7 +70,9 @@ def _format_tool(
return ChatCompletionToolParam(type="function", function=tool_spec)
def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
def _convert_message_to_param(
message: ChatCompletionMessage,
) -> ChatCompletionMessageParam:
"""Convert from class to TypedDict."""
tool_calls: list[ChatCompletionMessageToolCallParam] = []
if message.tool_calls:
@ -92,17 +96,42 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
return param
def _chat_message_convert(
message: conversation.Content
| conversation.NativeContent[ChatCompletionMessageParam],
def _convert_content_to_param(
content: conversation.Content,
) -> ChatCompletionMessageParam:
"""Convert any native chat message for this agent to the native format."""
if message.role == "native":
# mypy doesn't understand that checking role ensures content type
return message.content # type: ignore[return-value]
return cast(
ChatCompletionMessageParam,
{"role": message.role, "content": message.content},
if content.role == "tool_result":
assert type(content) is conversation.ToolResultContent
return ChatCompletionToolMessageParam(
role="tool",
tool_call_id=content.tool_call_id,
content=json.dumps(content.tool_result),
)
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
role = content.role
if role == "system":
role = "developer"
return cast(
ChatCompletionMessageParam,
{"role": content.role, "content": content.content}, # type: ignore[union-attr]
)
# Handle the Assistant content including tool calls.
assert type(content) is conversation.AssistantContent
return ChatCompletionAssistantMessageParam(
role="assistant",
content=content.content,
tool_calls=[
ChatCompletionMessageToolCallParam(
id=tool_call.id,
function=Function(
arguments=json.dumps(tool_call.tool_args),
name=tool_call.tool_name,
),
type="function",
)
for tool_call in content.tool_calls
],
)
@ -166,14 +195,14 @@ class OpenAIConversationEntity(
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
session: conversation.ChatLog[ChatCompletionMessageParam],
chat_log: conversation.ChatLog,
) -> conversation.ConversationResult:
"""Call the API."""
assert user_input.agent_id
options = self.entry.options
try:
await session.async_update_llm_data(
await chat_log.async_update_llm_data(
DOMAIN,
user_input,
options.get(CONF_LLM_HASS_API),
@ -183,73 +212,77 @@ class OpenAIConversationEntity(
return err.as_conversation_result()
tools: list[ChatCompletionToolParam] | None = None
if session.llm_api:
if chat_log.llm_api:
tools = [
_format_tool(tool, session.llm_api.custom_serializer)
for tool in session.llm_api.tools
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
messages = [
_chat_message_convert(message) for message in session.async_get_messages()
]
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
messages = [_convert_content_to_param(content) for content in chat_log.content]
client = self.entry.runtime_data
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
result = await client.chat.completions.create(
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages,
tools=tools or NOT_GIVEN,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
user=session.conversation_id,
model_args = {
"model": model,
"messages": messages,
"tools": tools or NOT_GIVEN,
"max_completion_tokens": options.get(
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
),
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": chat_log.conversation_id,
}
if model.startswith("o"):
model_args["reasoning_effort"] = options.get(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
)
try:
result = await client.chat.completions.create(**model_args)
except openai.OpenAIError as err:
LOGGER.error("Error talking to OpenAI: %s", err)
raise HomeAssistantError("Error talking to OpenAI") from err
LOGGER.debug("Response %s", result)
response = result.choices[0].message
messages.append(_message_convert(response))
messages.append(_convert_message_to_param(response))
session.async_add_message(
conversation.Content(
role=response.role,
agent_id=user_input.agent_id,
content=response.content or "",
),
tool_calls: list[llm.ToolInput] | None = None
if response.tool_calls:
tool_calls = [
llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments),
)
for tool_call in response.tool_calls
]
messages.extend(
[
_convert_content_to_param(tool_response)
async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=response.content or "",
tool_calls=tool_calls,
)
)
]
)
if not response.tool_calls or not session.llm_api:
if not tool_calls:
break
for tool_call in response.tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments),
)
tool_response = await session.async_call_tool(tool_input)
messages.append(
ChatCompletionToolMessageParam(
role="tool",
tool_call_id=tool_call.id,
content=json.dumps(tool_response),
)
)
session.async_add_message(
conversation.NativeContent(
agent_id=user_input.agent_id,
content=messages[-1],
)
)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response.content or "")
return conversation.ConversationResult(
response=intent_response, conversation_id=session.conversation_id
response=intent_response, conversation_id=chat_log.conversation_id
)
async def _async_entry_update_listener(

View File

@ -23,12 +23,26 @@
"temperature": "Temperature",
"top_p": "Top P",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"recommended": "Recommended model settings"
"recommended": "Recommended model settings",
"reasoning_effort": "Reasoning effort"
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template."
"prompt": "Instruct how the LLM should respond. This can be a template.",
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt (for certain reasoning models)"
}
}
},
"error": {
"model_not_supported": "This model is not supported, please select a different model"
}
},
"selector": {
"reasoning_effort": {
"options": {
"low": "Low",
"medium": "Medium",
"high": "High"
}
}
},
"services": {

View File

@ -18,7 +18,13 @@ from homeassistant.const import (
STATE_OFF,
STATE_ON,
)
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.core import (
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
callback,
)
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.collection import (
CollectionEntity,
@ -44,6 +50,7 @@ from .const import (
CONF_TO,
DOMAIN,
LOGGER,
SERVICE_GET,
WEEKDAY_TO_CONF,
)
@ -205,6 +212,14 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
reload_service_handler,
)
component.async_register_entity_service(
SERVICE_GET,
{},
async_get_schedule_service,
supports_response=SupportsResponse.ONLY,
)
await component.async_setup(config)
return True
@ -296,6 +311,10 @@ class Schedule(CollectionEntity):
self.async_on_remove(self._clean_up_listener)
self._update()
def get_schedule(self) -> ConfigType:
"""Return the schedule."""
return {d: self._config[d] for d in WEEKDAY_TO_CONF.values()}
@callback
def _update(self, _: datetime | None = None) -> None:
"""Update the states of the schedule."""
@ -390,3 +409,10 @@ class Schedule(CollectionEntity):
data_keys.update(time_range_custom_data.keys())
return frozenset(data_keys)
async def async_get_schedule_service(
schedule: Schedule, service_call: ServiceCall
) -> ServiceResponse:
"""Return the schedule configuration."""
return schedule.get_schedule()

View File

@ -37,3 +37,5 @@ WEEKDAY_TO_CONF: Final = {
5: CONF_SATURDAY,
6: CONF_SUNDAY,
}
SERVICE_GET: Final = "get_schedule"

View File

@ -2,6 +2,9 @@
"services": {
"reload": {
"service": "mdi:reload"
},
"get_schedule": {
"service": "mdi:calendar-export"
}
}
}

View File

@ -1 +1,5 @@
reload:
get_schedule:
target:
entity:
domain: schedule

View File

@ -25,6 +25,10 @@
"reload": {
"name": "[%key:common::action::reload%]",
"description": "Reloads schedules from the YAML-configuration."
},
"get_schedule": {
"name": "Get schedule",
"description": "Retrieve one or multiple schedules."
}
}
}

View File

@ -1,16 +1,16 @@
{
"config": {
"flow_title": "Add Shark IQ Account",
"flow_title": "Add Shark IQ account",
"step": {
"user": {
"description": "Sign into your Shark Clean account to control your devices.",
"description": "Sign into your SharkClean account to control your devices.",
"data": {
"username": "[%key:common::config_flow::data::username%]",
"password": "[%key:common::config_flow::data::password%]",
"region": "Region"
},
"data_description": {
"region": "Shark IQ uses different services in the EU. Select your region to connect to the correct service for your account."
"region": "Shark IQ uses different services in the EU. Select your region to connect to the correct service for your account."
}
},
"reauth_confirm": {
@ -37,18 +37,18 @@
"region": {
"options": {
"europe": "Europe",
"elsewhere": "Everywhere Else"
"elsewhere": "Everywhere else"
}
}
},
"exceptions": {
"invalid_room": {
"message": "The room {room} is unavailable to your vacuum. Make sure all rooms match the Shark App, including capitalization."
"message": "The room {room} is unavailable to your vacuum. Make sure all rooms match the SharkClean app, including capitalization."
}
},
"services": {
"clean_room": {
"name": "Clean Room",
"name": "Clean room",
"description": "Cleans a specific user-defined room or set of rooms.",
"fields": {
"rooms": {

View File

@ -272,6 +272,18 @@ RPC_SENSORS: Final = {
entity_category=EntityCategory.DIAGNOSTIC,
entity_class=RpcBluTrvBinarySensor,
),
"flood": RpcBinarySensorDescription(
key="flood",
sub_key="alarm",
name="Flood",
device_class=BinarySensorDeviceClass.MOISTURE,
),
"mute": RpcBinarySensorDescription(
key="flood",
sub_key="mute",
name="Mute",
entity_category=EntityCategory.DIAGNOSTIC,
),
}

View File

@ -14,6 +14,7 @@ from homeassistant.config_entries import SOURCE_USER, ConfigFlow, ConfigFlowResu
from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PASSWORD, CONF_USERNAME
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.device_registry import format_mac
from homeassistant.helpers.service_info.dhcp import DhcpServiceInfo
from homeassistant.helpers.service_info.zeroconf import ZeroconfServiceInfo
from .const import DOMAIN
@ -35,7 +36,8 @@ STEP_AUTH_DATA_SCHEMA = vol.Schema(
class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for SMLIGHT Zigbee."""
host: str
_host: str
_device_name: str
client: Api2
async def async_step_user(
@ -45,11 +47,13 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
errors: dict[str, str] = {}
if user_input is not None:
self.host = user_input[CONF_HOST]
self.client = Api2(self.host, session=async_get_clientsession(self.hass))
self._host = user_input[CONF_HOST]
self.client = Api2(self._host, session=async_get_clientsession(self.hass))
try:
info = await self.client.get_info()
self._host = str(info.device_ip)
self._device_name = str(info.hostname)
if info.model not in Devices:
return self.async_abort(reason="unsupported_device")
@ -93,15 +97,14 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
self, discovery_info: ZeroconfServiceInfo
) -> ConfigFlowResult:
"""Handle a discovered Lan coordinator."""
local_name = discovery_info.hostname[:-1]
node_name = local_name.removesuffix(".local")
mac: str | None = discovery_info.properties.get("mac")
self._device_name = discovery_info.hostname.removesuffix(".local.")
self._host = discovery_info.host
self.host = local_name
self.context["title_placeholders"] = {CONF_NAME: node_name}
self.client = Api2(self.host, session=async_get_clientsession(self.hass))
self.context["title_placeholders"] = {CONF_NAME: self._device_name}
self.client = Api2(self._host, session=async_get_clientsession(self.hass))
mac = discovery_info.properties.get("mac")
# fallback for legacy firmware
# fallback for legacy firmware older than v2.3.x
if mac is None:
try:
info = await self.client.get_info()
@ -111,7 +114,7 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
mac = info.MAC
await self.async_set_unique_id(format_mac(mac))
self._abort_if_unique_id_configured()
self._abort_if_unique_id_configured(updates={CONF_HOST: self._host})
return await self.async_step_confirm_discovery()
@ -122,7 +125,6 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
errors: dict[str, str] = {}
if user_input is not None:
user_input[CONF_HOST] = self.host
try:
info = await self.client.get_info()
@ -142,7 +144,7 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_show_form(
step_id="confirm_discovery",
description_placeholders={"host": self.host},
description_placeholders={"host": self._device_name},
errors=errors,
)
@ -151,8 +153,8 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
) -> ConfigFlowResult:
"""Handle reauth when API Authentication failed."""
self.host = entry_data[CONF_HOST]
self.client = Api2(self.host, session=async_get_clientsession(self.hass))
self._host = entry_data[CONF_HOST]
self.client = Api2(self._host, session=async_get_clientsession(self.hass))
return await self.async_step_reauth_confirm()
@ -182,6 +184,16 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
errors=errors,
)
async def async_step_dhcp(
self, discovery_info: DhcpServiceInfo
) -> ConfigFlowResult:
"""Handle DHCP discovery."""
await self.async_set_unique_id(format_mac(discovery_info.macaddress))
self._abort_if_unique_id_configured(updates={CONF_HOST: discovery_info.ip})
# This should never happen since we only listen to DHCP requests
# for configured devices.
return self.async_abort(reason="already_configured")
async def _async_check_auth_required(self, user_input: dict[str, Any]) -> bool:
"""Check if auth required and attempt to authenticate."""
if await self.client.check_auth_needed():
@ -200,11 +212,10 @@ class SmlightConfigFlow(ConfigFlow, domain=DOMAIN):
await self.async_set_unique_id(
format_mac(info.MAC), raise_on_progress=self.source != SOURCE_USER
)
self._abort_if_unique_id_configured()
self._abort_if_unique_id_configured(updates={CONF_HOST: self._host})
if user_input.get(CONF_HOST) is None:
user_input[CONF_HOST] = self.host
user_input[CONF_HOST] = self._host
assert info.model is not None
title = self.context.get("title_placeholders", {}).get(CONF_NAME) or info.model
title = self._device_name or info.model
return self.async_create_entry(title=title, data=user_input)

View File

@ -3,10 +3,15 @@
"name": "SMLIGHT SLZB",
"codeowners": ["@tl-sl"],
"config_flow": true,
"dhcp": [
{
"registered_devices": true
}
],
"documentation": "https://www.home-assistant.io/integrations/smlight",
"integration_type": "device",
"iot_class": "local_push",
"requirements": ["pysmlight==0.2.1"],
"requirements": ["pysmlight==0.2.2"],
"zeroconf": [
{
"type": "_slzb-06._tcp.local."

View File

@ -65,6 +65,7 @@ BINARY_SENSORS = [
key="currently_obstructed",
translation_key="currently_obstructed",
device_class=BinarySensorDeviceClass.PROBLEM,
entity_category=EntityCategory.DIAGNOSTIC,
value_fn=lambda data: data.status["currently_obstructed"],
),
StarlinkBinarySensorEntityDescription(
@ -114,4 +115,9 @@ BINARY_SENSORS = [
entity_category=EntityCategory.DIAGNOSTIC,
value_fn=lambda data: data.alert["alert_unexpected_location"],
),
StarlinkBinarySensorEntityDescription(
key="connection",
device_class=BinarySensorDeviceClass.CONNECTIVITY,
value_fn=lambda data: data.status["state"] == "CONNECTED",
),
]

View File

@ -14,7 +14,7 @@
},
"reconfigure": {
"title": "Reconfigure your Tado",
"description": "Reconfigure the entry, for your account: `{username}`.",
"description": "Reconfigure the entry for your account: `{username}`.",
"data": {
"password": "[%key:common::config_flow::data::password%]"
},
@ -25,7 +25,7 @@
},
"error": {
"unknown": "[%key:common::config_flow::error::unknown%]",
"no_homes": "There are no homes linked to this tado account.",
"no_homes": "There are no homes linked to this Tado account.",
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
}
@ -33,7 +33,7 @@
"options": {
"step": {
"init": {
"description": "Fallback mode lets you choose when to fallback to Smart Schedule from your manual zone overlay. (NEXT_TIME_BLOCK:= Change at next Smart Schedule change; MANUAL:= Dont change until you cancel; TADO_DEFAULT:= Change based on your setting in Tado App).",
"description": "Fallback mode lets you choose when to fallback to Smart Schedule from your manual zone overlay. (NEXT_TIME_BLOCK:= Change at next Smart Schedule change; MANUAL:= Don't change until you cancel; TADO_DEFAULT:= Change based on your setting in the Tado app).",
"data": {
"fallback": "Choose fallback mode."
},
@ -102,11 +102,11 @@
},
"time_period": {
"name": "Time period",
"description": "Choose this or Overlay. Set the time period for the change if you want to be specific. Alternatively use Overlay."
"description": "Choose this or 'Overlay'. Set the time period for the change if you want to be specific."
},
"requested_overlay": {
"name": "Overlay",
"description": "Choose this or Time Period. Allows you to choose an overlay. MANUAL:=Overlay until user removes; NEXT_TIME_BLOCK:=Overlay until next timeblock; TADO_DEFAULT:=Overlay based on tado app setting."
"description": "Choose this or 'Time period'. Allows you to choose an overlay. MANUAL:=Overlay until user removes; NEXT_TIME_BLOCK:=Overlay until next timeblock; TADO_DEFAULT:=Overlay based on Tado app setting."
}
}
},
@ -151,8 +151,8 @@
},
"issues": {
"water_heater_fallback": {
"title": "Tado Water Heater entities now support fallback options",
"description": "Due to added support for water heaters entities, these entities may use different overlay. Please configure integration entity and tado app water heater zone overlay options. Otherwise, please configure the integration entity and Tado app water heater zone overlay options (under Settings -> Rooms & Devices -> Hot Water)."
"title": "Tado water heater entities now support fallback options",
"description": "Due to added support for water heaters entities, these entities may use a different overlay. Please configure the integration entity and Tado app water heater zone overlay options (under Settings -> Rooms & Devices -> Hot Water)."
}
}
}

View File

@ -7,6 +7,7 @@ from pyvesync import VeSync
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_PASSWORD, CONF_USERNAME, Platform
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.dispatcher import async_dispatcher_send
from .common import async_generate_device_list
@ -91,3 +92,37 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.data.pop(DOMAIN)
return unload_ok
async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Migrate old entry."""
_LOGGER.debug(
"Migrating VeSync config entry: %s minor version: %s",
config_entry.version,
config_entry.minor_version,
)
if config_entry.minor_version == 1:
# Migrate switch/outlets entity to a new unique ID
_LOGGER.debug("Migrating VeSync config entry from version 1 to version 2")
entity_registry = er.async_get(hass)
registry_entries = er.async_entries_for_config_entry(
entity_registry, config_entry.entry_id
)
for reg_entry in registry_entries:
if "-" not in reg_entry.unique_id and reg_entry.entity_id.startswith(
Platform.SWITCH
):
_LOGGER.debug(
"Migrating switch/outlet entity from unique_id: %s to unique_id: %s",
reg_entry.unique_id,
reg_entry.unique_id + "-device_status",
)
entity_registry.async_update_entity(
reg_entry.entity_id,
new_unique_id=reg_entry.unique_id + "-device_status",
)
else:
_LOGGER.debug("Skipping entity with unique_id: %s", reg_entry.unique_id)
hass.config_entries.async_update_entry(config_entry, minor_version=2)
return True

View File

@ -24,6 +24,7 @@ class VeSyncFlowHandler(ConfigFlow, domain=DOMAIN):
"""Handle a config flow."""
VERSION = 1
MINOR_VERSION = 2
@callback
def _show_form(self, errors: dict[str, str] | None = None) -> ConfigFlowResult:

View File

@ -12,5 +12,5 @@
"documentation": "https://www.home-assistant.io/integrations/vesync",
"iot_class": "cloud_polling",
"loggers": ["pyvesync"],
"requirements": ["pyvesync==2.1.16"]
"requirements": ["pyvesync==2.1.17"]
}

View File

@ -83,6 +83,7 @@ class VeSyncSwitchHA(VeSyncBaseSwitch, SwitchEntity):
) -> None:
"""Initialize the VeSync switch device."""
super().__init__(plug, coordinator)
self._attr_unique_id = f"{super().unique_id}-device_status"
self.smartplug = plug
@ -94,4 +95,5 @@ class VeSyncLightSwitch(VeSyncBaseSwitch, SwitchEntity):
) -> None:
"""Initialize Light Switch device class."""
super().__init__(switch, coordinator)
self._attr_unique_id = f"{super().unique_id}-device_status"
self.switch = switch

View File

@ -616,6 +616,10 @@ DHCP: Final[list[dict[str, str | bool]]] = [
"hostname": "hub*",
"macaddress": "286D97*",
},
{
"domain": "smlight",
"registered_devices": True,
},
{
"domain": "solaredge",
"hostname": "target",

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from dataclasses import dataclass, field as dc_field
from datetime import timedelta
from decimal import Decimal
from enum import Enum
@ -36,6 +36,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import dt as dt_util, yaml as yaml_util
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType
from homeassistant.util.ulid import ulid_now
from . import (
area_registry as ar,
@ -139,6 +140,8 @@ class ToolInput:
tool_name: str
tool_args: dict[str, Any]
# Using lambda for default to allow patching in tests
id: str = dc_field(default_factory=lambda: ulid_now()) # pylint: disable=unnecessary-lambda
class Tool:

View File

@ -29,7 +29,7 @@ certifi>=2021.5.30
ciso8601==2.3.2
cronsim==2.6
cryptography==44.0.0
dbus-fast==2.31.0
dbus-fast==2.32.0
fnv-hash-fast==1.2.2
go2rtc-client==0.1.2
ha-ffmpeg==3.2.2
@ -53,7 +53,7 @@ psutil-home-assistant==0.0.1
PyJWT==2.10.1
pymicro-vad==1.0.1
PyNaCl==1.5.0
pyOpenSSL==24.3.0
pyOpenSSL==25.0.0
pyserial==3.5
pyspeex-noise==1.0.2
python-slugify==8.0.4

View File

@ -59,7 +59,7 @@ dependencies = [
"cryptography==44.0.0",
"Pillow==11.1.0",
"propcache==0.2.1",
"pyOpenSSL==24.3.0",
"pyOpenSSL==25.0.0",
"orjson==3.10.12",
"packaging>=23.1",
"psutil-home-assistant==0.0.1",

2
requirements.txt generated
View File

@ -31,7 +31,7 @@ PyJWT==2.10.1
cryptography==44.0.0
Pillow==11.1.0
propcache==0.2.1
pyOpenSSL==24.3.0
pyOpenSSL==25.0.0
orjson==3.10.12
packaging>=23.1
psutil-home-assistant==0.0.1

18
requirements_all.txt generated
View File

@ -600,7 +600,7 @@ bizkaibus==0.1.1
# homeassistant.components.eq3btsmart
# homeassistant.components.esphome
bleak-esphome==2.6.0
bleak-esphome==2.7.0
# homeassistant.components.bluetooth
bleak-retry-connector==3.8.0
@ -741,7 +741,7 @@ datadog==0.15.0
datapoint==0.9.9
# homeassistant.components.bluetooth
dbus-fast==2.31.0
dbus-fast==2.32.0
# homeassistant.components.debugpy
debugpy==1.8.11
@ -1437,9 +1437,6 @@ motioneye-client==0.3.14
# homeassistant.components.bang_olufsen
mozart-api==4.1.1.116.4
# homeassistant.components.onedrive
msgraph-sdk==1.16.0
# homeassistant.components.mullvad
mullvad-api==1.0.0
@ -1561,6 +1558,9 @@ omnilogic==0.4.5
# homeassistant.components.ondilo_ico
ondilo==0.5.0
# homeassistant.components.onedrive
onedrive-personal-sdk==0.0.1
# homeassistant.components.onvif
onvif-zeep-async==3.2.5
@ -2205,7 +2205,7 @@ pypalazzetti==0.1.19
pypca==0.0.7
# homeassistant.components.lcn
pypck==0.8.3
pypck==0.8.5
# homeassistant.components.pjlink
pypjlink2==1.2.1
@ -2313,7 +2313,7 @@ pysmarty2==0.10.1
pysml==0.0.12
# homeassistant.components.smlight
pysmlight==0.2.1
pysmlight==0.2.2
# homeassistant.components.snmp
pysnmp==6.2.6
@ -2391,7 +2391,7 @@ python-gitlab==1.6.0
python-google-drive-api==0.0.2
# homeassistant.components.analytics_insights
python-homeassistant-analytics==0.8.1
python-homeassistant-analytics==0.9.0
# homeassistant.components.homewizard
python-homewizard-energy==v8.3.2
@ -2516,7 +2516,7 @@ pyvera==0.3.15
pyversasense==0.0.6
# homeassistant.components.vesync
pyvesync==2.1.16
pyvesync==2.1.17
# homeassistant.components.vizio
pyvizio==0.1.61

View File

@ -8,31 +8,31 @@
-c homeassistant/package_constraints.txt
-r requirements_test_pre_commit.txt
astroid==3.3.8
coverage==7.6.8
coverage==7.6.10
freezegun==1.5.1
license-expression==30.4.0
license-expression==30.4.1
mock-open==1.4.0
mypy-dev==1.16.0a1
pre-commit==4.0.0
pydantic==2.10.6
pylint==3.3.3
pylint-per-file-ignores==1.3.2
pipdeptree==2.23.4
pytest-asyncio==0.24.0
pylint==3.3.4
pylint-per-file-ignores==1.4.0
pipdeptree==2.25.0
pytest-asyncio==0.25.3
pytest-aiohttp==1.0.5
pytest-cov==6.0.0
pytest-freezer==0.4.8
pytest-github-actions-annotate-failures==0.2.0
pytest-freezer==0.4.9
pytest-github-actions-annotate-failures==0.3.0
pytest-socket==0.7.0
pytest-sugar==1.0.0
pytest-timeout==2.3.1
pytest-unordered==0.6.1
pytest-picked==0.5.0
pytest-picked==0.5.1
pytest-xdist==3.6.1
pytest==8.3.4
requests-mock==1.12.1
respx==0.22.0
syrupy==4.8.0
syrupy==4.8.1
tqdm==4.66.5
types-aiofiles==24.1.0.20241221
types-atomicwrites==1.4.5.1

View File

@ -528,7 +528,7 @@ bimmer-connected[china]==0.17.2
# homeassistant.components.eq3btsmart
# homeassistant.components.esphome
bleak-esphome==2.6.0
bleak-esphome==2.7.0
# homeassistant.components.bluetooth
bleak-retry-connector==3.8.0
@ -634,7 +634,7 @@ datadog==0.15.0
datapoint==0.9.9
# homeassistant.components.bluetooth
dbus-fast==2.31.0
dbus-fast==2.32.0
# homeassistant.components.debugpy
debugpy==1.8.11
@ -1206,9 +1206,6 @@ motioneye-client==0.3.14
# homeassistant.components.bang_olufsen
mozart-api==4.1.1.116.4
# homeassistant.components.onedrive
msgraph-sdk==1.16.0
# homeassistant.components.mullvad
mullvad-api==1.0.0
@ -1306,6 +1303,9 @@ omnilogic==0.4.5
# homeassistant.components.ondilo_ico
ondilo==0.5.0
# homeassistant.components.onedrive
onedrive-personal-sdk==0.0.1
# homeassistant.components.onvif
onvif-zeep-async==3.2.5
@ -1795,7 +1795,7 @@ pyownet==0.10.0.post1
pypalazzetti==0.1.19
# homeassistant.components.lcn
pypck==0.8.3
pypck==0.8.5
# homeassistant.components.pjlink
pypjlink2==1.2.1
@ -1882,7 +1882,7 @@ pysmarty2==0.10.1
pysml==0.0.12
# homeassistant.components.smlight
pysmlight==0.2.1
pysmlight==0.2.2
# homeassistant.components.snmp
pysnmp==6.2.6
@ -1933,7 +1933,7 @@ python-fullykiosk==0.0.14
python-google-drive-api==0.0.2
# homeassistant.components.analytics_insights
python-homeassistant-analytics==0.8.1
python-homeassistant-analytics==0.9.0
# homeassistant.components.homewizard
python-homewizard-energy==v8.3.2
@ -2031,7 +2031,7 @@ pyuptimerobot==22.2.0
pyvera==0.3.15
# homeassistant.components.vesync
pyvesync==2.1.16
pyvesync==2.1.17
# homeassistant.components.vizio
pyvizio==0.1.61

View File

@ -24,7 +24,7 @@ RUN --mount=from=ghcr.io/astral-sh/uv:0.5.21,source=/uv,target=/bin/uv \
--no-cache \
-c /usr/src/homeassistant/homeassistant/package_constraints.txt \
-r /usr/src/homeassistant/requirements.txt \
stdlib-list==0.10.0 pipdeptree==2.23.4 tqdm==4.66.5 ruff==0.9.1 \
stdlib-list==0.10.0 pipdeptree==2.25.0 tqdm==4.66.5 ruff==0.9.1 \
PyTurboJPEG==1.7.5 go2rtc-client==0.1.2 ha-ffmpeg==3.2.2 hassil==2.2.0 home-assistant-intents==2025.1.28 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2
LABEL "name"="hassfest"

View File

@ -236,6 +236,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="toolu_0123456789AbCdEfGhIjKlM",
tool_name="test_tool",
tool_args={"param1": "test_value"},
),
@ -373,6 +374,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="toolu_0123456789AbCdEfGhIjKlM",
tool_name="test_tool",
tool_args={"param1": "test_value"},
),

View File

@ -3,6 +3,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
@ -32,7 +33,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -94,6 +95,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
@ -123,7 +125,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -185,6 +187,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
@ -214,7 +217,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -276,6 +279,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
@ -329,7 +333,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -391,6 +395,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
@ -427,6 +432,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-conversation-id',
'language': 'en',
'pipeline': <ANY>,
}),
@ -434,7 +440,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-conversation-id',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
@ -478,6 +484,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-conversation-id',
'language': 'en',
'pipeline': <ANY>,
}),
@ -485,7 +492,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-conversation-id',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
@ -529,6 +536,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-conversation-id',
'language': 'en',
'pipeline': <ANY>,
}),
@ -536,7 +544,7 @@
}),
dict({
'data': dict({
'conversation_id': None,
'conversation_id': 'mock-conversation-id',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
@ -580,6 +588,7 @@
list([
dict({
'data': dict({
'conversation_id': 'mock-conversation-id',
'language': 'en',
'pipeline': <ANY>,
}),

View File

@ -1,6 +1,7 @@
# serializer version: 1
# name: test_audio_pipeline
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -31,7 +32,7 @@
# ---
# name: test_audio_pipeline.3
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -84,6 +85,7 @@
# ---
# name: test_audio_pipeline_debug
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -114,7 +116,7 @@
# ---
# name: test_audio_pipeline_debug.3
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -179,6 +181,7 @@
# ---
# name: test_audio_pipeline_with_enhancements
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -209,7 +212,7 @@
# ---
# name: test_audio_pipeline_with_enhancements.3
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -262,6 +265,7 @@
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -314,7 +318,7 @@
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.5
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test transcript',
@ -367,6 +371,7 @@
# ---
# name: test_audio_pipeline_with_wake_word_timeout
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -399,6 +404,7 @@
# ---
# name: test_device_capture
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -425,6 +431,7 @@
# ---
# name: test_device_capture_override
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -473,6 +480,7 @@
# ---
# name: test_device_capture_queue_full
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -512,6 +520,7 @@
# ---
# name: test_intent_failed
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -522,7 +531,7 @@
# ---
# name: test_intent_failed.1
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?',
@ -535,6 +544,7 @@
# ---
# name: test_intent_timeout
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -545,7 +555,7 @@
# ---
# name: test_intent_timeout.1
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?',
@ -564,6 +574,7 @@
# ---
# name: test_pipeline_empty_tts_output
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -574,7 +585,7 @@
# ---
# name: test_pipeline_empty_tts_output.1
dict({
'conversation_id': None,
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'never mind',
@ -611,6 +622,7 @@
# ---
# name: test_stt_cooldown_different_ids
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -621,6 +633,7 @@
# ---
# name: test_stt_cooldown_different_ids.1
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -631,6 +644,7 @@
# ---
# name: test_stt_cooldown_same_id
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -641,6 +655,7 @@
# ---
# name: test_stt_cooldown_same_id.1
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -651,6 +666,7 @@
# ---
# name: test_stt_stream_failed
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -677,6 +693,7 @@
# ---
# name: test_text_only_pipeline[extra_msg0]
dict({
'conversation_id': 'mock-conversation-id',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -723,6 +740,7 @@
# ---
# name: test_text_only_pipeline[extra_msg1]
dict({
'conversation_id': 'mock-conversation-id',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -775,6 +793,7 @@
# ---
# name: test_tts_failed
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -796,6 +815,7 @@
# ---
# name: test_wake_word_cooldown_different_entities
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -806,6 +826,7 @@
# ---
# name: test_wake_word_cooldown_different_entities.1
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -857,6 +878,7 @@
# ---
# name: test_wake_word_cooldown_different_ids
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -867,6 +889,7 @@
# ---
# name: test_wake_word_cooldown_different_ids.1
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -921,6 +944,7 @@
# ---
# name: test_wake_word_cooldown_same_id
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
@ -931,6 +955,7 @@
# ---
# name: test_wake_word_cooldown_same_id.1
dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({

View File

@ -1,11 +1,12 @@
"""Test Voice Assistant init."""
import asyncio
from collections.abc import Generator
from dataclasses import asdict
import itertools as it
from pathlib import Path
import tempfile
from unittest.mock import ANY, patch
from unittest.mock import ANY, Mock, patch
import wave
import hass_nabucasa
@ -41,6 +42,14 @@ from .conftest import (
from tests.typing import ClientSessionGenerator, WebSocketGenerator
@pytest.fixture(autouse=True)
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid of chat sessions."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
"""Process events to remove dynamic values."""
processed = []
@ -684,7 +693,7 @@ async def test_wake_word_detection_aborted(
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
conversation_id=None,
conversation_id="mock-conversation-id",
device_id=None,
stt_metadata=stt.SpeechMetadata(
language="",
@ -771,7 +780,7 @@ async def test_tts_audio_output(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
conversation_id="mock-conversation-id",
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
@ -828,7 +837,7 @@ async def test_tts_wav_preferred_format(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
conversation_id="mock-conversation-id",
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
@ -896,7 +905,7 @@ async def test_tts_dict_preferred_format(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
conversation_id="mock-conversation-id",
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
@ -982,6 +991,7 @@ async def test_sentence_trigger_overrides_conversation_agent(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test trigger sentence",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
@ -1059,6 +1069,7 @@ async def test_prefer_local_intents(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="I'd like to order a stout please",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
@ -1136,6 +1147,7 @@ async def test_stt_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
@ -1210,6 +1222,7 @@ async def test_tts_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
@ -1284,6 +1297,7 @@ async def test_pipeline_language_used_instead_of_conversation_language(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
conversation_id="mock-conversation-id",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),

View File

@ -2,8 +2,9 @@
import asyncio
import base64
from collections.abc import Generator
from typing import Any
from unittest.mock import ANY, patch
from unittest.mock import ANY, Mock, patch
import pytest
from syrupy.assertion import SnapshotAssertion
@ -35,6 +36,14 @@ from tests.common import MockConfigEntry
from tests.typing import WebSocketGenerator
@pytest.fixture(autouse=True)
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid of chat sessions."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
@pytest.mark.parametrize(
"extra_msg",
[

View File

@ -94,7 +94,9 @@ class MockAssistSatellite(AssistSatelliteEntity):
self, start_announcement: AssistSatelliteConfiguration
) -> None:
"""Start a conversation from the satellite."""
self.start_conversations.append((self._extra_system_prompt, start_announcement))
self.start_conversations.append(
(self._conversation_id, self._extra_system_prompt, start_announcement)
)
@pytest.fixture

View File

@ -1,7 +1,8 @@
"""Test the Assist Satellite entity."""
import asyncio
from unittest.mock import patch
from collections.abc import Generator
from unittest.mock import Mock, patch
import pytest
@ -31,6 +32,14 @@ from . import ENTITY_ID
from .conftest import MockAssistSatellite
@pytest.fixture
def mock_chat_session_conversation_id() -> Generator[Mock]:
"""Mock the ulid library."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-conversation-id"
yield mock_ulid_now
@pytest.fixture(autouse=True)
async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None:
"""Set up a pipeline with a TTS engine."""
@ -487,6 +496,7 @@ async def test_vad_sensitivity_entity_not_found(
"extra_system_prompt": "Better system prompt",
},
(
"mock-conversation-id",
"Better system prompt",
AssistSatelliteAnnouncement(
message="Hello",
@ -502,6 +512,7 @@ async def test_vad_sensitivity_entity_not_found(
"start_media_id": "media-source://given",
},
(
"mock-conversation-id",
"Hello",
AssistSatelliteAnnouncement(
message="Hello",
@ -514,6 +525,7 @@ async def test_vad_sensitivity_entity_not_found(
(
{"start_media_id": "http://example.com/given.mp3"},
(
"mock-conversation-id",
None,
AssistSatelliteAnnouncement(
message="",
@ -525,6 +537,7 @@ async def test_vad_sensitivity_entity_not_found(
),
],
)
@pytest.mark.usefixtures("mock_chat_session_conversation_id")
async def test_start_conversation(
hass: HomeAssistant,
init_components: ConfigEntry,

View File

@ -9,13 +9,13 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components.conversation import (
Content,
AssistantContent,
ConversationInput,
ConverseError,
NativeContent,
ToolResultContent,
async_get_chat_log,
)
from homeassistant.components.conversation.session import DATA_CHAT_HISTORY
from homeassistant.components.conversation.chat_log import DATA_CHAT_HISTORY
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, llm
@ -40,7 +40,7 @@ def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
@pytest.fixture
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid library."""
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
@ -56,13 +56,13 @@ async def test_cleanup(
):
conversation_id = session.conversation_id
# Add message so it persists
chat_log.async_add_message(
Content(
role="assistant",
agent_id=mock_conversation_input.agent_id,
content="",
async for _tool_result in chat_log.async_add_assistant_content(
AssistantContent(
agent_id="mock-agent-id",
content="Hey!",
)
)
):
pytest.fail("should not reach here")
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
@ -79,7 +79,7 @@ async def test_cleanup(
assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
async def test_add_message(
async def test_default_content(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test filtering of messages."""
@ -87,95 +87,11 @@ async def test_add_message(
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
assert len(chat_log.messages) == 2
with pytest.raises(ValueError):
chat_log.async_add_message(
Content(role="system", agent_id=None, content="")
)
# No 2 user messages in a row
assert chat_log.messages[1].role == "user"
with pytest.raises(ValueError):
chat_log.async_add_message(Content(role="user", agent_id=None, content=""))
# No 2 assistant messages in a row
chat_log.async_add_message(Content(role="assistant", agent_id=None, content=""))
assert len(chat_log.messages) == 3
assert chat_log.messages[-1].role == "assistant"
with pytest.raises(ValueError):
chat_log.async_add_message(
Content(role="assistant", agent_id=None, content="")
)
async def test_message_filtering(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test filtering of messages."""
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
messages = chat_log.async_get_messages(agent_id=None)
assert len(messages) == 2
assert messages[0] == Content(
role="system",
agent_id=None,
content="",
)
assert messages[1] == Content(
role="user",
agent_id="mock-agent-id",
content=mock_conversation_input.text,
)
# Cannot add a second user message in a row
with pytest.raises(ValueError):
chat_log.async_add_message(
Content(
role="user",
agent_id="mock-agent-id",
content="Hey!",
)
)
chat_log.async_add_message(
Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
)
# Different agent, native messages will be filtered out.
chat_log.async_add_message(
NativeContent(agent_id="another-mock-agent-id", content=1)
)
chat_log.async_add_message(NativeContent(agent_id="mock-agent-id", content=1))
# A non-native message from another agent is not filtered out.
chat_log.async_add_message(
Content(
role="assistant",
agent_id="another-mock-agent-id",
content="Hi!",
)
)
assert len(chat_log.messages) == 6
messages = chat_log.async_get_messages(agent_id="mock-agent-id")
assert len(messages) == 5
assert messages[2] == Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
assert messages[3] == NativeContent(agent_id="mock-agent-id", content=1)
assert messages[4] == Content(
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
)
assert len(chat_log.content) == 2
assert chat_log.content[0].role == "system"
assert chat_log.content[0].content == ""
assert chat_log.content[1].role == "user"
assert chat_log.content[1].content == mock_conversation_input.text
async def test_llm_api(
@ -268,12 +184,10 @@ async def test_template_variables(
),
)
assert chat_log.user_name == "Test User"
assert "The instance name is test home." in chat_log.messages[0].content
assert "The user name is Test User." in chat_log.messages[0].content
assert "The user id is 12345." in chat_log.messages[0].content
assert "The calling platform is test." in chat_log.messages[0].content
assert "The instance name is test home." in chat_log.content[0].content
assert "The user name is Test User." in chat_log.content[0].content
assert "The user id is 12345." in chat_log.content[0].content
assert "The calling platform is test." in chat_log.content[0].content
async def test_extra_systen_prompt(
@ -296,16 +210,16 @@ async def test_extra_systen_prompt(
user_llm_hass_api=None,
user_llm_prompt=None,
)
chat_log.async_add_message(
Content(
role="assistant",
async for _tool_result in chat_log.async_add_assistant_content(
AssistantContent(
agent_id="mock-agent-id",
content="Hey!",
)
)
):
pytest.fail("should not reach here")
assert chat_log.extra_system_prompt == extra_system_prompt
assert chat_log.messages[0].content.endswith(extra_system_prompt)
assert chat_log.content[0].content.endswith(extra_system_prompt)
# Verify that follow-up conversations with no system prompt take previous one
conversation_id = chat_log.conversation_id
@ -323,7 +237,7 @@ async def test_extra_systen_prompt(
)
assert chat_log.extra_system_prompt == extra_system_prompt
assert chat_log.messages[0].content.endswith(extra_system_prompt)
assert chat_log.content[0].content.endswith(extra_system_prompt)
# Verify that we take new system prompts
mock_conversation_input.extra_system_prompt = extra_system_prompt2
@ -338,17 +252,17 @@ async def test_extra_systen_prompt(
user_llm_hass_api=None,
user_llm_prompt=None,
)
chat_log.async_add_message(
Content(
role="assistant",
async for _tool_result in chat_log.async_add_assistant_content(
AssistantContent(
agent_id="mock-agent-id",
content="Hey!",
)
)
):
pytest.fail("should not reach here")
assert chat_log.extra_system_prompt == extra_system_prompt2
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
assert extra_system_prompt not in chat_log.messages[0].content
assert chat_log.content[0].content.endswith(extra_system_prompt2)
assert extra_system_prompt not in chat_log.content[0].content
# Verify that follow-up conversations with no system prompt take previous one
mock_conversation_input.extra_system_prompt = None
@ -365,7 +279,7 @@ async def test_extra_systen_prompt(
)
assert chat_log.extra_system_prompt == extra_system_prompt2
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
assert chat_log.content[0].content.endswith(extra_system_prompt2)
async def test_tool_call(
@ -383,8 +297,7 @@ async def test_tool_call(
mock_tool.async_call.return_value = "Test response"
with patch(
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
return_value=[],
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
) as mock_get_tools:
mock_get_tools.return_value = [mock_tool]
@ -398,14 +311,29 @@ async def test_tool_call(
user_llm_hass_api="assist",
user_llm_prompt=None,
)
result = await chat_log.async_call_tool(
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "Test Param"},
result = None
async for tool_result_content in chat_log.async_add_assistant_content(
AssistantContent(
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
],
)
)
):
assert result is None
result = tool_result_content
assert result == "Test response"
assert result == ToolResultContent(
agent_id=mock_conversation_input.agent_id,
tool_call_id="mock-tool-call-id",
tool_result="Test response",
tool_name="test_tool",
)
async def test_tool_call_exception(
@ -423,8 +351,7 @@ async def test_tool_call_exception(
mock_tool.async_call.side_effect = HomeAssistantError("Test error")
with patch(
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
return_value=[],
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
) as mock_get_tools:
mock_get_tools.return_value = [mock_tool]
@ -438,11 +365,26 @@ async def test_tool_call_exception(
user_llm_hass_api="assist",
user_llm_prompt=None,
)
result = await chat_log.async_call_tool(
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "Test Param"},
result = None
async for tool_result_content in chat_log.async_add_assistant_content(
AssistantContent(
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
],
)
)
):
assert result is None
result = tool_result_content
assert result == {"error": "HomeAssistantError", "error_text": "Test error"}
assert result == ToolResultContent(
agent_id=mock_conversation_input.agent_id,
tool_call_id="mock-tool-call-id",
tool_result={"error": "HomeAssistantError", "error_text": "Test error"},
tool_name="test_tool",
)

View File

@ -36,6 +36,13 @@ def freeze_the_time():
yield
@pytest.fixture(autouse=True)
def mock_ulid_tools():
"""Mock generated ULIDs for tool calls."""
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
yield
@pytest.mark.parametrize(
"agent_id", [None, "conversation.google_generative_ai_conversation"]
)
@ -177,6 +184,7 @@ async def test_chat_history(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
)
@pytest.mark.usefixtures("mock_init_component")
@pytest.mark.usefixtures("mock_ulid_tools")
async def test_function_call(
mock_get_tools,
hass: HomeAssistant,
@ -256,6 +264,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool",
tool_args={
"param1": ["test_value", "param1's value"],
@ -287,9 +296,7 @@ async def test_function_call(
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
assert [
p.function_response.name
for p in detail_event["data"]["messages"][2]["content"].parts
if p.function_response
p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
] == ["test_tool"]
@ -362,6 +369,7 @@ async def test_function_call_without_parameters(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool",
tool_args={},
),
@ -451,6 +459,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool",
tool_args={"param1": 1},
),
@ -605,6 +614,7 @@ async def test_template_variables(
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.text = "Model response"
mock_part.function_call = None
chat_response.parts = [mock_part]
result = await conversation.async_converse(
hass, "hello", None, context, agent_id=mock_config_entry.entry_id

View File

@ -324,6 +324,24 @@ TEST_JOB_DONE = supervisor_jobs.Job(
errors=[],
child_jobs=[],
)
TEST_RESTORE_JOB_DONE_WITH_ERROR = supervisor_jobs.Job(
name="backup_manager_partial_restore",
reference="1ef41507",
uuid=UUID(TEST_JOB_ID),
progress=0.0,
stage="copy_additional_locations",
done=True,
errors=[
supervisor_jobs.JobError(
type="BackupInvalidError",
message=(
"Backup was made on supervisor version 2025.02.2.dev3105, "
"can't restore on 2025.01.2.dev3105"
),
)
],
child_jobs=[],
)
@pytest.fixture(autouse=True)
@ -1946,6 +1964,97 @@ async def test_reader_writer_restore_error(
assert response["error"]["code"] == expected_error_code
@pytest.mark.usefixtures("hassio_client", "setup_integration")
async def test_reader_writer_restore_late_error(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
supervisor_client: AsyncMock,
) -> None:
"""Test restoring a backup with error."""
client = await hass_ws_client(hass)
supervisor_client.backups.partial_restore.return_value.job_id = TEST_JOB_ID
supervisor_client.backups.list.return_value = [TEST_BACKUP]
supervisor_client.backups.backup_info.return_value = TEST_BACKUP_DETAILS
supervisor_client.jobs.get_job.return_value = TEST_JOB_NOT_DONE
await client.send_json_auto_id({"type": "backup/subscribe_events"})
response = await client.receive_json()
assert response["event"] == {"manager_state": "idle"}
response = await client.receive_json()
assert response["success"]
await client.send_json_auto_id(
{"type": "backup/restore", "agent_id": "hassio.local", "backup_id": "abc123"}
)
response = await client.receive_json()
assert response["event"] == {
"manager_state": "restore_backup",
"reason": None,
"stage": None,
"state": "in_progress",
}
supervisor_client.backups.partial_restore.assert_called_once_with(
"abc123",
supervisor_backups.PartialRestoreOptions(
addons=None,
background=True,
folders=None,
homeassistant=True,
location=None,
password=None,
),
)
event = {
"event": "job",
"data": {
"name": "backup_manager_partial_restore",
"reference": "7c54aeed",
"uuid": TEST_JOB_ID,
"progress": 0,
"stage": None,
"done": True,
"parent_id": None,
"errors": [
{
"type": "BackupInvalidError",
"message": (
"Backup was made on supervisor version 2025.02.2.dev3105, can't"
" restore on 2025.01.2.dev3105. Must update supervisor first."
),
}
],
"created": "2025-02-03T08:27:49.297997+00:00",
},
}
await client.send_json_auto_id({"type": "supervisor/event", "data": event})
response = await client.receive_json()
assert response["success"]
response = await client.receive_json()
assert response["event"] == {
"manager_state": "restore_backup",
"reason": "backup_reader_writer_error",
"stage": None,
"state": "failed",
}
response = await client.receive_json()
assert response["event"] == {"manager_state": "idle"}
response = await client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "home_assistant_error",
"message": (
"Restore failed: [{'type': 'BackupInvalidError', 'message': \"Backup "
"was made on supervisor version 2025.02.2.dev3105, can't restore on "
'2025.01.2.dev3105. Must update supervisor first."}]'
),
}
@pytest.mark.parametrize(
("backup", "backup_details", "parameters", "expected_error"),
[
@ -1999,15 +2108,40 @@ async def test_reader_writer_restore_wrong_parameters(
}
@pytest.mark.parametrize(
("get_job_result", "last_non_idle_event"),
[
(
TEST_JOB_DONE,
{
"manager_state": "restore_backup",
"reason": "",
"stage": None,
"state": "completed",
},
),
(
TEST_RESTORE_JOB_DONE_WITH_ERROR,
{
"manager_state": "restore_backup",
"reason": "unknown_error",
"stage": None,
"state": "failed",
},
),
],
)
@pytest.mark.usefixtures("hassio_client")
async def test_restore_progress_after_restart(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
supervisor_client: AsyncMock,
get_job_result: supervisor_jobs.Job,
last_non_idle_event: dict[str, Any],
) -> None:
"""Test restore backup progress after restart."""
supervisor_client.jobs.get_job.return_value = TEST_JOB_DONE
supervisor_client.jobs.get_job.return_value = get_job_result
with patch.dict(os.environ, MOCK_ENVIRON | {RESTORE_JOB_ID_ENV: TEST_JOB_ID}):
assert await async_setup_component(hass, BACKUP_DOMAIN, {BACKUP_DOMAIN: {}})
@ -2018,12 +2152,7 @@ async def test_restore_progress_after_restart(
response = await client.receive_json()
assert response["success"]
assert response["result"]["last_non_idle_event"] == {
"manager_state": "restore_backup",
"reason": "",
"stage": None,
"state": "completed",
}
assert response["result"]["last_non_idle_event"] == last_non_idle_event
assert response["result"]["state"] == "idle"

View File

@ -18,6 +18,13 @@ from homeassistant.helpers import intent, llm
from tests.common import MockConfigEntry
@pytest.fixture(autouse=True)
def mock_ulid_tools():
"""Mock generated ULIDs for tool calls."""
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
yield
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
async def test_chat(
hass: HomeAssistant,
@ -205,6 +212,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool",
tool_args=expected_tool_args,
),
@ -285,6 +293,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool",
tool_args={"param1": "test_value"},
),

View File

@ -1,18 +1,9 @@
"""Fixtures for OneDrive tests."""
from collections.abc import AsyncIterator, Generator
from html import escape
from json import dumps
import time
from unittest.mock import AsyncMock, MagicMock, patch
from httpx import Response
from msgraph.generated.models.drive_item import DriveItem
from msgraph.generated.models.drive_item_collection_response import (
DriveItemCollectionResponse,
)
from msgraph.generated.models.upload_session import UploadSession
from msgraph_core.models import LargeFileUploadSession
import pytest
from homeassistant.components.application_credentials import (
@ -23,7 +14,13 @@ from homeassistant.components.onedrive.const import DOMAIN, OAUTH_SCOPES
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from .const import BACKUP_METADATA, CLIENT_ID, CLIENT_SECRET
from .const import (
CLIENT_ID,
CLIENT_SECRET,
MOCK_APPROOT,
MOCK_BACKUP_FILE,
MOCK_BACKUP_FOLDER,
)
from tests.common import MockConfigEntry
@ -70,96 +67,41 @@ def mock_config_entry(expires_at: int, scopes: list[str]) -> MockConfigEntry:
)
@pytest.fixture
def mock_adapter() -> Generator[MagicMock]:
"""Return a mocked GraphAdapter."""
with (
patch(
"homeassistant.components.onedrive.config_flow.GraphRequestAdapter",
autospec=True,
) as mock_adapter,
patch(
"homeassistant.components.onedrive.backup.GraphRequestAdapter",
new=mock_adapter,
),
):
adapter = mock_adapter.return_value
adapter.get_http_response_message.return_value = Response(
status_code=200,
json={
"parentReference": {"driveId": "mock_drive_id"},
"createdBy": {"user": {"displayName": "John Doe"}},
},
)
yield adapter
adapter.send_async.return_value = LargeFileUploadSession(
next_expected_ranges=["2-"]
)
@pytest.fixture(autouse=True)
def mock_graph_client(mock_adapter: MagicMock) -> Generator[MagicMock]:
def mock_onedrive_client() -> Generator[MagicMock]:
"""Return a mocked GraphServiceClient."""
with (
patch(
"homeassistant.components.onedrive.config_flow.GraphServiceClient",
"homeassistant.components.onedrive.config_flow.OneDriveClient",
autospec=True,
) as graph_client,
) as onedrive_client,
patch(
"homeassistant.components.onedrive.GraphServiceClient",
new=graph_client,
"homeassistant.components.onedrive.OneDriveClient",
new=onedrive_client,
),
):
client = graph_client.return_value
client = onedrive_client.return_value
client.get_approot.return_value = MOCK_APPROOT
client.create_folder.return_value = MOCK_BACKUP_FOLDER
client.list_drive_items.return_value = [MOCK_BACKUP_FILE]
client.get_drive_item.return_value = MOCK_BACKUP_FILE
client.request_adapter = mock_adapter
class MockStreamReader:
async def iter_chunked(self, chunk_size: int) -> AsyncIterator[bytes]:
yield b"backup data"
drives = client.drives.by_drive_id.return_value
drives.special.by_drive_item_id.return_value.get = AsyncMock(
return_value=DriveItem(id="approot")
)
drive_items = drives.items.by_drive_item_id.return_value
drive_items.get = AsyncMock(return_value=DriveItem(id="folder_id"))
drive_items.children.post = AsyncMock(return_value=DriveItem(id="folder_id"))
drive_items.children.get = AsyncMock(
return_value=DriveItemCollectionResponse(
value=[
DriveItem(
id=BACKUP_METADATA["backup_id"],
description=escape(dumps(BACKUP_METADATA)),
),
DriveItem(),
]
)
)
drive_items.delete = AsyncMock(return_value=None)
drive_items.create_upload_session.post = AsyncMock(
return_value=UploadSession(upload_url="https://test.tld")
)
drive_items.patch = AsyncMock(return_value=None)
async def generate_bytes() -> AsyncIterator[bytes]:
"""Asynchronous generator that yields bytes."""
yield b"backup data"
drive_items.content.get = AsyncMock(
return_value=Response(status_code=200, content=generate_bytes())
)
client.download_drive_item.return_value = MockStreamReader()
yield client
@pytest.fixture
def mock_drive_items(mock_graph_client: MagicMock) -> MagicMock:
"""Return a mocked DriveItems."""
return mock_graph_client.drives.by_drive_id.return_value.items.by_drive_item_id.return_value
@pytest.fixture
def mock_get_special_folder(mock_graph_client: MagicMock) -> MagicMock:
"""Mock the get special folder method."""
return mock_graph_client.drives.by_drive_id.return_value.special.by_drive_item_id.return_value.get
def mock_large_file_upload_client() -> Generator[AsyncMock]:
"""Return a mocked LargeFileUploadClient upload."""
with patch(
"homeassistant.components.onedrive.backup.LargeFileUploadClient.upload"
) as mock_upload:
yield mock_upload
@pytest.fixture
@ -179,10 +121,3 @@ def mock_instance_id() -> Generator[AsyncMock]:
return_value="9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0",
):
yield
@pytest.fixture(autouse=True)
def mock_asyncio_sleep() -> Generator[AsyncMock]:
"""Mock asyncio.sleep."""
with patch("homeassistant.components.onedrive.backup.asyncio.sleep", AsyncMock()):
yield

View File

@ -1,5 +1,18 @@
"""Consts for OneDrive tests."""
from html import escape
from json import dumps
from onedrive_personal_sdk.models.items import (
AppRoot,
Contributor,
File,
Folder,
Hashes,
ItemParentReference,
User,
)
CLIENT_ID = "1234"
CLIENT_SECRET = "5678"
@ -17,3 +30,48 @@ BACKUP_METADATA = {
"protected": False,
"size": 34519040,
}
CONTRIBUTOR = Contributor(
user=User(
display_name="John Doe",
id="id",
email="john@doe.com",
)
)
MOCK_APPROOT = AppRoot(
id="id",
child_count=0,
size=0,
name="name",
parent_reference=ItemParentReference(
drive_id="mock_drive_id", id="id", path="path"
),
created_by=CONTRIBUTOR,
)
MOCK_BACKUP_FOLDER = Folder(
id="id",
name="name",
size=0,
child_count=0,
parent_reference=ItemParentReference(
drive_id="mock_drive_id", id="id", path="path"
),
created_by=CONTRIBUTOR,
)
MOCK_BACKUP_FILE = File(
id="id",
name="23e64aec.tar",
size=34519040,
parent_reference=ItemParentReference(
drive_id="mock_drive_id", id="id", path="path"
),
hashes=Hashes(
quick_xor_hash="hash",
),
mime_type="application/x-tar",
description=escape(dumps(BACKUP_METADATA)),
created_by=CONTRIBUTOR,
)

View File

@ -3,15 +3,14 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
from html import escape
from io import StringIO
from json import dumps
from unittest.mock import Mock, patch
from httpx import TimeoutException
from kiota_abstractions.api_error import APIError
from msgraph.generated.models.drive_item import DriveItem
from msgraph_core.models import LargeFileUploadSession
from onedrive_personal_sdk.exceptions import (
AuthenticationError,
HashMismatchError,
OneDriveException,
)
import pytest
from homeassistant.components.backup import DOMAIN as BACKUP_DOMAIN, AgentBackup
@ -102,14 +101,10 @@ async def test_agents_list_backups(
async def test_agents_get_backup(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_drive_items: MagicMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test agent get backup."""
mock_drive_items.get = AsyncMock(
return_value=DriveItem(description=escape(dumps(BACKUP_METADATA)))
)
backup_id = BACKUP_METADATA["backup_id"]
client = await hass_ws_client(hass)
await client.send_json_auto_id({"type": "backup/details", "backup_id": backup_id})
@ -140,7 +135,7 @@ async def test_agents_get_backup(
async def test_agents_delete(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_drive_items: MagicMock,
mock_onedrive_client: MagicMock,
) -> None:
"""Test agent delete backup."""
client = await hass_ws_client(hass)
@ -155,37 +150,15 @@ async def test_agents_delete(
assert response["success"]
assert response["result"] == {"agent_errors": {}}
mock_drive_items.delete.assert_called_once()
async def test_agents_delete_not_found_does_not_throw(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_drive_items: MagicMock,
) -> None:
"""Test agent delete backup."""
mock_drive_items.children.get = AsyncMock(return_value=[])
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "backup/delete",
"backup_id": BACKUP_METADATA["backup_id"],
}
)
response = await client.receive_json()
assert response["success"]
assert response["result"] == {"agent_errors": {}}
assert mock_drive_items.delete.call_count == 0
mock_onedrive_client.delete_drive_item.assert_called_once()
async def test_agents_upload(
hass_client: ClientSessionGenerator,
caplog: pytest.LogCaptureFixture,
mock_drive_items: MagicMock,
mock_onedrive_client: MagicMock,
mock_large_file_upload_client: AsyncMock,
mock_config_entry: MockConfigEntry,
mock_adapter: MagicMock,
) -> None:
"""Test agent upload backup."""
client = await hass_client()
@ -200,7 +173,6 @@ async def test_agents_upload(
return_value=test_backup,
),
patch("pathlib.Path.open") as mocked_open,
patch("homeassistant.components.onedrive.backup.UPLOAD_CHUNK_SIZE", 3),
):
mocked_open.return_value.read = Mock(side_effect=[b"test", b""])
fetch_backup.return_value = test_backup
@ -211,31 +183,22 @@ async def test_agents_upload(
assert resp.status == 201
assert f"Uploading backup {test_backup.backup_id}" in caplog.text
mock_drive_items.create_upload_session.post.assert_called_once()
mock_drive_items.patch.assert_called_once()
assert mock_adapter.send_async.call_count == 2
assert mock_adapter.method_calls[0].args[0].content == b"tes"
assert mock_adapter.method_calls[0].args[0].headers.get("Content-Range") == {
"bytes 0-2/34519040"
}
assert mock_adapter.method_calls[1].args[0].content == b"t"
assert mock_adapter.method_calls[1].args[0].headers.get("Content-Range") == {
"bytes 3-3/34519040"
}
mock_large_file_upload_client.assert_called_once()
mock_onedrive_client.update_drive_item.assert_called_once()
async def test_broken_upload_session(
async def test_agents_upload_corrupt_upload(
hass_client: ClientSessionGenerator,
caplog: pytest.LogCaptureFixture,
mock_drive_items: MagicMock,
mock_onedrive_client: MagicMock,
mock_large_file_upload_client: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test broken upload session."""
"""Test hash validation fails."""
mock_large_file_upload_client.side_effect = HashMismatchError("test")
client = await hass_client()
test_backup = AgentBackup.from_dict(BACKUP_METADATA)
mock_drive_items.create_upload_session.post = AsyncMock(return_value=None)
with (
patch(
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
@ -254,152 +217,18 @@ async def test_broken_upload_session(
)
assert resp.status == 201
assert "Failed to start backup upload" in caplog.text
@pytest.mark.parametrize(
"side_effect",
[
APIError(response_status_code=500),
TimeoutException("Timeout"),
],
)
async def test_agents_upload_errors_retried(
hass_client: ClientSessionGenerator,
caplog: pytest.LogCaptureFixture,
mock_drive_items: MagicMock,
mock_config_entry: MockConfigEntry,
mock_adapter: MagicMock,
side_effect: Exception,
) -> None:
"""Test agent upload backup."""
client = await hass_client()
test_backup = AgentBackup.from_dict(BACKUP_METADATA)
mock_adapter.send_async.side_effect = [
side_effect,
LargeFileUploadSession(next_expected_ranges=["2-"]),
LargeFileUploadSession(next_expected_ranges=["2-"]),
]
with (
patch(
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
) as fetch_backup,
patch(
"homeassistant.components.backup.manager.read_backup",
return_value=test_backup,
),
patch("pathlib.Path.open") as mocked_open,
patch("homeassistant.components.onedrive.backup.UPLOAD_CHUNK_SIZE", 3),
):
mocked_open.return_value.read = Mock(side_effect=[b"test", b""])
fetch_backup.return_value = test_backup
resp = await client.post(
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.unique_id}",
data={"file": StringIO("test")},
)
assert resp.status == 201
assert mock_adapter.send_async.call_count == 3
assert f"Uploading backup {test_backup.backup_id}" in caplog.text
mock_drive_items.patch.assert_called_once()
async def test_agents_upload_4xx_errors_not_retried(
hass_client: ClientSessionGenerator,
caplog: pytest.LogCaptureFixture,
mock_drive_items: MagicMock,
mock_config_entry: MockConfigEntry,
mock_adapter: MagicMock,
) -> None:
"""Test agent upload backup."""
client = await hass_client()
test_backup = AgentBackup.from_dict(BACKUP_METADATA)
mock_adapter.send_async.side_effect = APIError(response_status_code=404)
with (
patch(
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
) as fetch_backup,
patch(
"homeassistant.components.backup.manager.read_backup",
return_value=test_backup,
),
patch("pathlib.Path.open") as mocked_open,
patch("homeassistant.components.onedrive.backup.UPLOAD_CHUNK_SIZE", 3),
):
mocked_open.return_value.read = Mock(side_effect=[b"test", b""])
fetch_backup.return_value = test_backup
resp = await client.post(
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.unique_id}",
data={"file": StringIO("test")},
)
assert resp.status == 201
assert mock_adapter.send_async.call_count == 1
assert f"Uploading backup {test_backup.backup_id}" in caplog.text
assert mock_drive_items.patch.call_count == 0
assert "Backup operation failed" in caplog.text
@pytest.mark.parametrize(
("side_effect", "error"),
[
(APIError(response_status_code=500), "Backup operation failed"),
(TimeoutException("Timeout"), "Backup operation timed out"),
],
)
async def test_agents_upload_fails_after_max_retries(
hass_client: ClientSessionGenerator,
caplog: pytest.LogCaptureFixture,
mock_drive_items: MagicMock,
mock_config_entry: MockConfigEntry,
mock_adapter: MagicMock,
side_effect: Exception,
error: str,
) -> None:
"""Test agent upload backup."""
client = await hass_client()
test_backup = AgentBackup.from_dict(BACKUP_METADATA)
mock_adapter.send_async.side_effect = side_effect
with (
patch(
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
) as fetch_backup,
patch(
"homeassistant.components.backup.manager.read_backup",
return_value=test_backup,
),
patch("pathlib.Path.open") as mocked_open,
patch("homeassistant.components.onedrive.backup.UPLOAD_CHUNK_SIZE", 3),
):
mocked_open.return_value.read = Mock(side_effect=[b"test", b""])
fetch_backup.return_value = test_backup
resp = await client.post(
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.unique_id}",
data={"file": StringIO("test")},
)
assert resp.status == 201
assert mock_adapter.send_async.call_count == 6
assert f"Uploading backup {test_backup.backup_id}" in caplog.text
assert mock_drive_items.patch.call_count == 0
assert error in caplog.text
mock_large_file_upload_client.assert_called_once()
assert mock_onedrive_client.update_drive_item.call_count == 0
assert "Hash validation failed, backup file might be corrupt" in caplog.text
async def test_agents_download(
hass_client: ClientSessionGenerator,
mock_drive_items: MagicMock,
mock_onedrive_client: MagicMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test agent download backup."""
mock_drive_items.get = AsyncMock(
return_value=DriveItem(description=escape(dumps(BACKUP_METADATA)))
)
client = await hass_client()
backup_id = BACKUP_METADATA["backup_id"]
@ -408,29 +237,30 @@ async def test_agents_download(
)
assert resp.status == 200
assert await resp.content.read() == b"backup data"
mock_drive_items.content.get.assert_called_once()
@pytest.mark.parametrize(
("side_effect", "error"),
[
(
APIError(response_status_code=500),
OneDriveException(),
"Backup operation failed",
),
(TimeoutException("Timeout"), "Backup operation timed out"),
(TimeoutError(), "Backup operation timed out"),
],
)
async def test_delete_error(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_drive_items: MagicMock,
mock_onedrive_client: MagicMock,
mock_config_entry: MockConfigEntry,
side_effect: Exception,
error: str,
) -> None:
"""Test error during delete."""
mock_drive_items.delete = AsyncMock(side_effect=side_effect)
mock_onedrive_client.delete_drive_item.side_effect = AsyncMock(
side_effect=side_effect
)
client = await hass_ws_client(hass)
@ -448,14 +278,35 @@ async def test_delete_error(
}
async def test_agents_delete_not_found_does_not_throw(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_onedrive_client: MagicMock,
) -> None:
"""Test agent delete backup."""
mock_onedrive_client.list_drive_items.return_value = []
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "backup/delete",
"backup_id": BACKUP_METADATA["backup_id"],
}
)
response = await client.receive_json()
assert response["success"]
assert response["result"] == {"agent_errors": {}}
async def test_agents_backup_not_found(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_drive_items: MagicMock,
mock_onedrive_client: MagicMock,
) -> None:
"""Test backup not found."""
mock_drive_items.children.get = AsyncMock(return_value=[])
mock_onedrive_client.list_drive_items.return_value = []
backup_id = BACKUP_METADATA["backup_id"]
client = await hass_ws_client(hass)
await client.send_json_auto_id({"type": "backup/details", "backup_id": backup_id})
@ -468,13 +319,13 @@ async def test_agents_backup_not_found(
async def test_reauth_on_403(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_drive_items: MagicMock,
mock_onedrive_client: MagicMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test we re-authenticate on 403."""
mock_drive_items.children.get = AsyncMock(
side_effect=APIError(response_status_code=403)
mock_onedrive_client.list_drive_items.side_effect = AuthenticationError(
403, "Auth failed"
)
backup_id = BACKUP_METADATA["backup_id"]
client = await hass_ws_client(hass)
@ -483,7 +334,7 @@ async def test_reauth_on_403(
assert response["success"]
assert response["result"]["agent_errors"] == {
f"{DOMAIN}.{mock_config_entry.unique_id}": "Backup operation failed"
f"{DOMAIN}.{mock_config_entry.unique_id}": "Authentication error"
}
await hass.async_block_till_done()

View File

@ -3,8 +3,7 @@
from http import HTTPStatus
from unittest.mock import AsyncMock, MagicMock
from httpx import Response
from kiota_abstractions.api_error import APIError
from onedrive_personal_sdk.exceptions import OneDriveException
import pytest
from homeassistant import config_entries
@ -20,7 +19,7 @@ from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import config_entry_oauth2_flow
from . import setup_integration
from .const import CLIENT_ID
from .const import CLIENT_ID, MOCK_APPROOT
from tests.common import MockConfigEntry
from tests.test_util.aiohttp import AiohttpClientMocker
@ -89,25 +88,52 @@ async def test_full_flow(
assert result["data"][CONF_TOKEN]["refresh_token"] == "mock-refresh-token"
@pytest.mark.usefixtures("current_request_with_host")
async def test_full_flow_with_owner_not_found(
hass: HomeAssistant,
hass_client_no_auth: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
mock_setup_entry: AsyncMock,
mock_onedrive_client: MagicMock,
) -> None:
"""Ensure we get a default title if the drive's owner can't be read."""
mock_onedrive_client.get_approot.return_value.created_by.user = None
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await _do_get_token(hass, result, hass_client_no_auth, aioclient_mock)
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] is FlowResultType.CREATE_ENTRY
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
assert len(mock_setup_entry.mock_calls) == 1
assert result["title"] == "OneDrive"
assert result["result"].unique_id == "mock_drive_id"
assert result["data"][CONF_TOKEN][CONF_ACCESS_TOKEN] == "mock-access-token"
assert result["data"][CONF_TOKEN]["refresh_token"] == "mock-refresh-token"
@pytest.mark.usefixtures("current_request_with_host")
@pytest.mark.parametrize(
("exception", "error"),
[
(Exception, "unknown"),
(APIError, "connection_error"),
(OneDriveException, "connection_error"),
],
)
async def test_flow_errors(
hass: HomeAssistant,
hass_client_no_auth: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
mock_adapter: MagicMock,
mock_onedrive_client: MagicMock,
exception: Exception,
error: str,
) -> None:
"""Test errors during flow."""
mock_adapter.get_http_response_message.side_effect = exception
mock_onedrive_client.get_approot.side_effect = exception
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
@ -172,15 +198,12 @@ async def test_reauth_flow_id_changed(
aioclient_mock: AiohttpClientMocker,
mock_setup_entry: AsyncMock,
mock_config_entry: MockConfigEntry,
mock_adapter: MagicMock,
mock_onedrive_client: MagicMock,
) -> None:
"""Test that the reauth flow fails on a different drive id."""
mock_adapter.get_http_response_message.return_value = Response(
status_code=200,
json={
"parentReference": {"driveId": "other_drive_id"},
},
)
app_root = MOCK_APPROOT
app_root.parent_reference.drive_id = "other_drive_id"
mock_onedrive_client.get_approot.return_value = app_root
await setup_integration(hass, mock_config_entry)

View File

@ -2,7 +2,7 @@
from unittest.mock import MagicMock
from kiota_abstractions.api_error import APIError
from onedrive_personal_sdk.exceptions import AuthenticationError, OneDriveException
import pytest
from homeassistant.config_entries import ConfigEntryState
@ -31,82 +31,31 @@ async def test_load_unload_config_entry(
@pytest.mark.parametrize(
("side_effect", "state"),
[
(APIError(response_status_code=403), ConfigEntryState.SETUP_ERROR),
(APIError(response_status_code=500), ConfigEntryState.SETUP_RETRY),
(AuthenticationError(403, "Auth failed"), ConfigEntryState.SETUP_ERROR),
(OneDriveException(), ConfigEntryState.SETUP_RETRY),
],
)
async def test_approot_errors(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_get_special_folder: MagicMock,
mock_onedrive_client: MagicMock,
side_effect: Exception,
state: ConfigEntryState,
) -> None:
"""Test errors during approot retrieval."""
mock_get_special_folder.side_effect = side_effect
mock_onedrive_client.get_approot.side_effect = side_effect
await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is state
async def test_faulty_approot(
async def test_get_integration_folder_error(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_get_special_folder: MagicMock,
mock_onedrive_client: MagicMock,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test faulty approot retrieval."""
mock_get_special_folder.return_value = None
await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
assert "Failed to get approot folder" in caplog.text
async def test_faulty_integration_folder(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_drive_items: MagicMock,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test faulty approot retrieval."""
mock_drive_items.get.return_value = None
mock_onedrive_client.create_folder.side_effect = OneDriveException()
await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
assert "Failed to get backups_9f86d081 folder" in caplog.text
async def test_500_error_during_backup_folder_get(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_drive_items: MagicMock,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test error during backup folder creation."""
mock_drive_items.get.side_effect = APIError(response_status_code=500)
await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
assert "Failed to get backups_9f86d081 folder" in caplog.text
async def test_error_during_backup_folder_creation(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_drive_items: MagicMock,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test error during backup folder creation."""
mock_drive_items.get.side_effect = APIError(response_status_code=404)
mock_drive_items.children.post.side_effect = APIError()
await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
assert "Failed to create backups_9f86d081 folder" in caplog.text
async def test_successful_backup_folder_creation(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_drive_items: MagicMock,
) -> None:
"""Test successful backup folder creation."""
mock_drive_items.get.side_effect = APIError(response_status_code=404)
await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is ConfigEntryState.LOADED

View File

@ -12,12 +12,14 @@ from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_REASONING_EFFORT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TOP_P,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TOP_P,
)
from homeassistant.const import CONF_LLM_HASS_API
@ -88,6 +90,27 @@ async def test_options(
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
async def test_options_unsupported_model(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test the options form giving error about models not supported."""
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
result = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_CHAT_MODEL: "o1-mini",
CONF_LLM_HASS_API: "assist",
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"chat_model": "model_not_supported"}
@pytest.mark.parametrize(
("side_effect", "error"),
[
@ -148,6 +171,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
},
),
(
@ -158,6 +182,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
},
{
CONF_RECOMMENDED: True,

View File

@ -195,6 +195,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
tool_name="test_tool",
tool_args={"param1": "test_value"},
),
@ -359,6 +360,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
tool_name="test_tool",
tool_args={"param1": "test_value"},
),

View File

@ -0,0 +1,59 @@
# serializer version: 1
# name: test_service_get[schedule.from_storage-get-after-update]
dict({
'friday': list([
]),
'monday': list([
]),
'saturday': list([
]),
'sunday': list([
]),
'thursday': list([
]),
'tuesday': list([
]),
'wednesday': list([
dict({
'from': datetime.time(17, 0),
'to': datetime.time(19, 0),
}),
]),
})
# ---
# name: test_service_get[schedule.from_storage-get]
dict({
'friday': list([
dict({
'data': dict({
'party_level': 'epic',
}),
'from': datetime.time(17, 0),
'to': datetime.time(23, 59, 59),
}),
]),
'monday': list([
]),
'saturday': list([
dict({
'from': datetime.time(0, 0),
'to': datetime.time(23, 59, 59),
}),
]),
'sunday': list([
dict({
'data': dict({
'entry': 'VIPs only',
}),
'from': datetime.time(0, 0),
'to': datetime.time(23, 59, 59, 999999),
}),
]),
'thursday': list([
]),
'tuesday': list([
]),
'wednesday': list([
]),
})
# ---

View File

@ -8,10 +8,12 @@ from unittest.mock import patch
from freezegun.api import FrozenDateTimeFactory
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.schedule import STORAGE_VERSION, STORAGE_VERSION_MINOR
from homeassistant.components.schedule.const import (
ATTR_NEXT_EVENT,
CONF_ALL_DAYS,
CONF_DATA,
CONF_FRIDAY,
CONF_FROM,
@ -23,12 +25,14 @@ from homeassistant.components.schedule.const import (
CONF_TUESDAY,
CONF_WEDNESDAY,
DOMAIN,
SERVICE_GET,
)
from homeassistant.const import (
ATTR_EDITABLE,
ATTR_FRIENDLY_NAME,
ATTR_ICON,
ATTR_NAME,
CONF_ENTITY_ID,
CONF_ICON,
CONF_ID,
CONF_NAME,
@ -754,3 +758,66 @@ async def test_ws_create(
assert result["party_mode"][CONF_MONDAY] == [
{CONF_FROM: "12:00:00", CONF_TO: saved_to}
]
async def test_service_get(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
schedule_setup: Callable[..., Coroutine[Any, Any, bool]],
) -> None:
"""Test getting a single schedule via service."""
assert await schedule_setup()
entity_id = "schedule.from_storage"
# Test retrieving a single schedule via service call
service_result = await hass.services.async_call(
DOMAIN,
SERVICE_GET,
{
CONF_ENTITY_ID: entity_id,
},
blocking=True,
return_response=True,
)
result = service_result.get(entity_id)
assert set(result) == CONF_ALL_DAYS
assert result == snapshot(name=f"{entity_id}-get")
# Now we update the schedule via WS
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 1,
"type": f"{DOMAIN}/update",
f"{DOMAIN}_id": entity_id.rsplit(".", maxsplit=1)[-1],
CONF_NAME: "Party pooper",
CONF_ICON: "mdi:party-pooper",
CONF_MONDAY: [],
CONF_TUESDAY: [],
CONF_WEDNESDAY: [{CONF_FROM: "17:00:00", CONF_TO: "19:00:00"}],
CONF_THURSDAY: [],
CONF_FRIDAY: [],
CONF_SATURDAY: [],
CONF_SUNDAY: [],
}
)
resp = await client.receive_json()
assert resp["success"]
# Test retrieving the schedule via service call after WS update
service_result = await hass.services.async_call(
DOMAIN,
SERVICE_GET,
{
CONF_ENTITY_ID: entity_id,
},
blocking=True,
return_response=True,
)
result = service_result.get(entity_id)
assert set(result) == CONF_ALL_DAYS
assert result == snapshot(name=f"{entity_id}-get-after-update")

View File

@ -180,6 +180,7 @@ MOCK_CONFIG = {
"xcounts": {"expr": None, "unit": None},
"xfreq": {"expr": None, "unit": None},
},
"flood:0": {"id": 0, "name": "Test name"},
"light:0": {"name": "test light_0"},
"light:1": {"name": "test light_1"},
"light:2": {"name": "test light_2"},
@ -326,6 +327,7 @@ MOCK_STATUS_RPC = {
"em1:1": {"act_power": 123.3},
"em1data:0": {"total_act_energy": 123456.4},
"em1data:1": {"total_act_energy": 987654.3},
"flood:0": {"id": 0, "alarm": False, "mute": False},
"thermostat:0": {
"id": 0,
"enable": True,

View File

@ -46,3 +46,96 @@
'state': 'off',
})
# ---
# name: test_rpc_flood_entities[binary_sensor.test_name_flood-entry]
EntityRegistryEntrySnapshot({
'aliases': set({
}),
'area_id': None,
'capabilities': None,
'config_entry_id': <ANY>,
'device_class': None,
'device_id': <ANY>,
'disabled_by': None,
'domain': 'binary_sensor',
'entity_category': None,
'entity_id': 'binary_sensor.test_name_flood',
'has_entity_name': False,
'hidden_by': None,
'icon': None,
'id': <ANY>,
'labels': set({
}),
'name': None,
'options': dict({
}),
'original_device_class': <BinarySensorDeviceClass.MOISTURE: 'moisture'>,
'original_icon': None,
'original_name': 'Test name flood',
'platform': 'shelly',
'previous_unique_id': None,
'supported_features': 0,
'translation_key': None,
'unique_id': '123456789ABC-flood:0-flood',
'unit_of_measurement': None,
})
# ---
# name: test_rpc_flood_entities[binary_sensor.test_name_flood-state]
StateSnapshot({
'attributes': ReadOnlyDict({
'device_class': 'moisture',
'friendly_name': 'Test name flood',
}),
'context': <ANY>,
'entity_id': 'binary_sensor.test_name_flood',
'last_changed': <ANY>,
'last_reported': <ANY>,
'last_updated': <ANY>,
'state': 'off',
})
# ---
# name: test_rpc_flood_entities[binary_sensor.test_name_mute-entry]
EntityRegistryEntrySnapshot({
'aliases': set({
}),
'area_id': None,
'capabilities': None,
'config_entry_id': <ANY>,
'device_class': None,
'device_id': <ANY>,
'disabled_by': None,
'domain': 'binary_sensor',
'entity_category': <EntityCategory.DIAGNOSTIC: 'diagnostic'>,
'entity_id': 'binary_sensor.test_name_mute',
'has_entity_name': False,
'hidden_by': None,
'icon': None,
'id': <ANY>,
'labels': set({
}),
'name': None,
'options': dict({
}),
'original_device_class': None,
'original_icon': None,
'original_name': 'Test name mute',
'platform': 'shelly',
'previous_unique_id': None,
'supported_features': 0,
'translation_key': None,
'unique_id': '123456789ABC-flood:0-mute',
'unit_of_measurement': None,
})
# ---
# name: test_rpc_flood_entities[binary_sensor.test_name_mute-state]
StateSnapshot({
'attributes': ReadOnlyDict({
'friendly_name': 'Test name mute',
}),
'context': <ANY>,
'entity_id': 'binary_sensor.test_name_mute',
'last_changed': <ANY>,
'last_reported': <ANY>,
'last_updated': <ANY>,
'state': 'off',
})
# ---

View File

@ -496,3 +496,22 @@ async def test_blu_trv_binary_sensor_entity(
entry = entity_registry.async_get(entity_id)
assert entry == snapshot(name=f"{entity_id}-entry")
async def test_rpc_flood_entities(
hass: HomeAssistant,
mock_rpc_device: Mock,
entity_registry: EntityRegistry,
snapshot: SnapshotAssertion,
) -> None:
"""Test RPC flood sensor entities."""
await init_integration(hass, 4)
for entity in ("flood", "mute"):
entity_id = f"{BINARY_SENSOR_DOMAIN}.test_name_{entity}"
state = hass.states.get(entity_id)
assert state == snapshot(name=f"{entity_id}-state")
entry = entity_registry.async_get(entity_id)
assert entry == snapshot(name=f"{entity_id}-entry")

View File

@ -18,7 +18,8 @@ from tests.common import (
load_json_object_fixture,
)
MOCK_HOST = "slzb-06.local"
MOCK_DEVICE_NAME = "slzb-06"
MOCK_HOST = "192.168.1.161"
MOCK_USERNAME = "test-user"
MOCK_PASSWORD = "test-pass"

View File

@ -3,7 +3,7 @@
DeviceRegistryEntrySnapshot({
'area_id': None,
'config_entries': <ANY>,
'configuration_url': 'http://slzb-06.local',
'configuration_url': 'http://192.168.1.161',
'connections': set({
tuple(
'mac',

View File

@ -8,19 +8,20 @@ from pysmlight.exceptions import SmlightAuthError, SmlightConnectionError
import pytest
from homeassistant.components.smlight.const import DOMAIN
from homeassistant.config_entries import SOURCE_USER, SOURCE_ZEROCONF
from homeassistant.config_entries import SOURCE_DHCP, SOURCE_USER, SOURCE_ZEROCONF
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers.service_info.dhcp import DhcpServiceInfo
from homeassistant.helpers.service_info.zeroconf import ZeroconfServiceInfo
from .conftest import MOCK_HOST, MOCK_PASSWORD, MOCK_USERNAME
from .conftest import MOCK_DEVICE_NAME, MOCK_HOST, MOCK_PASSWORD, MOCK_USERNAME
from tests.common import MockConfigEntry
DISCOVERY_INFO = ZeroconfServiceInfo(
ip_address=ip_address("127.0.0.1"),
ip_addresses=[ip_address("127.0.0.1")],
ip_address=ip_address("192.168.1.161"),
ip_addresses=[ip_address("192.168.1.161")],
hostname="slzb-06.local.",
name="mock_name",
port=6638,
@ -29,8 +30,8 @@ DISCOVERY_INFO = ZeroconfServiceInfo(
)
DISCOVERY_INFO_LEGACY = ZeroconfServiceInfo(
ip_address=ip_address("127.0.0.1"),
ip_addresses=[ip_address("127.0.0.1")],
ip_address=ip_address("192.168.1.161"),
ip_addresses=[ip_address("192.168.1.161")],
hostname="slzb-06.local.",
name="mock_name",
port=6638,
@ -52,7 +53,7 @@ async def test_user_flow(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> No
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_HOST: MOCK_HOST,
CONF_HOST: "slzb-06p7.local",
},
)
@ -76,7 +77,7 @@ async def test_zeroconf_flow(
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=DISCOVERY_INFO
)
assert result["description_placeholders"] == {"host": MOCK_HOST}
assert result["description_placeholders"] == {"host": MOCK_DEVICE_NAME}
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "confirm_discovery"
@ -113,7 +114,7 @@ async def test_zeroconf_flow_auth(
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=DISCOVERY_INFO
)
assert result["description_placeholders"] == {"host": MOCK_HOST}
assert result["description_placeholders"] == {"host": MOCK_DEVICE_NAME}
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "confirm_discovery"
@ -167,7 +168,7 @@ async def test_zeroconf_unsupported_abort(
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=DISCOVERY_INFO
)
assert result["description_placeholders"] == {"host": MOCK_HOST}
assert result["description_placeholders"] == {"host": MOCK_DEVICE_NAME}
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "confirm_discovery"
@ -489,7 +490,7 @@ async def test_zeroconf_legacy_mac(
data=DISCOVERY_INFO_LEGACY,
)
assert result["description_placeholders"] == {"host": MOCK_HOST}
assert result["description_placeholders"] == {"host": MOCK_DEVICE_NAME}
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={}
@ -507,6 +508,76 @@ async def test_zeroconf_legacy_mac(
assert len(mock_smlight_client.get_info.mock_calls) == 3
@pytest.mark.usefixtures("mock_smlight_client")
async def test_zeroconf_updates_host(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test zeroconf discovery updates host ip."""
mock_config_entry.add_to_hass(hass)
service_info = DISCOVERY_INFO
service_info.ip_address = ip_address("192.168.1.164")
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=service_info
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "already_configured"
assert mock_config_entry.data[CONF_HOST] == "192.168.1.164"
@pytest.mark.usefixtures("mock_smlight_client")
async def test_dhcp_discovery_updates_host(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test dhcp discovery updates host ip."""
mock_config_entry.add_to_hass(hass)
service_info = DhcpServiceInfo(
ip="192.168.1.164",
hostname="slzb-06",
macaddress="aabbccddeeff",
)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_DHCP}, data=service_info
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "already_configured"
assert mock_config_entry.data[CONF_HOST] == "192.168.1.164"
@pytest.mark.usefixtures("mock_smlight_client")
async def test_dhcp_discovery_aborts(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test dhcp discovery updates host ip."""
mock_config_entry.add_to_hass(hass)
service_info = DhcpServiceInfo(
ip="192.168.1.161",
hostname="slzb-06",
macaddress="000000000000",
)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_DHCP}, data=service_info
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "already_configured"
assert mock_config_entry.data[CONF_HOST] == "192.168.1.161"
async def test_reauth_flow(
hass: HomeAssistant,
mock_smlight_client: MagicMock,

View File

@ -153,3 +153,25 @@ async def humidifier_config_entry(
await hass.async_block_till_done()
return entry
@pytest.fixture(name="switch_old_id_config_entry")
async def switch_old_id_config_entry(
hass: HomeAssistant, requests_mock: requests_mock.Mocker, config
) -> MockConfigEntry:
"""Create a mock VeSync config entry for `switch` with the old unique ID approach."""
entry = MockConfigEntry(
title="VeSync",
domain=DOMAIN,
data=config[DOMAIN],
version=1,
minor_version=1,
)
entry.add_to_hass(hass)
wall_switch = "Wall Switch"
humidifer = "Humidifier 200s"
mock_multiple_device_responses(requests_mock, [wall_switch, humidifer])
return entry

View File

@ -128,7 +128,7 @@
'sleep',
'manual',
]),
'mode': None,
'mode': 'humidity',
'night_light': True,
'pid': None,
'speed': None,
@ -160,6 +160,30 @@
# ---
# name: test_async_get_device_diagnostics__single_fan
dict({
'_config_dict': dict({
'features': list([
'air_quality',
]),
'levels': list([
1,
2,
]),
'models': list([
'LV-PUR131S',
'LV-RH131S',
]),
'modes': list([
'manual',
'auto',
'sleep',
'off',
]),
'module': 'VeSyncAir131',
}),
'_features': list([
'air_quality',
]),
'air_quality_feature': True,
'cid': 'abcdefghabcdefghabcdefghabcdefgh',
'config': dict({
}),
@ -180,6 +204,7 @@
'device_region': 'US',
'device_status': 'unknown',
'device_type': 'LV-PUR131S',
'enabled': True,
'extension': None,
'home_assistant': dict({
'disabled': False,
@ -271,6 +296,12 @@
'mac_id': '**REDACTED**',
'manager': '**REDACTED**',
'mode': None,
'modes': list([
'manual',
'auto',
'sleep',
'off',
]),
'pid': None,
'speed': None,
'sub_device_no': None,

View File

@ -367,7 +367,7 @@
'previous_unique_id': None,
'supported_features': 0,
'translation_key': None,
'unique_id': 'outlet',
'unique_id': 'outlet-device_status',
'unit_of_measurement': None,
}),
])
@ -525,7 +525,7 @@
'previous_unique_id': None,
'supported_features': 0,
'translation_key': None,
'unique_id': 'switch',
'unique_id': 'switch-device_status',
'unit_of_measurement': None,
}),
])

View File

@ -10,6 +10,9 @@ from homeassistant.components.vesync.const import DOMAIN, VS_DEVICES, VS_MANAGER
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from tests.common import MockConfigEntry
async def test_async_setup_entry__not_login(
@ -125,3 +128,55 @@ async def test_async_new_device_discovery(
assert manager.login.call_count == 1
assert hass.data[DOMAIN][VS_MANAGER] == manager
assert hass.data[DOMAIN][VS_DEVICES] == [fan, humidifier]
async def test_migrate_config_entry(
hass: HomeAssistant,
switch_old_id_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration of config entry. Only migrates switches to a new unique_id."""
switch: er.RegistryEntry = entity_registry.async_get_or_create(
domain="switch",
platform="vesync",
unique_id="switch",
config_entry=switch_old_id_config_entry,
suggested_object_id="switch",
)
humidifer: er.RegistryEntry = entity_registry.async_get_or_create(
domain="humidifer",
platform="vesync",
unique_id="humidifer",
config_entry=switch_old_id_config_entry,
suggested_object_id="humidifer",
)
assert switch.unique_id == "switch"
assert switch_old_id_config_entry.minor_version == 1
assert humidifer.unique_id == "humidifer"
await hass.config_entries.async_setup(switch_old_id_config_entry.entry_id)
await hass.async_block_till_done()
assert switch_old_id_config_entry.minor_version == 2
migrated_switch = entity_registry.async_get(switch.entity_id)
assert migrated_switch is not None
assert migrated_switch.entity_id.startswith("switch")
assert migrated_switch.unique_id == "switch-device_status"
# Confirm humidifer was not impacted
migrated_humidifer = entity_registry.async_get(humidifer.entity_id)
assert migrated_humidifer is not None
assert migrated_humidifer.unique_id == "humidifer"
# Assert that only one entity exists in the switch domain
switch_entities = [
e for e in entity_registry.entities.values() if e.domain == "switch"
]
assert len(switch_entities) == 1
humidifer_entities = [
e for e in entity_registry.entities.values() if e.domain == "humidifer"
]
assert len(humidifer_entities) == 1