Look up todoist collaborators only when adding new task (#87957)

* Look up collaborators only when adding new task.

Also fixed a few api call arguments that were incorrect. The `labels`
key should have been a list of strings and the `assignee` key should
have been `assignee_id`.

* Add missing type in test.

* Remove print
This commit is contained in:
Aaron Godfrey 2023-03-28 00:33:32 -07:00 committed by GitHub
parent ff135ecdc6
commit 8b7594ae08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 25 deletions

View File

@ -1,9 +1,7 @@
"""Support for Todoist task management (https://todoist.com).""" """Support for Todoist task management (https://todoist.com)."""
from __future__ import annotations from __future__ import annotations
import asyncio
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from itertools import chain
import logging import logging
from typing import Any from typing import Any
import uuid import uuid
@ -117,8 +115,6 @@ async def async_setup_platform(
# Look up IDs based on (lowercase) names. # Look up IDs based on (lowercase) names.
project_id_lookup = {} project_id_lookup = {}
label_id_lookup = {}
collaborator_id_lookup = {}
api = TodoistAPIAsync(token) api = TodoistAPIAsync(token)
@ -126,9 +122,6 @@ async def async_setup_platform(
# Grab all projects. # Grab all projects.
projects = await api.get_projects() projects = await api.get_projects()
collaborator_tasks = (api.get_collaborators(project.id) for project in projects)
collaborators = list(chain.from_iterable(await asyncio.gather(*collaborator_tasks)))
# Grab all labels # Grab all labels
labels = await api.get_labels() labels = await api.get_labels()
@ -142,13 +135,6 @@ async def async_setup_platform(
# Cache the names so we can easily look up name->ID. # Cache the names so we can easily look up name->ID.
project_id_lookup[project.name.lower()] = project.id project_id_lookup[project.name.lower()] = project.id
# Cache all label names
label_id_lookup = {label.name.lower(): label.id for label in labels}
collaborator_id_lookup = {
collab.name.lower(): collab.id for collab in collaborators
}
# Check config for more projects. # Check config for more projects.
extra_projects: list[CustomProject] = config[CONF_EXTRA_PROJECTS] extra_projects: list[CustomProject] = config[CONF_EXTRA_PROJECTS]
for extra_project in extra_projects: for extra_project in extra_projects:
@ -194,14 +180,16 @@ async def async_setup_platform(
data: dict[str, Any] = {"project_id": project_id} data: dict[str, Any] = {"project_id": project_id}
if task_labels := call.data.get(LABELS): if task_labels := call.data.get(LABELS):
data["label_ids"] = [ data["labels"] = task_labels
label_id_lookup[label.lower()] for label in task_labels
]
if ASSIGNEE in call.data: if ASSIGNEE in call.data:
collaborators = await api.get_collaborators(project_id)
collaborator_id_lookup = {
collab.name.lower(): collab.id for collab in collaborators
}
task_assignee = call.data[ASSIGNEE].lower() task_assignee = call.data[ASSIGNEE].lower()
if task_assignee in collaborator_id_lookup: if task_assignee in collaborator_id_lookup:
data["assignee"] = collaborator_id_lookup[task_assignee] data["assignee_id"] = collaborator_id_lookup[task_assignee]
else: else:
raise ValueError( raise ValueError(
f"User is not part of the shared project. user: {task_assignee}" f"User is not part of the shared project. user: {task_assignee}"

View File

@ -1,14 +1,21 @@
"""Unit tests for the Todoist calendar platform.""" """Unit tests for the Todoist calendar platform."""
from datetime import datetime, timedelta from datetime import timedelta
from http import HTTPStatus from http import HTTPStatus
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import urllib import urllib
import pytest import pytest
from todoist_api_python.models import Due, Label, Project, Task from todoist_api_python.models import Collaborator, Due, Label, Project, Task
from homeassistant import setup from homeassistant import setup
from homeassistant.components.todoist.calendar import DOMAIN from homeassistant.components.todoist.const import (
ASSIGNEE,
CONTENT,
DOMAIN,
LABELS,
PROJECT_NAME,
SERVICE_NEW_TASK,
)
from homeassistant.const import CONF_TOKEN from homeassistant.const import CONF_TOKEN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
@ -30,9 +37,7 @@ def mock_task() -> Task:
created_at="2021-10-01T00:00:00", created_at="2021-10-01T00:00:00",
creator_id="1", creator_id="1",
description="A task", description="A task",
due=Due( due=Due(is_recurring=False, date=dt.now().strftime("%Y-%m-%d"), string="today"),
is_recurring=False, date=datetime.now().strftime("%Y-%m-%d"), string="today"
),
id="1", id="1",
labels=["Label1"], labels=["Label1"],
order=1, order=1,
@ -68,7 +73,9 @@ def mock_api(task) -> AsyncMock:
api.get_labels.return_value = [ api.get_labels.return_value = [
Label(id="1", name="Label1", color="1", order=1, is_favorite=False) Label(id="1", name="Label1", color="1", order=1, is_favorite=False)
] ]
api.get_collaborators.return_value = [] api.get_collaborators.return_value = [
Collaborator(email="user@gmail.com", id="1", name="user")
]
api.get_tasks.return_value = [task] api.get_tasks.return_value = [task]
return api return api
@ -193,3 +200,31 @@ async def test_all_day_event(
} }
] ]
assert events == expected assert events == expected
@patch("homeassistant.components.todoist.calendar.TodoistAPIAsync")
async def test_create_task_service_call(todoist_api, hass: HomeAssistant, api) -> None:
"""Test api is called correctly after a new task service call."""
todoist_api.return_value = api
assert await setup.async_setup_component(
hass,
"calendar",
{
"calendar": {
"platform": DOMAIN,
CONF_TOKEN: "token",
}
},
)
await hass.async_block_till_done()
await hass.services.async_call(
DOMAIN,
SERVICE_NEW_TASK,
{ASSIGNEE: "user", CONTENT: "task", LABELS: ["Label1"], PROJECT_NAME: "Name"},
)
await hass.async_block_till_done()
api.add_task.assert_called_with(
"task", project_id="12345", labels=["Label1"], assignee_id="1"
)