Improve Google Tasks coordinator updates behavior (#133316)

This commit is contained in:
Allen Porter 2024-12-19 07:41:47 -08:00 committed by GitHub
parent 255f85eb2f
commit a3ef3cce3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 107 additions and 58 deletions

View File

@ -2,6 +2,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from aiohttp import ClientError, ClientResponseError from aiohttp import ClientError, ClientResponseError
from homeassistant.const import Platform from homeassistant.const import Platform
@ -11,8 +13,9 @@ from homeassistant.helpers import config_entry_oauth2_flow
from . import api from . import api
from .const import DOMAIN from .const import DOMAIN
from .coordinator import TaskUpdateCoordinator
from .exceptions import GoogleTasksApiError from .exceptions import GoogleTasksApiError
from .types import GoogleTasksConfigEntry, GoogleTasksData from .types import GoogleTasksConfigEntry
__all__ = [ __all__ = [
"DOMAIN", "DOMAIN",
@ -46,7 +49,23 @@ async def async_setup_entry(hass: HomeAssistant, entry: GoogleTasksConfigEntry)
except GoogleTasksApiError as err: except GoogleTasksApiError as err:
raise ConfigEntryNotReady from err raise ConfigEntryNotReady from err
entry.runtime_data = GoogleTasksData(auth, task_lists) coordinators = [
TaskUpdateCoordinator(
hass,
auth,
task_list["id"],
task_list["title"],
)
for task_list in task_lists
]
# Refresh all coordinators in parallel
await asyncio.gather(
*(
coordinator.async_config_entry_first_refresh()
for coordinator in coordinators
)
)
entry.runtime_data = coordinators
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)

View File

@ -20,7 +20,11 @@ class TaskUpdateCoordinator(DataUpdateCoordinator[list[dict[str, Any]]]):
"""Coordinator for fetching Google Tasks for a Task List form the API.""" """Coordinator for fetching Google Tasks for a Task List form the API."""
def __init__( def __init__(
self, hass: HomeAssistant, api: AsyncConfigEntryAuth, task_list_id: str self,
hass: HomeAssistant,
api: AsyncConfigEntryAuth,
task_list_id: str,
task_list_title: str,
) -> None: ) -> None:
"""Initialize TaskUpdateCoordinator.""" """Initialize TaskUpdateCoordinator."""
super().__init__( super().__init__(
@ -30,9 +34,10 @@ class TaskUpdateCoordinator(DataUpdateCoordinator[list[dict[str, Any]]]):
update_interval=UPDATE_INTERVAL, update_interval=UPDATE_INTERVAL,
) )
self.api = api self.api = api
self._task_list_id = task_list_id self.task_list_id = task_list_id
self.task_list_title = task_list_title
async def _async_update_data(self) -> list[dict[str, Any]]: async def _async_update_data(self) -> list[dict[str, Any]]:
"""Fetch tasks from API endpoint.""" """Fetch tasks from API endpoint."""
async with asyncio.timeout(TIMEOUT): async with asyncio.timeout(TIMEOUT):
return await self.api.list_tasks(self._task_list_id) return await self.api.list_tasks(self.task_list_id)

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC, date, datetime, timedelta from datetime import UTC, date, datetime
from typing import Any, cast from typing import Any, cast
from homeassistant.components.todo import ( from homeassistant.components.todo import (
@ -20,7 +20,6 @@ from .coordinator import TaskUpdateCoordinator
from .types import GoogleTasksConfigEntry from .types import GoogleTasksConfigEntry
PARALLEL_UPDATES = 0 PARALLEL_UPDATES = 0
SCAN_INTERVAL = timedelta(minutes=15)
TODO_STATUS_MAP = { TODO_STATUS_MAP = {
"needsAction": TodoItemStatus.NEEDS_ACTION, "needsAction": TodoItemStatus.NEEDS_ACTION,
@ -76,14 +75,13 @@ async def async_setup_entry(
async_add_entities( async_add_entities(
( (
GoogleTaskTodoListEntity( GoogleTaskTodoListEntity(
TaskUpdateCoordinator(hass, entry.runtime_data.api, task_list["id"]), coordinator,
task_list["title"], coordinator.task_list_title,
entry.entry_id, entry.entry_id,
task_list["id"], coordinator.task_list_id,
) )
for task_list in entry.runtime_data.task_lists for coordinator in entry.runtime_data
), ),
True,
) )
@ -118,8 +116,6 @@ class GoogleTaskTodoListEntity(
@property @property
def todo_items(self) -> list[TodoItem] | None: def todo_items(self) -> list[TodoItem] | None:
"""Get the current set of To-do items.""" """Get the current set of To-do items."""
if self.coordinator.data is None:
return None
return [_convert_api_item(item) for item in _order_tasks(self.coordinator.data)] return [_convert_api_item(item) for item in _order_tasks(self.coordinator.data)]
async def async_create_todo_item(self, item: TodoItem) -> None: async def async_create_todo_item(self, item: TodoItem) -> None:

View File

@ -1,19 +1,7 @@
"""Types for the Google Tasks integration.""" """Types for the Google Tasks integration."""
from dataclasses import dataclass
from typing import Any
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from .api import AsyncConfigEntryAuth from .coordinator import TaskUpdateCoordinator
type GoogleTasksConfigEntry = ConfigEntry[list[TaskUpdateCoordinator]]
@dataclass
class GoogleTasksData:
"""Class to hold Google Tasks data."""
api: AsyncConfigEntryAuth
task_lists: list[dict[str, Any]]
type GoogleTasksConfigEntry = ConfigEntry[GoogleTasksData]

View File

@ -34,6 +34,18 @@ LIST_TASK_LIST_RESPONSE = {
"items": [TASK_LIST], "items": [TASK_LIST],
} }
LIST_TASKS_RESPONSE_WATER = {
"items": [
{
"id": "some-task-id",
"title": "Water",
"status": "needsAction",
"description": "Any size is ok",
"position": "00000000000000000001",
},
],
}
@pytest.fixture @pytest.fixture
def platforms() -> list[Platform]: def platforms() -> list[Platform]:
@ -44,7 +56,7 @@ def platforms() -> list[Platform]:
@pytest.fixture(name="expires_at") @pytest.fixture(name="expires_at")
def mock_expires_at() -> int: def mock_expires_at() -> int:
"""Fixture to set the oauth token expiration time.""" """Fixture to set the oauth token expiration time."""
return time.time() + 3600 return time.time() + 86400
@pytest.fixture(name="token_entry") @pytest.fixture(name="token_entry")

View File

@ -3,6 +3,7 @@
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
import http import http
from http import HTTPStatus from http import HTTPStatus
import json
import time import time
from unittest.mock import Mock from unittest.mock import Mock
@ -15,13 +16,15 @@ from homeassistant.components.google_tasks.const import OAUTH2_TOKEN
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .conftest import LIST_TASK_LIST_RESPONSE from .conftest import LIST_TASK_LIST_RESPONSE, LIST_TASKS_RESPONSE_WATER
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.test_util.aiohttp import AiohttpClientMocker from tests.test_util.aiohttp import AiohttpClientMocker
@pytest.mark.parametrize("api_responses", [[LIST_TASK_LIST_RESPONSE]]) @pytest.mark.parametrize(
"api_responses", [[LIST_TASK_LIST_RESPONSE, LIST_TASKS_RESPONSE_WATER]]
)
async def test_setup( async def test_setup(
hass: HomeAssistant, hass: HomeAssistant,
integration_setup: Callable[[], Awaitable[bool]], integration_setup: Callable[[], Awaitable[bool]],
@ -42,8 +45,10 @@ async def test_setup(
assert not hass.services.async_services().get(DOMAIN) assert not hass.services.async_services().get(DOMAIN)
@pytest.mark.parametrize("expires_at", [time.time() - 3600], ids=["expired"]) @pytest.mark.parametrize("expires_at", [time.time() - 86400], ids=["expired"])
@pytest.mark.parametrize("api_responses", [[LIST_TASK_LIST_RESPONSE]]) @pytest.mark.parametrize(
"api_responses", [[LIST_TASK_LIST_RESPONSE, LIST_TASKS_RESPONSE_WATER]]
)
async def test_expired_token_refresh_success( async def test_expired_token_refresh_success(
hass: HomeAssistant, hass: HomeAssistant,
integration_setup: Callable[[], Awaitable[bool]], integration_setup: Callable[[], Awaitable[bool]],
@ -60,8 +65,8 @@ async def test_expired_token_refresh_success(
json={ json={
"access_token": "updated-access-token", "access_token": "updated-access-token",
"refresh_token": "updated-refresh-token", "refresh_token": "updated-refresh-token",
"expires_at": time.time() + 3600, "expires_at": time.time() + 86400,
"expires_in": 3600, "expires_in": 86400,
}, },
) )
@ -69,26 +74,26 @@ async def test_expired_token_refresh_success(
assert config_entry.state is ConfigEntryState.LOADED assert config_entry.state is ConfigEntryState.LOADED
assert config_entry.data["token"]["access_token"] == "updated-access-token" assert config_entry.data["token"]["access_token"] == "updated-access-token"
assert config_entry.data["token"]["expires_in"] == 3600 assert config_entry.data["token"]["expires_in"] == 86400
@pytest.mark.parametrize( @pytest.mark.parametrize(
("expires_at", "status", "exc", "expected_state"), ("expires_at", "status", "exc", "expected_state"),
[ [
( (
time.time() - 3600, time.time() - 86400,
http.HTTPStatus.UNAUTHORIZED, http.HTTPStatus.UNAUTHORIZED,
None, None,
ConfigEntryState.SETUP_ERROR, ConfigEntryState.SETUP_ERROR,
), ),
( (
time.time() - 3600, time.time() - 86400,
http.HTTPStatus.INTERNAL_SERVER_ERROR, http.HTTPStatus.INTERNAL_SERVER_ERROR,
None, None,
ConfigEntryState.SETUP_RETRY, ConfigEntryState.SETUP_RETRY,
), ),
( (
time.time() - 3600, time.time() - 86400,
None, None,
ClientError("error"), ClientError("error"),
ConfigEntryState.SETUP_RETRY, ConfigEntryState.SETUP_RETRY,
@ -124,6 +129,16 @@ async def test_expired_token_refresh_failure(
"response_handler", "response_handler",
[ [
([(Response({"status": HTTPStatus.INTERNAL_SERVER_ERROR}), b"")]), ([(Response({"status": HTTPStatus.INTERNAL_SERVER_ERROR}), b"")]),
# First request succeeds, second request fails
(
[
(
Response({"status": HTTPStatus.OK}),
json.dumps(LIST_TASK_LIST_RESPONSE),
),
(Response({"status": HTTPStatus.INTERNAL_SERVER_ERROR}), b""),
]
),
], ],
) )
async def test_setup_error( async def test_setup_error(

View File

@ -6,10 +6,12 @@ import json
from typing import Any from typing import Any
from unittest.mock import Mock from unittest.mock import Mock
from freezegun.api import FrozenDateTimeFactory
from httplib2 import Response from httplib2 import Response
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components.google_tasks.coordinator import UPDATE_INTERVAL
from homeassistant.components.todo import ( from homeassistant.components.todo import (
ATTR_DESCRIPTION, ATTR_DESCRIPTION,
ATTR_DUE_DATE, ATTR_DUE_DATE,
@ -19,12 +21,17 @@ from homeassistant.components.todo import (
DOMAIN as TODO_DOMAIN, DOMAIN as TODO_DOMAIN,
TodoServices, TodoServices,
) )
from homeassistant.const import ATTR_ENTITY_ID, Platform from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from .conftest import LIST_TASK_LIST_RESPONSE, create_response_object from .conftest import (
LIST_TASK_LIST_RESPONSE,
LIST_TASKS_RESPONSE_WATER,
create_response_object,
)
from tests.common import async_fire_time_changed
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
ENTITY_ID = "todo.my_tasks" ENTITY_ID = "todo.my_tasks"
@ -44,17 +51,6 @@ ERROR_RESPONSE = {
CONTENT_ID = "Content-ID" CONTENT_ID = "Content-ID"
BOUNDARY = "batch_00972cc8-75bd-11ee-9692-0242ac110002" # Arbitrary uuid BOUNDARY = "batch_00972cc8-75bd-11ee-9692-0242ac110002" # Arbitrary uuid
LIST_TASKS_RESPONSE_WATER = {
"items": [
{
"id": "some-task-id",
"title": "Water",
"status": "needsAction",
"description": "Any size is ok",
"position": "00000000000000000001",
},
],
}
LIST_TASKS_RESPONSE_MULTIPLE = { LIST_TASKS_RESPONSE_MULTIPLE = {
"items": [ "items": [
{ {
@ -311,7 +307,9 @@ async def test_empty_todo_list(
[ [
[ [
LIST_TASK_LIST_RESPONSE, LIST_TASK_LIST_RESPONSE,
ERROR_RESPONSE, LIST_TASKS_RESPONSE_WATER,
ERROR_RESPONSE, # Fail after one update interval
LIST_TASKS_RESPONSE_WATER,
] ]
], ],
) )
@ -319,18 +317,34 @@ async def test_task_items_error_response(
hass: HomeAssistant, hass: HomeAssistant,
setup_credentials: None, setup_credentials: None,
integration_setup: Callable[[], Awaitable[bool]], integration_setup: Callable[[], Awaitable[bool]],
hass_ws_client: WebSocketGenerator, freezer: FrozenDateTimeFactory,
ws_get_items: Callable[[], Awaitable[dict[str, str]]],
) -> None: ) -> None:
"""Test an error while getting todo list items.""" """Test an error while the entity updates getting a new list of todo list items."""
assert await integration_setup() assert await integration_setup()
await hass_ws_client(hass) # Test successful setup and first data fetch
state = hass.states.get("todo.my_tasks")
assert state
assert state.state == "1"
# Next update fails
freezer.tick(UPDATE_INTERVAL)
async_fire_time_changed(hass)
await hass.async_block_till_done(wait_background_tasks=True)
state = hass.states.get("todo.my_tasks") state = hass.states.get("todo.my_tasks")
assert state assert state
assert state.state == "unavailable" assert state.state == STATE_UNAVAILABLE
# Next update succeeds
freezer.tick(UPDATE_INTERVAL)
async_fire_time_changed(hass)
await hass.async_block_till_done(wait_background_tasks=True)
state = hass.states.get("todo.my_tasks")
assert state
assert state.state == "1"
@pytest.mark.parametrize( @pytest.mark.parametrize(