Move recorder chunk utils to shared collection utils (#118065)

This commit is contained in:
Jan Bouwhuis 2024-05-25 00:49:39 +02:00 committed by GitHub
parent 7522bbfa9d
commit c616fc036e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 72 additions and 60 deletions

View File

@ -11,6 +11,8 @@ from typing import TYPE_CHECKING
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from homeassistant.util.collection import chunked_or_all
from .db_schema import Events, States, StatesMeta from .db_schema import Events, States, StatesMeta
from .models import DatabaseEngine from .models import DatabaseEngine
from .queries import ( from .queries import (
@ -40,7 +42,7 @@ from .queries import (
find_statistics_runs_to_purge, find_statistics_runs_to_purge,
) )
from .repack import repack_database from .repack import repack_database
from .util import chunked_or_all, retryable_database_job, session_scope from .util import retryable_database_job, session_scope
if TYPE_CHECKING: if TYPE_CHECKING:
from . import Recorder from . import Recorder

View File

@ -9,11 +9,12 @@ from typing import TYPE_CHECKING, cast
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from homeassistant.core import Event from homeassistant.core import Event
from homeassistant.util.collection import chunked
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS
from ..db_schema import EventData from ..db_schema import EventData
from ..queries import get_shared_event_datas from ..queries import get_shared_event_datas
from ..util import chunked, execute_stmt_lambda_element from ..util import execute_stmt_lambda_element
from . import BaseLRUTableManager from . import BaseLRUTableManager
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -9,12 +9,13 @@ from lru import LRU
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from homeassistant.core import Event from homeassistant.core import Event
from homeassistant.util.collection import chunked
from homeassistant.util.event_type import EventType from homeassistant.util.event_type import EventType
from ..db_schema import EventTypes from ..db_schema import EventTypes
from ..queries import find_event_type_ids from ..queries import find_event_type_ids
from ..tasks import RefreshEventTypesTask from ..tasks import RefreshEventTypesTask
from ..util import chunked, execute_stmt_lambda_element from ..util import execute_stmt_lambda_element
from . import BaseLRUTableManager from . import BaseLRUTableManager
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -9,11 +9,12 @@ from typing import TYPE_CHECKING, cast
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from homeassistant.core import Event, EventStateChangedData from homeassistant.core import Event, EventStateChangedData
from homeassistant.util.collection import chunked
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS
from ..db_schema import StateAttributes from ..db_schema import StateAttributes
from ..queries import get_shared_attributes from ..queries import get_shared_attributes
from ..util import chunked, execute_stmt_lambda_element from ..util import execute_stmt_lambda_element
from . import BaseLRUTableManager from . import BaseLRUTableManager
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -8,10 +8,11 @@ from typing import TYPE_CHECKING, cast
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from homeassistant.core import Event, EventStateChangedData from homeassistant.core import Event, EventStateChangedData
from homeassistant.util.collection import chunked
from ..db_schema import StatesMeta from ..db_schema import StatesMeta
from ..queries import find_all_states_metadata_ids, find_states_metadata_ids from ..queries import find_all_states_metadata_ids, find_states_metadata_ids
from ..util import chunked, execute_stmt_lambda_element from ..util import execute_stmt_lambda_element
from . import BaseLRUTableManager from . import BaseLRUTableManager
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -2,13 +2,11 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Collection, Generator, Iterable, Sequence from collections.abc import Callable, Generator, Sequence
import contextlib import contextlib
from contextlib import contextmanager from contextlib import contextmanager
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
import functools import functools
from functools import partial
from itertools import islice
import logging import logging
import os import os
import time import time
@ -859,36 +857,6 @@ def resolve_period(
return (start_time, end_time) return (start_time, end_time)
def take(take_num: int, iterable: Iterable) -> list[Any]:
"""Return first n items of the iterable as a list.
From itertools recipes
"""
return list(islice(iterable, take_num))
def chunked(iterable: Iterable, chunked_num: int) -> Iterable[Any]:
"""Break *iterable* into lists of length *n*.
From more-itertools
"""
return iter(partial(take, chunked_num, iter(iterable)), [])
def chunked_or_all(iterable: Collection[Any], chunked_num: int) -> Iterable[Any]:
"""Break *collection* into iterables of length *n*.
Returns the collection if its length is less than *n*.
Unlike chunked, this function requires a collection so it can
determine the length of the collection and return the collection
if it is less than *n*.
"""
if len(iterable) <= chunked_num:
return (iterable,)
return chunked(iterable, chunked_num)
def get_index_by_name(session: Session, table_name: str, index_name: str) -> str | None: def get_index_by_name(session: Session, table_name: str, index_name: str) -> str | None:
"""Get an index by name.""" """Get an index by name."""
connection = session.connection() connection = session.connection()

View File

@ -0,0 +1,36 @@
"""Helpers for working with collections."""
from collections.abc import Collection, Iterable
from functools import partial
from itertools import islice
from typing import Any
def take(take_num: int, iterable: Iterable) -> list[Any]:
"""Return first n items of the iterable as a list.
From itertools recipes
"""
return list(islice(iterable, take_num))
def chunked(iterable: Iterable, chunked_num: int) -> Iterable[Any]:
"""Break *iterable* into lists of length *n*.
From more-itertools
"""
return iter(partial(take, chunked_num, iter(iterable)), [])
def chunked_or_all(iterable: Collection[Any], chunked_num: int) -> Iterable[Any]:
"""Break *collection* into iterables of length *n*.
Returns the collection if its length is less than *n*.
Unlike chunked, this function requires a collection so it can
determine the length of the collection and return the collection
if it is less than *n*.
"""
if len(iterable) <= chunked_num:
return (iterable,)
return chunked(iterable, chunked_num)

View File

@ -26,7 +26,6 @@ from homeassistant.components.recorder.models import (
process_timestamp, process_timestamp,
) )
from homeassistant.components.recorder.util import ( from homeassistant.components.recorder.util import (
chunked_or_all,
end_incomplete_runs, end_incomplete_runs,
is_second_sunday, is_second_sunday,
resolve_period, resolve_period,
@ -1051,24 +1050,3 @@ async def test_resolve_period(hass: HomeAssistant) -> None:
} }
} }
) == (now - timedelta(hours=1, minutes=25), now - timedelta(minutes=25)) ) == (now - timedelta(hours=1, minutes=25), now - timedelta(minutes=25))
def test_chunked_or_all():
"""Test chunked_or_all can iterate chunk sizes larger than the passed in collection."""
all_items = []
incoming = (1, 2, 3, 4)
for chunk in chunked_or_all(incoming, 2):
assert len(chunk) == 2
all_items.extend(chunk)
assert all_items == [1, 2, 3, 4]
all_items = []
incoming = (1, 2, 3, 4)
for chunk in chunked_or_all(incoming, 5):
assert len(chunk) == 4
# Verify the chunk is the same object as the incoming
# collection since we want to avoid copying the collection
# if we don't need to
assert chunk is incoming
all_items.extend(chunk)
assert all_items == [1, 2, 3, 4]

View File

@ -0,0 +1,24 @@
"""Test collection utils."""
from homeassistant.util.collection import chunked_or_all
def test_chunked_or_all() -> None:
"""Test chunked_or_all can iterate chunk sizes larger than the passed in collection."""
all_items = []
incoming = (1, 2, 3, 4)
for chunk in chunked_or_all(incoming, 2):
assert len(chunk) == 2
all_items.extend(chunk)
assert all_items == [1, 2, 3, 4]
all_items = []
incoming = (1, 2, 3, 4)
for chunk in chunked_or_all(incoming, 5):
assert len(chunk) == 4
# Verify the chunk is the same object as the incoming
# collection since we want to avoid copying the collection
# if we don't need to
assert chunk is incoming
all_items.extend(chunk)
assert all_items == [1, 2, 3, 4]