From 72fed878b4278643f78df6a4bac7886681b7c0ac Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Tue, 26 Mar 2024 16:15:20 -0500 Subject: [PATCH] Add Ollama conversation agent (#113962) * Add ollama conversation agent * Change iot class * Much better default template * Slight adjustment to prompt * Make casing consistent * Switch to ollama Python fork * Add prompt to tests * Rename to "ollama" * Download models in config flow * Update homeassistant/components/ollama/config_flow.py --------- Co-authored-by: Paulus Schoutsen --- CODEOWNERS | 2 + homeassistant/components/ollama/__init__.py | 266 +++++++++++++ .../components/ollama/config_flow.py | 245 ++++++++++++ homeassistant/components/ollama/const.py | 114 ++++++ homeassistant/components/ollama/manifest.json | 11 + homeassistant/components/ollama/models.py | 47 +++ homeassistant/components/ollama/strings.json | 33 ++ homeassistant/generated/config_flows.py | 1 + homeassistant/generated/integrations.json | 6 + requirements_all.txt | 3 + requirements_test_all.txt | 3 + tests/components/ollama/__init__.py | 14 + tests/components/ollama/conftest.py | 37 ++ tests/components/ollama/test_config_flow.py | 234 +++++++++++ tests/components/ollama/test_init.py | 366 ++++++++++++++++++ 15 files changed, 1382 insertions(+) create mode 100644 homeassistant/components/ollama/__init__.py create mode 100644 homeassistant/components/ollama/config_flow.py create mode 100644 homeassistant/components/ollama/const.py create mode 100644 homeassistant/components/ollama/manifest.json create mode 100644 homeassistant/components/ollama/models.py create mode 100644 homeassistant/components/ollama/strings.json create mode 100644 tests/components/ollama/__init__.py create mode 100644 tests/components/ollama/conftest.py create mode 100644 tests/components/ollama/test_config_flow.py create mode 100644 tests/components/ollama/test_init.py diff --git a/CODEOWNERS b/CODEOWNERS index 7ba24210f96..85603250b7c 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -933,6 +933,8 @@ build.json @home-assistant/supervisor /homeassistant/components/octoprint/ @rfleming71 /tests/components/octoprint/ @rfleming71 /homeassistant/components/ohmconnect/ @robbiet480 +/homeassistant/components/ollama/ @synesthesiam +/tests/components/ollama/ @synesthesiam /homeassistant/components/ombi/ @larssont /homeassistant/components/omnilogic/ @oliver84 @djtimca @gentoosu /tests/components/omnilogic/ @oliver84 @djtimca @gentoosu diff --git a/homeassistant/components/ollama/__init__.py b/homeassistant/components/ollama/__init__.py new file mode 100644 index 00000000000..8c9b00f3c9c --- /dev/null +++ b/homeassistant/components/ollama/__init__.py @@ -0,0 +1,266 @@ +"""The Ollama integration.""" + +from __future__ import annotations + +import asyncio +import logging +import time +from typing import Literal + +import httpx +import ollama + +from homeassistant.components import conversation +from homeassistant.components.homeassistant.exposed_entities import async_should_expose +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import CONF_URL, MATCH_ALL +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryNotReady, TemplateError +from homeassistant.helpers import ( + area_registry as ar, + config_validation as cv, + device_registry as dr, + entity_registry as er, + intent, + template, +) +from homeassistant.util import ulid + +from .const import ( + CONF_MAX_HISTORY, + CONF_MODEL, + CONF_PROMPT, + DEFAULT_MAX_HISTORY, + DEFAULT_PROMPT, + DEFAULT_TIMEOUT, + DOMAIN, + KEEP_ALIVE_FOREVER, + MAX_HISTORY_SECONDS, +) +from .models import ExposedEntity, MessageHistory, MessageRole + +_LOGGER = logging.getLogger(__name__) + +__all__ = [ + "CONF_URL", + "CONF_PROMPT", + "CONF_MODEL", + "CONF_MAX_HISTORY", + "MAX_HISTORY_NO_LIMIT", + "DOMAIN", +] + +CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) + + +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Set up Ollama from a config entry.""" + settings = {**entry.data, **entry.options} + client = ollama.AsyncClient(host=settings[CONF_URL]) + try: + async with asyncio.timeout(DEFAULT_TIMEOUT): + await client.list() + except (TimeoutError, httpx.ConnectError) as err: + raise ConfigEntryNotReady(err) from err + + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client + + conversation.async_set_agent(hass, entry, OllamaAgent(hass, entry)) + return True + + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload Ollama.""" + hass.data[DOMAIN].pop(entry.entry_id) + conversation.async_unset_agent(hass, entry) + return True + + +class OllamaAgent(conversation.AbstractConversationAgent): + """Ollama conversation agent.""" + + def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: + """Initialize the agent.""" + self.hass = hass + self.entry = entry + + # conversation id -> message history + self._history: dict[str, MessageHistory] = {} + + @property + def supported_languages(self) -> list[str] | Literal["*"]: + """Return a list of supported languages.""" + return MATCH_ALL + + async def async_process( + self, user_input: conversation.ConversationInput + ) -> conversation.ConversationResult: + """Process a sentence.""" + settings = {**self.entry.data, **self.entry.options} + + client = self.hass.data[DOMAIN][self.entry.entry_id] + conversation_id = user_input.conversation_id or ulid.ulid_now() + model = settings[CONF_MODEL] + + # Look up message history + message_history: MessageHistory | None = None + message_history = self._history.get(conversation_id) + if message_history is None: + # New history + # + # Render prompt and error out early if there's a problem + raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT) + try: + prompt = self._generate_prompt(raw_prompt) + _LOGGER.debug("Prompt: %s", prompt) + except TemplateError as err: + _LOGGER.error("Error rendering prompt: %s", err) + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem generating my prompt: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + message_history = MessageHistory( + timestamp=time.monotonic(), + messages=[ + ollama.Message(role=MessageRole.SYSTEM.value, content=prompt) + ], + ) + self._history[conversation_id] = message_history + else: + # Bump timestamp so this conversation won't get cleaned up + message_history.timestamp = time.monotonic() + + # Clean up old histories + self._prune_old_histories() + + # Trim this message history to keep a maximum number of *user* messages + max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)) + self._trim_history(message_history, max_messages) + + # Add new user message + message_history.messages.append( + ollama.Message(role=MessageRole.USER.value, content=user_input.text) + ) + + # Get response + try: + response = await client.chat( + model=model, + # Make a copy of the messages because we mutate the list later + messages=list(message_history.messages), + stream=False, + keep_alive=KEEP_ALIVE_FOREVER, + ) + except (ollama.RequestError, ollama.ResponseError) as err: + _LOGGER.error("Unexpected error talking to Ollama server: %s", err) + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem talking to the Ollama server: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + response_message = response["message"] + message_history.messages.append( + ollama.Message( + role=response_message["role"], content=response_message["content"] + ) + ) + + # Create intent response + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_speech(response_message["content"]) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + def _prune_old_histories(self) -> None: + """Remove old message histories.""" + now = time.monotonic() + self._history = { + conversation_id: message_history + for conversation_id, message_history in self._history.items() + if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS + } + + def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None: + """Trims excess messages from a single history.""" + if max_messages < 1: + # Keep all messages + return + + if message_history.num_user_messages >= max_messages: + # Trim history but keep system prompt (first message). + # Every other message should be an assistant message, so keep 2x + # message objects. + num_keep = 2 * max_messages + drop_index = len(message_history.messages) - num_keep + message_history.messages = [ + message_history.messages[0] + ] + message_history.messages[drop_index:] + + def _generate_prompt(self, raw_prompt: str) -> str: + """Generate a prompt for the user.""" + return template.Template(raw_prompt, self.hass).async_render( + { + "ha_name": self.hass.config.location_name, + "ha_language": self.hass.config.language, + "exposed_entities": self._get_exposed_entities(), + }, + parse_result=False, + ) + + def _get_exposed_entities(self) -> list[ExposedEntity]: + """Get state list of exposed entities.""" + area_registry = ar.async_get(self.hass) + entity_registry = er.async_get(self.hass) + device_registry = dr.async_get(self.hass) + + exposed_entities = [] + exposed_states = [ + state + for state in self.hass.states.async_all() + if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id) + ] + + for state in exposed_states: + entity = entity_registry.async_get(state.entity_id) + names = [state.name] + area_names = [] + + if entity is not None: + # Add aliases + names.extend(entity.aliases) + if entity.area_id and ( + area := area_registry.async_get_area(entity.area_id) + ): + # Entity is in area + area_names.append(area.name) + area_names.extend(area.aliases) + elif entity.device_id and ( + device := device_registry.async_get(entity.device_id) + ): + # Check device area + if device.area_id and ( + area := area_registry.async_get_area(device.area_id) + ): + area_names.append(area.name) + area_names.extend(area.aliases) + + exposed_entities.append( + ExposedEntity( + entity_id=state.entity_id, + state=state, + names=names, + area_names=area_names, + ) + ) + + return exposed_entities diff --git a/homeassistant/components/ollama/config_flow.py b/homeassistant/components/ollama/config_flow.py new file mode 100644 index 00000000000..50d0667803f --- /dev/null +++ b/homeassistant/components/ollama/config_flow.py @@ -0,0 +1,245 @@ +"""Config flow for Ollama integration.""" + +from __future__ import annotations + +import asyncio +import logging +import sys +from types import MappingProxyType +from typing import Any + +import httpx +import ollama +import voluptuous as vol + +from homeassistant.config_entries import ( + ConfigEntry, + ConfigFlow, + ConfigFlowResult, + OptionsFlow, +) +from homeassistant.const import CONF_URL +from homeassistant.helpers.selector import ( + NumberSelector, + NumberSelectorConfig, + NumberSelectorMode, + SelectOptionDict, + SelectSelector, + SelectSelectorConfig, + TemplateSelector, + TextSelector, + TextSelectorConfig, + TextSelectorType, +) + +from .const import ( + CONF_MAX_HISTORY, + CONF_MODEL, + CONF_PROMPT, + DEFAULT_MAX_HISTORY, + DEFAULT_MODEL, + DEFAULT_PROMPT, + DEFAULT_TIMEOUT, + DOMAIN, + MODEL_NAMES, +) + +_LOGGER = logging.getLogger(__name__) + + +STEP_USER_DATA_SCHEMA = vol.Schema( + { + vol.Required(CONF_URL): TextSelector( + TextSelectorConfig(type=TextSelectorType.URL) + ), + } +) + + +class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): + """Handle a config flow for Ollama.""" + + VERSION = 1 + + def __init__(self) -> None: + """Initialize config flow.""" + self.url: str | None = None + self.model: str | None = None + self.client: ollama.AsyncClient | None = None + self.download_task: asyncio.Task | None = None + + async def async_step_user( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Handle the initial step.""" + user_input = user_input or {} + self.url = user_input.get(CONF_URL, self.url) + self.model = user_input.get(CONF_MODEL, self.model) + + if self.url is None: + return self.async_show_form( + step_id="user", data_schema=STEP_USER_DATA_SCHEMA, last_step=False + ) + + errors = {} + + try: + self.client = ollama.AsyncClient(host=self.url) + async with asyncio.timeout(DEFAULT_TIMEOUT): + response = await self.client.list() + + downloaded_models: set[str] = { + model_info["model"] for model_info in response.get("models", []) + } + except (TimeoutError, httpx.ConnectError): + errors["base"] = "cannot_connect" + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Unexpected exception") + errors["base"] = "unknown" + + if errors: + return self.async_show_form( + step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors + ) + + if self.model is None: + # Show models that have been downloaded first, followed by all known + # models (only latest tags). + models_to_list = [ + SelectOptionDict(label=f"{m} (downloaded)", value=m) + for m in sorted(downloaded_models) + ] + [ + SelectOptionDict(label=m, value=f"{m}:latest") + for m in sorted(MODEL_NAMES) + if m not in downloaded_models + ] + model_step_schema = vol.Schema( + { + vol.Required( + CONF_MODEL, description={"suggested_value": DEFAULT_MODEL} + ): SelectSelector( + SelectSelectorConfig(options=models_to_list, custom_value=True) + ), + } + ) + + return self.async_show_form( + step_id="user", + data_schema=model_step_schema, + ) + + if self.model not in downloaded_models: + # Ollama server needs to download model first + return await self.async_step_download() + + return self.async_create_entry( + title=_get_title(self.model), + data={CONF_URL: self.url, CONF_MODEL: self.model}, + ) + + async def async_step_download( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Step to wait for Ollama server to download a model.""" + assert self.model is not None + assert self.client is not None + + if self.download_task is None: + # Tell Ollama server to pull the model. + # The task will block until the model and metadata are fully + # downloaded. + self.download_task = self.hass.async_create_background_task( + self.client.pull(self.model), f"Downloading {self.model}" + ) + + if self.download_task.done(): + if err := self.download_task.exception(): + _LOGGER.exception("Unexpected error while downloading model: %s", err) + return self.async_show_progress_done(next_step_id="failed") + + return self.async_show_progress_done(next_step_id="finish") + + return self.async_show_progress( + step_id="download", + progress_action="download", + progress_task=self.download_task, + ) + + async def async_step_finish( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Step after model downloading has succeeded.""" + assert self.url is not None + assert self.model is not None + + return self.async_create_entry( + title=_get_title(self.model), + data={CONF_URL: self.url, CONF_MODEL: self.model}, + ) + + async def async_step_failed( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Step after model downloading has failed.""" + return self.async_abort(reason="download_failed") + + @staticmethod + def async_get_options_flow( + config_entry: ConfigEntry, + ) -> OptionsFlow: + """Create the options flow.""" + return OllamaOptionsFlow(config_entry) + + +class OllamaOptionsFlow(OptionsFlow): + """Ollama options flow.""" + + def __init__(self, config_entry: ConfigEntry) -> None: + """Initialize options flow.""" + self.config_entry = config_entry + self.url: str = self.config_entry.data[CONF_URL] + self.model: str = self.config_entry.data[CONF_MODEL] + + async def async_step_init( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Manage the options.""" + if user_input is not None: + return self.async_create_entry( + title=_get_title(self.model), data=user_input + ) + + options = self.config_entry.options or MappingProxyType({}) + schema = ollama_config_option_schema(options) + return self.async_show_form( + step_id="init", + data_schema=vol.Schema(schema), + ) + + +def ollama_config_option_schema(options: MappingProxyType[str, Any]) -> dict: + """Ollama options schema.""" + return { + vol.Optional( + CONF_PROMPT, + description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)}, + ): TemplateSelector(), + vol.Optional( + CONF_MAX_HISTORY, + description={ + "suggested_value": options.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY) + }, + ): NumberSelector( + NumberSelectorConfig( + min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX + ) + ), + } + + +def _get_title(model: str) -> str: + """Get title for config entry.""" + if model.endswith(":latest"): + model = model.split(":", maxsplit=1)[0] + + return model diff --git a/homeassistant/components/ollama/const.py b/homeassistant/components/ollama/const.py new file mode 100644 index 00000000000..59f1888cfc7 --- /dev/null +++ b/homeassistant/components/ollama/const.py @@ -0,0 +1,114 @@ +"""Constants for the Ollama integration.""" + +DOMAIN = "ollama" + +CONF_MODEL = "model" +CONF_PROMPT = "prompt" +DEFAULT_PROMPT = """{%- set used_domains = set([ + "binary_sensor", + "climate", + "cover", + "fan", + "light", + "lock", + "sensor", + "switch", + "weather", +]) %} +{%- set used_attributes = set([ + "temperature", + "current_temperature", + "temperature_unit", + "brightness", + "humidity", + "unit_of_measurement", + "device_class", + "current_position", + "percentage", +]) %} + +This smart home is controlled by Home Assistant. +The current time is {{ now().strftime("%X") }}. +Today's date is {{ now().strftime("%x") }}. + +An overview of the areas and the devices in this smart home: +```yaml +{%- for entity in exposed_entities: %} +{%- if entity.domain not in used_domains: %} + {%- continue %} +{%- endif %} + +- domain: {{ entity.domain }} +{%- if entity.names | length == 1: %} + name: {{ entity.names[0] }} +{%- else: %} + names: +{%- for name in entity.names: %} + - {{ name }} +{%- endfor %} +{%- endif %} +{%- if entity.area_names | length == 1: %} + area: {{ entity.area_names[0] }} +{%- elif entity.area_names: %} + areas: +{%- for area_name in entity.area_names: %} + - {{ area_name }} +{%- endfor %} +{%- endif %} + state: {{ entity.state.state }} + {%- set attributes_key_printed = False %} +{%- for attr_name, attr_value in entity.state.attributes.items(): %} + {%- if attr_name in used_attributes: %} + {%- if not attributes_key_printed: %} + attributes: + {%- set attributes_key_printed = True %} + {%- endif %} + {{ attr_name }}: {{ attr_value }} + {%- endif %} +{%- endfor %} +{%- endfor %} +``` + +Answer the user's questions using the information about this smart home. +Keep your answers brief and do not apologize.""" + +KEEP_ALIVE_FOREVER = -1 +DEFAULT_TIMEOUT = 5.0 # seconds + +CONF_MAX_HISTORY = "max_history" +DEFAULT_MAX_HISTORY = 20 + +MAX_HISTORY_SECONDS = 60 * 60 # 1 hour + +MODEL_NAMES = [ # https://ollama.com/library + "gemma", + "llama2", + "mistral", + "mixtral", + "llava", + "neural-chat", + "codellama", + "dolphin-mixtral", + "qwen", + "llama2-uncensored", + "mistral-openorca", + "deepseek-coder", + "nous-hermes2", + "phi", + "orca-mini", + "dolphin-mistral", + "wizard-vicuna-uncensored", + "vicuna", + "tinydolphin", + "llama2-chinese", + "nomic-embed-text", + "openhermes", + "zephyr", + "tinyllama", + "openchat", + "wizardcoder", + "starcoder", + "phind-codellama", + "starcoder2", +] +DEFAULT_MODEL = "llama2:latest" diff --git a/homeassistant/components/ollama/manifest.json b/homeassistant/components/ollama/manifest.json new file mode 100644 index 00000000000..6b16ae667f1 --- /dev/null +++ b/homeassistant/components/ollama/manifest.json @@ -0,0 +1,11 @@ +{ + "domain": "ollama", + "name": "Ollama", + "codeowners": ["@synesthesiam"], + "config_flow": true, + "dependencies": ["conversation"], + "documentation": "https://www.home-assistant.io/integrations/ollama", + "integration_type": "service", + "iot_class": "local_polling", + "requirements": ["ollama-hass==0.1.7"] +} diff --git a/homeassistant/components/ollama/models.py b/homeassistant/components/ollama/models.py new file mode 100644 index 00000000000..ce0f858bb8c --- /dev/null +++ b/homeassistant/components/ollama/models.py @@ -0,0 +1,47 @@ +"""Models for Ollama integration.""" + +from dataclasses import dataclass +from enum import StrEnum +from functools import cached_property + +import ollama + +from homeassistant.core import State + + +class MessageRole(StrEnum): + """Role of a chat message.""" + + SYSTEM = "system" # prompt + USER = "user" + + +@dataclass +class MessageHistory: + """Chat message history.""" + + timestamp: float + """Timestamp of last use in seconds.""" + + messages: list[ollama.Message] + """List of message history, including system prompt and assistant responses.""" + + @property + def num_user_messages(self) -> int: + """Return a count of user messages.""" + return sum(m["role"] == MessageRole.USER for m in self.messages) + + +@dataclass(frozen=True) +class ExposedEntity: + """Relevant information about an exposed entity.""" + + entity_id: str + state: State + names: list[str] + area_names: list[str] + + @cached_property + def domain(self) -> str: + """Get domain from entity id.""" + return self.entity_id.split(".", maxsplit=1)[0] diff --git a/homeassistant/components/ollama/strings.json b/homeassistant/components/ollama/strings.json new file mode 100644 index 00000000000..59f48929681 --- /dev/null +++ b/homeassistant/components/ollama/strings.json @@ -0,0 +1,33 @@ +{ + "config": { + "step": { + "user": { + "data": { + "url": "[%key:common::config_flow::data::url%]", + "model": "Model" + } + }, + "download": { + "title": "Downloading model" + } + }, + "error": { + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", + "download_failed": "Model downloading failed", + "unknown": "[%key:common::config_flow::error::unknown%]" + }, + "progress": { + "download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details." + } + }, + "options": { + "step": { + "init": { + "data": { + "prompt": "Prompt template", + "max_history": "Max history messages" + } + } + } + } +} diff --git a/homeassistant/generated/config_flows.py b/homeassistant/generated/config_flows.py index d779fbead64..8d46c8be240 100644 --- a/homeassistant/generated/config_flows.py +++ b/homeassistant/generated/config_flows.py @@ -360,6 +360,7 @@ FLOWS = { "nzbget", "obihai", "octoprint", + "ollama", "omnilogic", "oncue", "ondilo_ico", diff --git a/homeassistant/generated/integrations.json b/homeassistant/generated/integrations.json index 2b4a637dacc..6cba84431f3 100644 --- a/homeassistant/generated/integrations.json +++ b/homeassistant/generated/integrations.json @@ -4136,6 +4136,12 @@ "config_flow": false, "iot_class": "cloud_polling" }, + "ollama": { + "name": "Ollama", + "integration_type": "service", + "config_flow": true, + "iot_class": "local_polling" + }, "ombi": { "name": "Ombi", "integration_type": "hub", diff --git a/requirements_all.txt b/requirements_all.txt index 3487723b62a..4fef44d80b2 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1436,6 +1436,9 @@ odp-amsterdam==6.0.1 # homeassistant.components.oem oemthermostat==1.1.1 +# homeassistant.components.ollama +ollama-hass==0.1.7 + # homeassistant.components.omnilogic omnilogic==0.4.5 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 2f8aaafee0a..75ba113891e 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -1148,6 +1148,9 @@ objgraph==3.5.0 # homeassistant.components.garages_amsterdam odp-amsterdam==6.0.1 +# homeassistant.components.ollama +ollama-hass==0.1.7 + # homeassistant.components.omnilogic omnilogic==0.4.5 diff --git a/tests/components/ollama/__init__.py b/tests/components/ollama/__init__.py new file mode 100644 index 00000000000..22a576e94a4 --- /dev/null +++ b/tests/components/ollama/__init__.py @@ -0,0 +1,14 @@ +"""Tests for the Ollama integration.""" + +from homeassistant.components import ollama +from homeassistant.components.ollama.const import DEFAULT_PROMPT + +TEST_USER_DATA = { + ollama.CONF_URL: "http://localhost:11434", + ollama.CONF_MODEL: "test model", +} + +TEST_OPTIONS = { + ollama.CONF_PROMPT: DEFAULT_PROMPT, + ollama.CONF_MAX_HISTORY: 2, +} diff --git a/tests/components/ollama/conftest.py b/tests/components/ollama/conftest.py new file mode 100644 index 00000000000..78ecf0766d7 --- /dev/null +++ b/tests/components/ollama/conftest.py @@ -0,0 +1,37 @@ +"""Tests Ollama integration.""" + +from unittest.mock import patch + +import pytest + +from homeassistant.components import ollama +from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component + +from . import TEST_OPTIONS, TEST_USER_DATA + +from tests.common import MockConfigEntry + + +@pytest.fixture +def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: + """Mock a config entry.""" + entry = MockConfigEntry( + domain=ollama.DOMAIN, + data=TEST_USER_DATA, + options=TEST_OPTIONS, + ) + entry.add_to_hass(hass) + return entry + + +@pytest.fixture +async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfigEntry): + """Initialize integration.""" + assert await async_setup_component(hass, "homeassistant", {}) + + with patch( + "ollama.AsyncClient.list", + ): + assert await async_setup_component(hass, ollama.DOMAIN, {}) + await hass.async_block_till_done() diff --git a/tests/components/ollama/test_config_flow.py b/tests/components/ollama/test_config_flow.py new file mode 100644 index 00000000000..825f3eac436 --- /dev/null +++ b/tests/components/ollama/test_config_flow.py @@ -0,0 +1,234 @@ +"""Test the Ollama config flow.""" + +import asyncio +from unittest.mock import patch + +from httpx import ConnectError +import pytest + +from homeassistant import config_entries +from homeassistant.components import ollama +from homeassistant.core import HomeAssistant +from homeassistant.data_entry_flow import FlowResultType + +from tests.common import MockConfigEntry + +TEST_MODEL = "test_model:latest" + + +async def test_form(hass: HomeAssistant) -> None: + """Test flow when the model is already downloaded.""" + # Pretend we already set up a config entry. + hass.config.components.add(ollama.DOMAIN) + MockConfigEntry( + domain=ollama.DOMAIN, + state=config_entries.ConfigEntryState.LOADED, + ).add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + ollama.DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + assert result["errors"] is None + + with ( + patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", + # test model is already "downloaded" + return_value={"models": [{"model": TEST_MODEL}]}, + ), + patch( + "homeassistant.components.ollama.async_setup_entry", + return_value=True, + ) as mock_setup_entry, + ): + # Step 1: URL + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"} + ) + await hass.async_block_till_done() + + # Step 2: model + assert result2["type"] == FlowResultType.FORM + result3 = await hass.config_entries.flow.async_configure( + result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL} + ) + await hass.async_block_till_done() + + assert result3["type"] == FlowResultType.CREATE_ENTRY + assert result3["data"] == { + ollama.CONF_URL: "http://localhost:11434", + ollama.CONF_MODEL: TEST_MODEL, + } + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_form_need_download(hass: HomeAssistant) -> None: + """Test flow when a model needs to be downloaded.""" + # Pretend we already set up a config entry. + hass.config.components.add(ollama.DOMAIN) + MockConfigEntry( + domain=ollama.DOMAIN, + state=config_entries.ConfigEntryState.LOADED, + ).add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + ollama.DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + assert result["errors"] is None + + pull_ready = asyncio.Event() + pull_called = asyncio.Event() + pull_model: str | None = None + + async def pull(self, model: str, *args, **kwargs) -> None: + nonlocal pull_model + + async with asyncio.timeout(1): + await pull_ready.wait() + + pull_model = model + pull_called.set() + + with ( + patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", + # No models are downloaded + return_value={}, + ), + patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull", + pull, + ), + patch( + "homeassistant.components.ollama.async_setup_entry", + return_value=True, + ) as mock_setup_entry, + ): + # Step 1: URL + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"} + ) + await hass.async_block_till_done() + + # Step 2: model + assert result2["type"] == FlowResultType.FORM + result3 = await hass.config_entries.flow.async_configure( + result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL} + ) + await hass.async_block_till_done() + + # Step 3: download + assert result3["type"] == FlowResultType.SHOW_PROGRESS + result4 = await hass.config_entries.flow.async_configure( + result3["flow_id"], + ) + await hass.async_block_till_done() + + # Run again without the task finishing. + # We should still be downloading. + assert result4["type"] == FlowResultType.SHOW_PROGRESS + result4 = await hass.config_entries.flow.async_configure( + result4["flow_id"], + ) + await hass.async_block_till_done() + assert result4["type"] == FlowResultType.SHOW_PROGRESS + + # Signal fake pull method to complete + pull_ready.set() + async with asyncio.timeout(1): + await pull_called.wait() + + assert pull_model == TEST_MODEL + + # Step 4: finish + result5 = await hass.config_entries.flow.async_configure( + result4["flow_id"], + ) + + assert result5["type"] == FlowResultType.CREATE_ENTRY + assert result5["data"] == { + ollama.CONF_URL: "http://localhost:11434", + ollama.CONF_MODEL: TEST_MODEL, + } + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_options( + hass: HomeAssistant, mock_config_entry, mock_init_component +) -> None: + """Test the options form.""" + options_flow = await hass.config_entries.options.async_init( + mock_config_entry.entry_id + ) + options = await hass.config_entries.options.async_configure( + options_flow["flow_id"], + {ollama.CONF_PROMPT: "test prompt", ollama.CONF_MAX_HISTORY: 100}, + ) + await hass.async_block_till_done() + assert options["type"] == FlowResultType.CREATE_ENTRY + assert options["data"] == { + ollama.CONF_PROMPT: "test prompt", + ollama.CONF_MAX_HISTORY: 100, + } + + +@pytest.mark.parametrize( + ("side_effect", "error"), + [ + (ConnectError(message=""), "cannot_connect"), + (RuntimeError(), "unknown"), + ], +) +async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None: + """Test we handle errors.""" + result = await hass.config_entries.flow.async_init( + ollama.DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + with patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", + side_effect=side_effect, + ): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"} + ) + + assert result2["type"] == FlowResultType.FORM + assert result2["errors"] == {"base": error} + + +async def test_download_error(hass: HomeAssistant) -> None: + """Test we handle errors while downloading a model.""" + result = await hass.config_entries.flow.async_init( + ollama.DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + with ( + patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", + return_value={}, + ), + patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull", + side_effect=RuntimeError(), + ), + ): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"} + ) + await hass.async_block_till_done() + + assert result2["type"] == FlowResultType.FORM + result3 = await hass.config_entries.flow.async_configure( + result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL} + ) + await hass.async_block_till_done() + + assert result3["type"] == FlowResultType.SHOW_PROGRESS + result4 = await hass.config_entries.flow.async_configure(result3["flow_id"]) + await hass.async_block_till_done() + + assert result4["type"] == FlowResultType.ABORT + assert result4["reason"] == "download_failed" diff --git a/tests/components/ollama/test_init.py b/tests/components/ollama/test_init.py new file mode 100644 index 00000000000..ffe69ca4628 --- /dev/null +++ b/tests/components/ollama/test_init.py @@ -0,0 +1,366 @@ +"""Tests for the Ollama integration.""" + +from unittest.mock import AsyncMock, patch + +from httpx import ConnectError +from ollama import Message, ResponseError +import pytest + +from homeassistant.components import conversation, ollama +from homeassistant.components.homeassistant.exposed_entities import async_expose_entity +from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL +from homeassistant.core import Context, HomeAssistant +from homeassistant.helpers import ( + area_registry as ar, + device_registry as dr, + entity_registry as er, + intent, +) +from homeassistant.setup import async_setup_component + +from tests.common import MockConfigEntry + + +async def test_chat( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + area_registry: ar.AreaRegistry, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test that the chat function is called with the appropriate arguments.""" + + # Create some areas, devices, and entities + area_kitchen = area_registry.async_get_or_create("kitchen_id") + area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen") + area_bedroom = area_registry.async_get_or_create("bedroom_id") + area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom") + area_office = area_registry.async_get_or_create("office_id") + area_office = area_registry.async_update(area_office.id, name="office") + + entry = MockConfigEntry() + entry.add_to_hass(hass) + kitchen_device = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections=set(), + identifiers={("demo", "id-1234")}, + ) + device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id) + + kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234") + kitchen_light = entity_registry.async_update_entity( + kitchen_light.entity_id, device_id=kitchen_device.id + ) + hass.states.async_set( + kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"} + ) + + bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678") + bedroom_light = entity_registry.async_update_entity( + bedroom_light.entity_id, area_id=area_bedroom.id + ) + hass.states.async_set( + bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"} + ) + + # Hide the office light + office_light = entity_registry.async_get_or_create("light", "demo", "ABCD") + office_light = entity_registry.async_update_entity( + office_light.entity_id, area_id=area_office.id + ) + hass.states.async_set( + office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"} + ) + async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False) + + with patch( + "ollama.AsyncClient.chat", + return_value={"message": {"role": "assistant", "content": "test response"}}, + ) as mock_chat: + result = await conversation.async_converse( + hass, + "test message", + None, + Context(), + agent_id=mock_config_entry.entry_id, + ) + + assert mock_chat.call_count == 1 + args = mock_chat.call_args.kwargs + prompt = args["messages"][0]["content"] + + assert args["model"] == "test model" + assert args["messages"] == [ + Message({"role": "system", "content": prompt}), + Message({"role": "user", "content": "test message"}), + ] + + # Verify only exposed devices/areas are in prompt + assert "kitchen light" in prompt + assert "bedroom light" in prompt + assert "office light" not in prompt + assert "office" not in prompt + + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + assert result.response.speech["plain"]["speech"] == "test response" + + +async def test_message_history_trimming( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test that a single message history is trimmed according to the config.""" + response_idx = 0 + + def response(*args, **kwargs) -> dict: + nonlocal response_idx + response_idx += 1 + return {"message": {"role": "assistant", "content": f"response {response_idx}"}} + + with patch( + "ollama.AsyncClient.chat", + side_effect=response, + ) as mock_chat: + # mock_init_component sets "max_history" to 2 + for i in range(5): + result = await conversation.async_converse( + hass, + f"message {i+1}", + conversation_id="1234", + context=Context(), + agent_id=mock_config_entry.entry_id, + ) + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + + assert mock_chat.call_count == 5 + args = mock_chat.call_args_list + prompt = args[0].kwargs["messages"][0]["content"] + + # system + user-1 + assert len(args[0].kwargs["messages"]) == 2 + assert args[0].kwargs["messages"][1]["content"] == "message 1" + + # Full history + # system + user-1 + assistant-1 + user-2 + assert len(args[1].kwargs["messages"]) == 4 + assert args[1].kwargs["messages"][0]["role"] == "system" + assert args[1].kwargs["messages"][0]["content"] == prompt + assert args[1].kwargs["messages"][1]["role"] == "user" + assert args[1].kwargs["messages"][1]["content"] == "message 1" + assert args[1].kwargs["messages"][2]["role"] == "assistant" + assert args[1].kwargs["messages"][2]["content"] == "response 1" + assert args[1].kwargs["messages"][3]["role"] == "user" + assert args[1].kwargs["messages"][3]["content"] == "message 2" + + # Full history + # system + user-1 + assistant-1 + user-2 + assistant-2 + user-3 + assert len(args[2].kwargs["messages"]) == 6 + assert args[2].kwargs["messages"][0]["role"] == "system" + assert args[2].kwargs["messages"][0]["content"] == prompt + assert args[2].kwargs["messages"][1]["role"] == "user" + assert args[2].kwargs["messages"][1]["content"] == "message 1" + assert args[2].kwargs["messages"][2]["role"] == "assistant" + assert args[2].kwargs["messages"][2]["content"] == "response 1" + assert args[2].kwargs["messages"][3]["role"] == "user" + assert args[2].kwargs["messages"][3]["content"] == "message 2" + assert args[2].kwargs["messages"][4]["role"] == "assistant" + assert args[2].kwargs["messages"][4]["content"] == "response 2" + assert args[2].kwargs["messages"][5]["role"] == "user" + assert args[2].kwargs["messages"][5]["content"] == "message 3" + + # Trimmed down to two user messages. + # system + user-2 + assistant-2 + user-3 + assistant-3 + user-4 + assert len(args[3].kwargs["messages"]) == 6 + assert args[3].kwargs["messages"][0]["role"] == "system" + assert args[3].kwargs["messages"][0]["content"] == prompt + assert args[3].kwargs["messages"][1]["role"] == "user" + assert args[3].kwargs["messages"][1]["content"] == "message 2" + assert args[3].kwargs["messages"][2]["role"] == "assistant" + assert args[3].kwargs["messages"][2]["content"] == "response 2" + assert args[3].kwargs["messages"][3]["role"] == "user" + assert args[3].kwargs["messages"][3]["content"] == "message 3" + assert args[3].kwargs["messages"][4]["role"] == "assistant" + assert args[3].kwargs["messages"][4]["content"] == "response 3" + assert args[3].kwargs["messages"][5]["role"] == "user" + assert args[3].kwargs["messages"][5]["content"] == "message 4" + + # Trimmed down to two user messages. + # system + user-3 + assistant-3 + user-4 + assistant-4 + user-5 + assert len(args[3].kwargs["messages"]) == 6 + assert args[4].kwargs["messages"][0]["role"] == "system" + assert args[4].kwargs["messages"][0]["content"] == prompt + assert args[4].kwargs["messages"][1]["role"] == "user" + assert args[4].kwargs["messages"][1]["content"] == "message 3" + assert args[4].kwargs["messages"][2]["role"] == "assistant" + assert args[4].kwargs["messages"][2]["content"] == "response 3" + assert args[4].kwargs["messages"][3]["role"] == "user" + assert args[4].kwargs["messages"][3]["content"] == "message 4" + assert args[4].kwargs["messages"][4]["role"] == "assistant" + assert args[4].kwargs["messages"][4]["content"] == "response 4" + assert args[4].kwargs["messages"][5]["role"] == "user" + assert args[4].kwargs["messages"][5]["content"] == "message 5" + + +async def test_message_history_pruning( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test that old message histories are pruned.""" + with patch( + "ollama.AsyncClient.chat", + return_value={"message": {"role": "assistant", "content": "test response"}}, + ): + # Create 3 different message histories + conversation_ids: list[str] = [] + for i in range(3): + result = await conversation.async_converse( + hass, + f"message {i+1}", + conversation_id=None, + context=Context(), + agent_id=mock_config_entry.entry_id, + ) + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + assert isinstance(result.conversation_id, str) + conversation_ids.append(result.conversation_id) + + agent = await conversation._get_agent_manager(hass).async_get_agent( + mock_config_entry.entry_id + ) + assert isinstance(agent, ollama.OllamaAgent) + assert len(agent._history) == 3 + assert agent._history.keys() == set(conversation_ids) + + # Modify the timestamps of the first 2 histories so they will be pruned + # on the next cycle. + for conversation_id in conversation_ids[:2]: + # Move back 2 hours + agent._history[conversation_id].timestamp -= 2 * 60 * 60 + + # Next cycle + result = await conversation.async_converse( + hass, + "test message", + conversation_id=None, + context=Context(), + agent_id=mock_config_entry.entry_id, + ) + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + + # Only the most recent histories should remain + assert len(agent._history) == 2 + assert conversation_ids[-1] in agent._history + assert result.conversation_id in agent._history + + +async def test_message_history_unlimited( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test that message history is not trimmed when max_history = 0.""" + conversation_id = "1234" + with ( + patch( + "ollama.AsyncClient.chat", + return_value={"message": {"role": "assistant", "content": "test response"}}, + ), + patch.object(mock_config_entry, "options", {ollama.CONF_MAX_HISTORY: 0}), + ): + for i in range(100): + result = await conversation.async_converse( + hass, + f"message {i+1}", + conversation_id=conversation_id, + context=Context(), + agent_id=mock_config_entry.entry_id, + ) + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + + agent = await conversation._get_agent_manager(hass).async_get_agent( + mock_config_entry.entry_id + ) + assert isinstance(agent, ollama.OllamaAgent) + + assert len(agent._history) == 1 + assert conversation_id in agent._history + assert agent._history[conversation_id].num_user_messages == 100 + + +async def test_error_handling( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test error handling during converse.""" + with patch( + "ollama.AsyncClient.chat", + new_callable=AsyncMock, + side_effect=ResponseError("test error"), + ): + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + + +async def test_template_error( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: + """Test that template error handling works.""" + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", + }, + ) + with patch( + "ollama.AsyncClient.list", + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + + +async def test_conversation_agent( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, +) -> None: + """Test OllamaAgent.""" + agent = await conversation._get_agent_manager(hass).async_get_agent( + mock_config_entry.entry_id + ) + assert agent.supported_languages == MATCH_ALL + + +@pytest.mark.parametrize( + ("side_effect", "error"), + [ + (ConnectError(message="Connect error"), "Connect error"), + (RuntimeError("Runtime error"), "Runtime error"), + ], +) +async def test_init_error( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, caplog, side_effect, error +) -> None: + """Test initialization errors.""" + with patch( + "ollama.AsyncClient.list", + side_effect=side_effect, + ): + assert await async_setup_component(hass, ollama.DOMAIN, {}) + await hass.async_block_till_done() + assert error in caplog.text