Refactor services setup in Habitica integration (#128186)

This commit is contained in:
Manu 2024-10-25 11:00:58 +02:00 committed by GitHub
parent 3adacb8799
commit 8665f4a251
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 215 additions and 270 deletions

View File

@ -1,17 +1,13 @@
"""The habitica integration.""" """The habitica integration."""
from http import HTTPStatus from http import HTTPStatus
import logging
from typing import Any
from aiohttp import ClientResponseError from aiohttp import ClientResponseError
from habitipy.aio import HabitipyAsync from habitipy.aio import HabitipyAsync
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
APPLICATION_NAME, APPLICATION_NAME,
ATTR_NAME,
CONF_API_KEY, CONF_API_KEY,
CONF_NAME, CONF_NAME,
CONF_URL, CONF_URL,
@ -19,140 +15,27 @@ from homeassistant.const import (
Platform, Platform,
__version__, __version__,
) )
from homeassistant.core import ( from homeassistant.core import HomeAssistant
HomeAssistant, from homeassistant.exceptions import ConfigEntryNotReady
ServiceCall,
ServiceResponse,
SupportsResponse,
)
from homeassistant.exceptions import (
ConfigEntryNotReady,
HomeAssistantError,
ServiceValidationError,
)
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.selector import ConfigEntrySelector
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import CONF_API_USER, DEVELOPER_ID, DOMAIN
ATTR_ARGS,
ATTR_CONFIG_ENTRY,
ATTR_DATA,
ATTR_PATH,
ATTR_SKILL,
ATTR_TASK,
CONF_API_USER,
DEVELOPER_ID,
DOMAIN,
EVENT_API_CALL_SUCCESS,
SERVICE_API_CALL,
SERVICE_CAST_SKILL,
)
from .coordinator import HabiticaDataUpdateCoordinator from .coordinator import HabiticaDataUpdateCoordinator
from .services import async_setup_services
from .types import HabiticaConfigEntry
_LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
type HabiticaConfigEntry = ConfigEntry[HabiticaDataUpdateCoordinator]
PLATFORMS = [Platform.BUTTON, Platform.SENSOR, Platform.SWITCH, Platform.TODO] PLATFORMS = [Platform.BUTTON, Platform.SENSOR, Platform.SWITCH, Platform.TODO]
SERVICE_API_CALL_SCHEMA = vol.Schema(
{
vol.Required(ATTR_NAME): str,
vol.Required(ATTR_PATH): vol.All(cv.ensure_list, [str]),
vol.Optional(ATTR_ARGS): dict,
}
)
SERVICE_CAST_SKILL_SCHEMA = vol.Schema(
{
vol.Required(ATTR_CONFIG_ENTRY): ConfigEntrySelector(),
vol.Required(ATTR_SKILL): cv.string,
vol.Optional(ATTR_TASK): cv.string,
}
)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Habitica service.""" """Set up the Habitica service."""
async def cast_skill(call: ServiceCall) -> ServiceResponse: async_setup_services(hass)
"""Skill action."""
entry: HabiticaConfigEntry | None
if not (
entry := hass.config_entries.async_get_entry(call.data[ATTR_CONFIG_ENTRY])
):
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="entry_not_found",
)
coordinator = entry.runtime_data
skill = {
"pickpocket": {"spellId": "pickPocket", "cost": "10 MP"},
"backstab": {"spellId": "backStab", "cost": "15 MP"},
"smash": {"spellId": "smash", "cost": "10 MP"},
"fireball": {"spellId": "fireball", "cost": "10 MP"},
}
try:
task_id = next(
task["id"]
for task in coordinator.data.tasks
if call.data[ATTR_TASK] in (task["id"], task.get("alias"))
or call.data[ATTR_TASK] == task["text"]
)
except StopIteration as e:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="task_not_found",
translation_placeholders={"task": f"'{call.data[ATTR_TASK]}'"},
) from e
try:
response: dict[str, Any] = await coordinator.api.user.class_.cast[
skill[call.data[ATTR_SKILL]]["spellId"]
].post(targetId=task_id)
except ClientResponseError as e:
if e.status == HTTPStatus.TOO_MANY_REQUESTS:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="setup_rate_limit_exception",
) from e
if e.status == HTTPStatus.UNAUTHORIZED:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="not_enough_mana",
translation_placeholders={
"cost": skill[call.data[ATTR_SKILL]]["cost"],
"mana": f"{int(coordinator.data.user.get("stats", {}).get("mp", 0))} MP",
},
) from e
if e.status == HTTPStatus.NOT_FOUND:
# could also be task not found, but the task is looked up
# before the request, so most likely wrong skill selected
# or the skill hasn't been unlocked yet.
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="skill_not_found",
translation_placeholders={"skill": call.data[ATTR_SKILL]},
) from e
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="service_call_exception",
) from e
else:
await coordinator.async_request_refresh()
return response
hass.services.async_register(
DOMAIN,
SERVICE_CAST_SKILL,
cast_skill,
schema=SERVICE_CAST_SKILL_SCHEMA,
supports_response=SupportsResponse.ONLY,
)
return True return True
@ -174,33 +57,6 @@ async def async_setup_entry(
) )
return headers return headers
async def handle_api_call(call: ServiceCall) -> None:
name = call.data[ATTR_NAME]
path = call.data[ATTR_PATH]
entries = hass.config_entries.async_entries(DOMAIN)
api = None
for entry in entries:
if entry.data[CONF_NAME] == name:
api = entry.runtime_data.api
break
if api is None:
_LOGGER.error("API_CALL: User '%s' not configured", name)
return
try:
for element in path:
api = api[element]
except KeyError:
_LOGGER.error(
"API_CALL: Path %s is invalid for API on '{%s}' element", path, element
)
return
kwargs = call.data.get(ATTR_ARGS, {})
data = await api(**kwargs)
hass.bus.async_fire(
EVENT_API_CALL_SUCCESS, {ATTR_NAME: name, ATTR_PATH: path, ATTR_DATA: data}
)
websession = async_get_clientsession( websession = async_get_clientsession(
hass, verify_ssl=config_entry.data.get(CONF_VERIFY_SSL, True) hass, verify_ssl=config_entry.data.get(CONF_VERIFY_SSL, True)
) )
@ -236,16 +92,9 @@ async def async_setup_entry(
config_entry.runtime_data = coordinator config_entry.runtime_data = coordinator
await hass.config_entries.async_forward_entry_setups(config_entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(config_entry, PLATFORMS)
if not hass.services.has_service(DOMAIN, SERVICE_API_CALL):
hass.services.async_register(
DOMAIN, SERVICE_API_CALL, handle_api_call, schema=SERVICE_API_CALL_SCHEMA
)
return True return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
if len(hass.config_entries.async_entries(DOMAIN)) == 1:
hass.services.async_remove(DOMAIN, SERVICE_API_CALL)
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)

View File

@ -20,10 +20,10 @@ from homeassistant.exceptions import ServiceValidationError
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import HabiticaConfigEntry
from .const import ASSETS_URL, DOMAIN, HEALER, MAGE, ROGUE, WARRIOR from .const import ASSETS_URL, DOMAIN, HEALER, MAGE, ROGUE, WARRIOR
from .coordinator import HabiticaData, HabiticaDataUpdateCoordinator from .coordinator import HabiticaData, HabiticaDataUpdateCoordinator
from .entity import HabiticaBase from .entity import HabiticaBase
from .types import HabiticaConfigEntry
@dataclass(kw_only=True, frozen=True) @dataclass(kw_only=True, frozen=True)

View File

@ -24,9 +24,9 @@ from homeassistant.helpers.issue_registry import (
) )
from homeassistant.helpers.typing import StateType from homeassistant.helpers.typing import StateType
from . import HabiticaConfigEntry
from .const import DOMAIN, UNIT_TASKS from .const import DOMAIN, UNIT_TASKS
from .entity import HabiticaBase from .entity import HabiticaBase
from .types import HabiticaConfigEntry
from .util import entity_used_in from .util import entity_used_in
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View File

@ -0,0 +1,167 @@
"""Actions for the Habitica integration."""
from __future__ import annotations
from http import HTTPStatus
import logging
from typing import Any
from aiohttp import ClientResponseError
import voluptuous as vol
from homeassistant.const import ATTR_NAME, CONF_NAME
from homeassistant.core import (
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
)
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.selector import ConfigEntrySelector
from .const import (
ATTR_ARGS,
ATTR_CONFIG_ENTRY,
ATTR_DATA,
ATTR_PATH,
ATTR_SKILL,
ATTR_TASK,
DOMAIN,
EVENT_API_CALL_SUCCESS,
SERVICE_API_CALL,
SERVICE_CAST_SKILL,
)
from .types import HabiticaConfigEntry
_LOGGER = logging.getLogger(__name__)
SERVICE_API_CALL_SCHEMA = vol.Schema(
{
vol.Required(ATTR_NAME): str,
vol.Required(ATTR_PATH): vol.All(cv.ensure_list, [str]),
vol.Optional(ATTR_ARGS): dict,
}
)
SERVICE_CAST_SKILL_SCHEMA = vol.Schema(
{
vol.Required(ATTR_CONFIG_ENTRY): ConfigEntrySelector(),
vol.Required(ATTR_SKILL): cv.string,
vol.Optional(ATTR_TASK): cv.string,
}
)
def async_setup_services(hass: HomeAssistant) -> None:
"""Set up services for Habitica integration."""
async def handle_api_call(call: ServiceCall) -> None:
name = call.data[ATTR_NAME]
path = call.data[ATTR_PATH]
entries = hass.config_entries.async_entries(DOMAIN)
api = None
for entry in entries:
if entry.data[CONF_NAME] == name:
api = entry.runtime_data.api
break
if api is None:
_LOGGER.error("API_CALL: User '%s' not configured", name)
return
try:
for element in path:
api = api[element]
except KeyError:
_LOGGER.error(
"API_CALL: Path %s is invalid for API on '{%s}' element", path, element
)
return
kwargs = call.data.get(ATTR_ARGS, {})
data = await api(**kwargs)
hass.bus.async_fire(
EVENT_API_CALL_SUCCESS, {ATTR_NAME: name, ATTR_PATH: path, ATTR_DATA: data}
)
async def cast_skill(call: ServiceCall) -> ServiceResponse:
"""Skill action."""
entry: HabiticaConfigEntry | None
if not (
entry := hass.config_entries.async_get_entry(call.data[ATTR_CONFIG_ENTRY])
):
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="entry_not_found",
)
coordinator = entry.runtime_data
skill = {
"pickpocket": {"spellId": "pickPocket", "cost": "10 MP"},
"backstab": {"spellId": "backStab", "cost": "15 MP"},
"smash": {"spellId": "smash", "cost": "10 MP"},
"fireball": {"spellId": "fireball", "cost": "10 MP"},
}
try:
task_id = next(
task["id"]
for task in coordinator.data.tasks
if call.data[ATTR_TASK] in (task["id"], task.get("alias"))
or call.data[ATTR_TASK] == task["text"]
)
except StopIteration as e:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="task_not_found",
translation_placeholders={"task": f"'{call.data[ATTR_TASK]}'"},
) from e
try:
response: dict[str, Any] = await coordinator.api.user.class_.cast[
skill[call.data[ATTR_SKILL]]["spellId"]
].post(targetId=task_id)
except ClientResponseError as e:
if e.status == HTTPStatus.TOO_MANY_REQUESTS:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="setup_rate_limit_exception",
) from e
if e.status == HTTPStatus.UNAUTHORIZED:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="not_enough_mana",
translation_placeholders={
"cost": skill[call.data[ATTR_SKILL]]["cost"],
"mana": f"{int(coordinator.data.user.get("stats", {}).get("mp", 0))} MP",
},
) from e
if e.status == HTTPStatus.NOT_FOUND:
# could also be task not found, but the task is looked up
# before the request, so most likely wrong skill selected
# or the skill hasn't been unlocked yet.
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="skill_not_found",
translation_placeholders={"skill": call.data[ATTR_SKILL]},
) from e
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="service_call_exception",
) from e
else:
await coordinator.async_request_refresh()
return response
hass.services.async_register(
DOMAIN,
SERVICE_API_CALL,
handle_api_call,
schema=SERVICE_API_CALL_SCHEMA,
)
hass.services.async_register(
DOMAIN,
SERVICE_CAST_SKILL,
cast_skill,
schema=SERVICE_CAST_SKILL_SCHEMA,
supports_response=SupportsResponse.ONLY,
)

View File

@ -15,9 +15,9 @@ from homeassistant.components.switch import (
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import HabiticaConfigEntry
from .coordinator import HabiticaData, HabiticaDataUpdateCoordinator from .coordinator import HabiticaData, HabiticaDataUpdateCoordinator
from .entity import HabiticaBase from .entity import HabiticaBase
from .types import HabiticaConfigEntry
@dataclass(kw_only=True, frozen=True) @dataclass(kw_only=True, frozen=True)

View File

@ -21,10 +21,10 @@ from homeassistant.helpers.entity import EntityDescription
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from . import HabiticaConfigEntry
from .const import ASSETS_URL, DOMAIN from .const import ASSETS_URL, DOMAIN
from .coordinator import HabiticaDataUpdateCoordinator from .coordinator import HabiticaDataUpdateCoordinator
from .entity import HabiticaBase from .entity import HabiticaBase
from .types import HabiticaConfigEntry
from .util import next_due_date from .util import next_due_date

View File

@ -0,0 +1,7 @@
"""Types for Habitica integration."""
from homeassistant.config_entries import ConfigEntry
from .coordinator import HabiticaDataUpdateCoordinator
type HabiticaConfigEntry = ConfigEntry[HabiticaDataUpdateCoordinator]

View File

@ -38,121 +38,47 @@ def capture_api_call_success(hass: HomeAssistant) -> list[Event]:
return async_capture_events(hass, EVENT_API_CALL_SUCCESS) return async_capture_events(hass, EVENT_API_CALL_SUCCESS)
@pytest.fixture @pytest.mark.usefixtures("mock_habitica")
def habitica_entry(hass: HomeAssistant) -> MockConfigEntry: async def test_entry_setup_unload(
"""Test entry for the following tests.""" hass: HomeAssistant, config_entry: MockConfigEntry
entry = MockConfigEntry( ) -> None:
domain=DOMAIN, """Test integration setup and unload."""
unique_id="test-api-user",
data={ config_entry.add_to_hass(hass)
"api_user": "test-api-user", assert await hass.config_entries.async_setup(config_entry.entry_id)
"api_key": "test-api-key", await hass.async_block_till_done()
"url": DEFAULT_URL,
}, assert config_entry.state is ConfigEntryState.LOADED
)
entry.add_to_hass(hass) assert await hass.config_entries.async_unload(config_entry.entry_id)
return entry
assert config_entry.state is ConfigEntryState.NOT_LOADED
@pytest.fixture @pytest.mark.usefixtures("mock_habitica")
def common_requests(aioclient_mock: AiohttpClientMocker) -> AiohttpClientMocker: async def test_service_call(
"""Register requests for the tests.""" hass: HomeAssistant,
aioclient_mock.get( config_entry: MockConfigEntry,
"https://habitica.com/api/v3/user", capture_api_call_success: list[Event],
json={ mock_habitica: AiohttpClientMocker,
"data": { ) -> None:
"auth": {"local": {"username": TEST_USER_NAME}}, """Test integration setup, service call and unload."""
"api_user": "test-api-user", config_entry.add_to_hass(hass)
"profile": {"name": TEST_USER_NAME}, assert await hass.config_entries.async_setup(config_entry.entry_id)
"stats": { await hass.async_block_till_done()
"class": "warrior",
"con": 1,
"exp": 2,
"gp": 3,
"hp": 4,
"int": 5,
"lvl": 6,
"maxHealth": 7,
"maxMP": 8,
"mp": 9,
"per": 10,
"points": 11,
"str": 12,
"toNextLevel": 13,
},
}
},
)
aioclient_mock.get( assert config_entry.state is ConfigEntryState.LOADED
"https://habitica.com/api/v3/tasks/user",
json={
"data": [
{
"text": f"this is a mock {task} #{i}",
"id": f"{i}",
"type": task,
"completed": False,
}
for i, task in enumerate(("habit", "daily", "todo", "reward"), start=1)
]
},
)
aioclient_mock.get(
"https://habitica.com/api/v3/tasks/user?type=completedTodos",
json={
"data": [
{
"text": "this is a mock todo #5",
"id": 5,
"type": "todo",
"completed": True,
}
]
},
)
aioclient_mock.post( assert len(capture_api_call_success) == 0
mock_habitica.post(
"https://habitica.com/api/v3/tasks/user", "https://habitica.com/api/v3/tasks/user",
status=HTTPStatus.CREATED, status=HTTPStatus.CREATED,
json={"data": TEST_API_CALL_ARGS}, json={"data": TEST_API_CALL_ARGS},
) )
return aioclient_mock
@pytest.mark.usefixtures("common_requests")
async def test_entry_setup_unload(
hass: HomeAssistant, habitica_entry: MockConfigEntry
) -> None:
"""Test integration setup and unload."""
assert await hass.config_entries.async_setup(habitica_entry.entry_id)
await hass.async_block_till_done()
assert hass.services.has_service(DOMAIN, SERVICE_API_CALL)
assert await hass.config_entries.async_unload(habitica_entry.entry_id)
assert not hass.services.has_service(DOMAIN, SERVICE_API_CALL)
@pytest.mark.usefixtures("common_requests")
async def test_service_call(
hass: HomeAssistant,
habitica_entry: MockConfigEntry,
capture_api_call_success: list[Event],
) -> None:
"""Test integration setup, service call and unload."""
assert await hass.config_entries.async_setup(habitica_entry.entry_id)
await hass.async_block_till_done()
assert hass.services.has_service(DOMAIN, SERVICE_API_CALL)
assert len(capture_api_call_success) == 0
TEST_SERVICE_DATA = { TEST_SERVICE_DATA = {
ATTR_NAME: "test_user", ATTR_NAME: "test-user",
ATTR_PATH: ["tasks", "user", "post"], ATTR_PATH: ["tasks", "user", "post"],
ATTR_ARGS: TEST_API_CALL_ARGS, ATTR_ARGS: TEST_API_CALL_ARGS,
} }
@ -166,10 +92,6 @@ async def test_service_call(
del captured_data[ATTR_DATA] del captured_data[ATTR_DATA]
assert captured_data == TEST_SERVICE_DATA assert captured_data == TEST_SERVICE_DATA
assert await hass.config_entries.async_unload(habitica_entry.entry_id)
assert not hass.services.has_service(DOMAIN, SERVICE_API_CALL)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("status"), [HTTPStatus.NOT_FOUND, HTTPStatus.TOO_MANY_REQUESTS] ("status"), [HTTPStatus.NOT_FOUND, HTTPStatus.TOO_MANY_REQUESTS]