Remove GenTextTaskType

This commit is contained in:
Paulus Schoutsen 2025-06-17 12:54:49 -04:00
parent 2be6acec03
commit a8d4caab01
7 changed files with 6 additions and 33 deletions

View File

@ -11,14 +11,13 @@ from homeassistant.helpers.typing import ConfigType
from .const import DATA_COMPONENT, DOMAIN from .const import DATA_COMPONENT, DOMAIN
from .entity import AITaskEntity from .entity import AITaskEntity
from .http import async_setup as async_setup_conversation_http from .http import async_setup as async_setup_conversation_http
from .task import GenTextTask, GenTextTaskResult, GenTextTaskType, async_generate_text from .task import GenTextTask, GenTextTaskResult, async_generate_text
__all__ = [ __all__ = [
"DOMAIN", "DOMAIN",
"AITaskEntity", "AITaskEntity",
"GenTextTask", "GenTextTask",
"GenTextTaskResult", "GenTextTaskResult",
"GenTextTaskType",
"async_generate_text", "async_generate_text",
"async_setup", "async_setup",
"async_setup_entry", "async_setup_entry",

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
from enum import StrEnum
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
@ -18,17 +17,3 @@ DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
DEFAULT_SYSTEM_PROMPT = ( DEFAULT_SYSTEM_PROMPT = (
"You are a Home Assistant expert and help users with their tasks." "You are a Home Assistant expert and help users with their tasks."
) )
class GenTextTaskType(StrEnum):
"""Generate text task types.
A task type describes the intent of the request in order to
match the right model for balance of cost and quality.
"""
GENERATE = "generate"
"""Generate content, which may target a higher quality result."""
SUMMARY = "summary"
"""Summarize existing content, which be able to use a more cost effective model."""

View File

@ -7,7 +7,7 @@ import voluptuous as vol
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from .task import GenTextTaskType, async_generate_text from .task import async_generate_text
@callback @callback
@ -21,7 +21,6 @@ def async_setup(hass: HomeAssistant) -> None:
vol.Required("type"): "ai_task/generate_text", vol.Required("type"): "ai_task/generate_text",
vol.Required("task_name"): str, vol.Required("task_name"): str,
vol.Required("entity_id"): str, vol.Required("entity_id"): str,
vol.Required("task_type"): (lambda v: GenTextTaskType(v)), # pylint: disable=unnecessary-lambda
vol.Required("instructions"): str, vol.Required("instructions"): str,
} }
) )

View File

@ -6,7 +6,7 @@ from dataclasses import dataclass
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .const import DATA_COMPONENT, GenTextTaskType from .const import DATA_COMPONENT
async def async_generate_text( async def async_generate_text(
@ -14,7 +14,6 @@ async def async_generate_text(
*, *,
task_name: str, task_name: str,
entity_id: str, entity_id: str,
task_type: GenTextTaskType,
instructions: str, instructions: str,
) -> GenTextTaskResult: ) -> GenTextTaskResult:
"""Run a task in the AI Task integration.""" """Run a task in the AI Task integration."""
@ -25,7 +24,6 @@ async def async_generate_text(
return await entity.internal_async_generate_text( return await entity.internal_async_generate_text(
GenTextTask( GenTextTask(
name=task_name, name=task_name,
type=task_type,
instructions=instructions, instructions=instructions,
) )
) )
@ -38,15 +36,12 @@ class GenTextTask:
name: str name: str
"""Name of the task.""" """Name of the task."""
type: GenTextTaskType
"""Type of the task."""
instructions: str instructions: str
"""Instructions on what needs to be done.""" """Instructions on what needs to be done."""
def __str__(self) -> str: def __str__(self) -> str:
"""Return task as a string.""" """Return task as a string."""
return f"<GenTextTask {self.type}: {id(self)}>" return f"<GenTextTask {self.name}: {id(self)}>"
@dataclass(slots=True) @dataclass(slots=True)

View File

@ -2,7 +2,7 @@
from freezegun import freeze_time from freezegun import freeze_time
from homeassistant.components.ai_task import GenTextTaskType, async_generate_text from homeassistant.components.ai_task import async_generate_text
from homeassistant.const import STATE_UNKNOWN from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -27,7 +27,6 @@ async def test_state_generate_text(
hass, hass,
task_name="Test task", task_name="Test task",
entity_id=TEST_ENTITY_ID, entity_id=TEST_ENTITY_ID,
task_type=GenTextTaskType.SUMMARY,
instructions="Test prompt", instructions="Test prompt",
) )
assert result.result == "Mock result" assert result.result == "Mock result"
@ -37,5 +36,4 @@ async def test_state_generate_text(
assert mock_ai_task_entity.mock_generate_text_tasks assert mock_ai_task_entity.mock_generate_text_tasks
task = mock_ai_task_entity.mock_generate_text_tasks[0] task = mock_ai_task_entity.mock_generate_text_tasks[0]
assert task.type == GenTextTaskType.SUMMARY
assert task.instructions == "Test prompt" assert task.instructions == "Test prompt"

View File

@ -25,7 +25,6 @@ async def test_ws_generate_text(
"type": "ai_task/generate_text", "type": "ai_task/generate_text",
"task_name": "Test Task", "task_name": "Test Task",
"entity_id": TEST_ENTITY_ID, "entity_id": TEST_ENTITY_ID,
"task_type": "summary",
"instructions": "Test prompt", "instructions": "Test prompt",
} }
) )

View File

@ -4,7 +4,7 @@ from freezegun import freeze_time
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components.ai_task import GenTextTaskType, async_generate_text from homeassistant.components.ai_task import async_generate_text
from homeassistant.components.conversation import async_get_chat_log from homeassistant.components.conversation import async_get_chat_log
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import chat_session from homeassistant.helpers import chat_session
@ -25,7 +25,6 @@ async def test_run_text_task_unknown_entity(
hass, hass,
task_name="Test Task", task_name="Test Task",
entity_id="ai_task.unknown_entity", entity_id="ai_task.unknown_entity",
task_type="summary",
instructions="Test prompt", instructions="Test prompt",
) )
@ -41,7 +40,6 @@ async def test_run_text_task_updates_chat_log(
hass, hass,
task_name="Test Task", task_name="Test Task",
entity_id=TEST_ENTITY_ID, entity_id=TEST_ENTITY_ID,
task_type=GenTextTaskType.SUMMARY,
instructions="Test prompt", instructions="Test prompt",
) )
assert result.result == "Mock result" assert result.result == "Mock result"