Merge pull request #73193 from home-assistant/rc

This commit is contained in:
Paulus Schoutsen 2022-06-07 17:32:30 -07:00 committed by GitHub
commit 8b0e10d8a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 489 additions and 297 deletions

View File

@ -3,7 +3,7 @@
"name": "DLNA Digital Media Renderer", "name": "DLNA Digital Media Renderer",
"config_flow": true, "config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/dlna_dmr", "documentation": "https://www.home-assistant.io/integrations/dlna_dmr",
"requirements": ["async-upnp-client==0.30.1"], "requirements": ["async-upnp-client==0.31.1"],
"dependencies": ["ssdp"], "dependencies": ["ssdp"],
"after_dependencies": ["media_source"], "after_dependencies": ["media_source"],
"ssdp": [ "ssdp": [

View File

@ -3,7 +3,7 @@
"name": "DLNA Digital Media Server", "name": "DLNA Digital Media Server",
"config_flow": true, "config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/dlna_dms", "documentation": "https://www.home-assistant.io/integrations/dlna_dms",
"requirements": ["async-upnp-client==0.30.1"], "requirements": ["async-upnp-client==0.31.1"],
"dependencies": ["ssdp"], "dependencies": ["ssdp"],
"after_dependencies": ["media_source"], "after_dependencies": ["media_source"],
"ssdp": [ "ssdp": [

View File

@ -25,7 +25,12 @@ from homeassistant.const import STATE_IDLE, STATE_PAUSED, STATE_PLAYING
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import EsphomeEntity, EsphomeEnumMapper, platform_async_setup_entry from . import (
EsphomeEntity,
EsphomeEnumMapper,
esphome_state_property,
platform_async_setup_entry,
)
async def async_setup_entry( async def async_setup_entry(
@ -54,6 +59,10 @@ _STATES: EsphomeEnumMapper[MediaPlayerState, str] = EsphomeEnumMapper(
) )
# https://github.com/PyCQA/pylint/issues/3150 for all @esphome_state_property
# pylint: disable=invalid-overridden-method
class EsphomeMediaPlayer( class EsphomeMediaPlayer(
EsphomeEntity[MediaPlayerInfo, MediaPlayerEntityState], MediaPlayerEntity EsphomeEntity[MediaPlayerInfo, MediaPlayerEntityState], MediaPlayerEntity
): ):
@ -61,17 +70,17 @@ class EsphomeMediaPlayer(
_attr_device_class = MediaPlayerDeviceClass.SPEAKER _attr_device_class = MediaPlayerDeviceClass.SPEAKER
@property @esphome_state_property
def state(self) -> str | None: def state(self) -> str | None:
"""Return current state.""" """Return current state."""
return _STATES.from_esphome(self._state.state) return _STATES.from_esphome(self._state.state)
@property @esphome_state_property
def is_volume_muted(self) -> bool: def is_volume_muted(self) -> bool:
"""Return true if volume is muted.""" """Return true if volume is muted."""
return self._state.muted return self._state.muted
@property @esphome_state_property
def volume_level(self) -> float | None: def volume_level(self) -> float | None:
"""Volume level of the media player (0..1).""" """Volume level of the media player (0..1)."""
return self._state.volume return self._state.volume

View File

@ -169,7 +169,16 @@ def wifi_entities_list(
} }
for i, network in networks.copy().items(): for i, network in networks.copy().items():
networks[i]["switch_name"] = network["ssid"] networks[i]["switch_name"] = network["ssid"]
if len([j for j, n in networks.items() if n["ssid"] == network["ssid"]]) > 1: if (
len(
[
j
for j, n in networks.items()
if slugify(n["ssid"]) == slugify(network["ssid"])
]
)
> 1
):
networks[i]["switch_name"] += f" ({WIFI_STANDARD[i]})" networks[i]["switch_name"] += f" ({WIFI_STANDARD[i]})"
_LOGGER.debug("WiFi networks list: %s", networks) _LOGGER.debug("WiFi networks list: %s", networks)

View File

@ -96,7 +96,9 @@ SENSOR_TYPES: Final[tuple[FritzSensorEntityDescription, ...]] = (
device_class=SensorDeviceClass.VOLTAGE, device_class=SensorDeviceClass.VOLTAGE,
state_class=SensorStateClass.MEASUREMENT, state_class=SensorStateClass.MEASUREMENT,
suitable=lambda device: device.has_powermeter, # type: ignore[no-any-return] suitable=lambda device: device.has_powermeter, # type: ignore[no-any-return]
native_value=lambda device: device.voltage / 1000 if device.voltage else 0.0, native_value=lambda device: device.voltage / 1000
if getattr(device, "voltage", None)
else 0.0,
), ),
FritzSensorEntityDescription( FritzSensorEntityDescription(
key="electric_current", key="electric_current",
@ -106,7 +108,7 @@ SENSOR_TYPES: Final[tuple[FritzSensorEntityDescription, ...]] = (
state_class=SensorStateClass.MEASUREMENT, state_class=SensorStateClass.MEASUREMENT,
suitable=lambda device: device.has_powermeter, # type: ignore[no-any-return] suitable=lambda device: device.has_powermeter, # type: ignore[no-any-return]
native_value=lambda device: device.power / device.voltage native_value=lambda device: device.power / device.voltage
if device.power and device.voltage if device.power and getattr(device, "voltage", None)
else 0.0, else 0.0,
), ),
FritzSensorEntityDescription( FritzSensorEntityDescription(

View File

@ -2,8 +2,9 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime as dt from datetime import datetime as dt
import json
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.selectable import Select
from homeassistant.components.recorder.filters import Filters from homeassistant.components.recorder.filters import Filters
@ -21,7 +22,7 @@ def statement_for_request(
device_ids: list[str] | None = None, device_ids: list[str] | None = None,
filters: Filters | None = None, filters: Filters | None = None,
context_id: str | None = None, context_id: str | None = None,
) -> StatementLambdaElement: ) -> Select:
"""Generate the logbook statement for a logbook request.""" """Generate the logbook statement for a logbook request."""
# No entities: logbook sends everything for the timeframe # No entities: logbook sends everything for the timeframe
@ -38,41 +39,36 @@ def statement_for_request(
context_id, context_id,
) )
# sqlalchemy caches object quoting, the
# json quotable ones must be a different
# object from the non-json ones to prevent
# sqlalchemy from quoting them incorrectly
# entities and devices: logbook sends everything for the timeframe for the entities and devices # entities and devices: logbook sends everything for the timeframe for the entities and devices
if entity_ids and device_ids: if entity_ids and device_ids:
json_quotable_entity_ids = list(entity_ids) json_quoted_entity_ids = [json.dumps(entity_id) for entity_id in entity_ids]
json_quotable_device_ids = list(device_ids) json_quoted_device_ids = [json.dumps(device_id) for device_id in device_ids]
return entities_devices_stmt( return entities_devices_stmt(
start_day, start_day,
end_day, end_day,
event_types, event_types,
entity_ids, entity_ids,
json_quotable_entity_ids, json_quoted_entity_ids,
json_quotable_device_ids, json_quoted_device_ids,
) )
# entities: logbook sends everything for the timeframe for the entities # entities: logbook sends everything for the timeframe for the entities
if entity_ids: if entity_ids:
json_quotable_entity_ids = list(entity_ids) json_quoted_entity_ids = [json.dumps(entity_id) for entity_id in entity_ids]
return entities_stmt( return entities_stmt(
start_day, start_day,
end_day, end_day,
event_types, event_types,
entity_ids, entity_ids,
json_quotable_entity_ids, json_quoted_entity_ids,
) )
# devices: logbook sends everything for the timeframe for the devices # devices: logbook sends everything for the timeframe for the devices
assert device_ids is not None assert device_ids is not None
json_quotable_device_ids = list(device_ids) json_quoted_device_ids = [json.dumps(device_id) for device_id in device_ids]
return devices_stmt( return devices_stmt(
start_day, start_day,
end_day, end_day,
event_types, event_types,
json_quotable_device_ids, json_quoted_device_ids,
) )

View File

@ -3,10 +3,9 @@ from __future__ import annotations
from datetime import datetime as dt from datetime import datetime as dt
from sqlalchemy import lambda_stmt
from sqlalchemy.orm import Query from sqlalchemy.orm import Query
from sqlalchemy.sql.elements import ClauseList from sqlalchemy.sql.elements import ClauseList
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.selectable import Select
from homeassistant.components.recorder.models import LAST_UPDATED_INDEX, Events, States from homeassistant.components.recorder.models import LAST_UPDATED_INDEX, Events, States
@ -25,32 +24,29 @@ def all_stmt(
states_entity_filter: ClauseList | None = None, states_entity_filter: ClauseList | None = None,
events_entity_filter: ClauseList | None = None, events_entity_filter: ClauseList | None = None,
context_id: str | None = None, context_id: str | None = None,
) -> StatementLambdaElement: ) -> Select:
"""Generate a logbook query for all entities.""" """Generate a logbook query for all entities."""
stmt = lambda_stmt( stmt = select_events_without_states(start_day, end_day, event_types)
lambda: select_events_without_states(start_day, end_day, event_types)
)
if context_id is not None: if context_id is not None:
# Once all the old `state_changed` events # Once all the old `state_changed` events
# are gone from the database remove the # are gone from the database remove the
# _legacy_select_events_context_id() # _legacy_select_events_context_id()
stmt += lambda s: s.where(Events.context_id == context_id).union_all( stmt = stmt.where(Events.context_id == context_id).union_all(
_states_query_for_context_id(start_day, end_day, context_id), _states_query_for_context_id(start_day, end_day, context_id),
legacy_select_events_context_id(start_day, end_day, context_id), legacy_select_events_context_id(start_day, end_day, context_id),
) )
else: else:
if events_entity_filter is not None: if events_entity_filter is not None:
stmt += lambda s: s.where(events_entity_filter) stmt = stmt.where(events_entity_filter)
if states_entity_filter is not None: if states_entity_filter is not None:
stmt += lambda s: s.union_all( stmt = stmt.union_all(
_states_query_for_all(start_day, end_day).where(states_entity_filter) _states_query_for_all(start_day, end_day).where(states_entity_filter)
) )
else: else:
stmt += lambda s: s.union_all(_states_query_for_all(start_day, end_day)) stmt = stmt.union_all(_states_query_for_all(start_day, end_day))
stmt += lambda s: s.order_by(Events.time_fired) return stmt.order_by(Events.time_fired)
return stmt
def _states_query_for_all(start_day: dt, end_day: dt) -> Query: def _states_query_for_all(start_day: dt, end_day: dt) -> Query:

View File

@ -4,11 +4,10 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
from datetime import datetime as dt from datetime import datetime as dt
from sqlalchemy import lambda_stmt, select from sqlalchemy import select
from sqlalchemy.orm import Query from sqlalchemy.orm import Query
from sqlalchemy.sql.elements import ClauseList from sqlalchemy.sql.elements import ClauseList
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.selectable import CTE, CompoundSelect, Select
from sqlalchemy.sql.selectable import CTE, CompoundSelect
from homeassistant.components.recorder.models import ( from homeassistant.components.recorder.models import (
DEVICE_ID_IN_EVENT, DEVICE_ID_IN_EVENT,
@ -31,11 +30,11 @@ def _select_device_id_context_ids_sub_query(
start_day: dt, start_day: dt,
end_day: dt, end_day: dt,
event_types: tuple[str, ...], event_types: tuple[str, ...],
json_quotable_device_ids: list[str], json_quoted_device_ids: list[str],
) -> CompoundSelect: ) -> CompoundSelect:
"""Generate a subquery to find context ids for multiple devices.""" """Generate a subquery to find context ids for multiple devices."""
inner = select_events_context_id_subquery(start_day, end_day, event_types).where( inner = select_events_context_id_subquery(start_day, end_day, event_types).where(
apply_event_device_id_matchers(json_quotable_device_ids) apply_event_device_id_matchers(json_quoted_device_ids)
) )
return select(inner.c.context_id).group_by(inner.c.context_id) return select(inner.c.context_id).group_by(inner.c.context_id)
@ -45,14 +44,14 @@ def _apply_devices_context_union(
start_day: dt, start_day: dt,
end_day: dt, end_day: dt,
event_types: tuple[str, ...], event_types: tuple[str, ...],
json_quotable_device_ids: list[str], json_quoted_device_ids: list[str],
) -> CompoundSelect: ) -> CompoundSelect:
"""Generate a CTE to find the device context ids and a query to find linked row.""" """Generate a CTE to find the device context ids and a query to find linked row."""
devices_cte: CTE = _select_device_id_context_ids_sub_query( devices_cte: CTE = _select_device_id_context_ids_sub_query(
start_day, start_day,
end_day, end_day,
event_types, event_types,
json_quotable_device_ids, json_quoted_device_ids,
).cte() ).cte()
return query.union_all( return query.union_all(
apply_events_context_hints( apply_events_context_hints(
@ -72,25 +71,22 @@ def devices_stmt(
start_day: dt, start_day: dt,
end_day: dt, end_day: dt,
event_types: tuple[str, ...], event_types: tuple[str, ...],
json_quotable_device_ids: list[str], json_quoted_device_ids: list[str],
) -> StatementLambdaElement: ) -> Select:
"""Generate a logbook query for multiple devices.""" """Generate a logbook query for multiple devices."""
stmt = lambda_stmt( return _apply_devices_context_union(
lambda: _apply_devices_context_union( select_events_without_states(start_day, end_day, event_types).where(
select_events_without_states(start_day, end_day, event_types).where( apply_event_device_id_matchers(json_quoted_device_ids)
apply_event_device_id_matchers(json_quotable_device_ids) ),
), start_day,
start_day, end_day,
end_day, event_types,
event_types, json_quoted_device_ids,
json_quotable_device_ids, ).order_by(Events.time_fired)
).order_by(Events.time_fired)
)
return stmt
def apply_event_device_id_matchers( def apply_event_device_id_matchers(
json_quotable_device_ids: Iterable[str], json_quoted_device_ids: Iterable[str],
) -> ClauseList: ) -> ClauseList:
"""Create matchers for the device_ids in the event_data.""" """Create matchers for the device_ids in the event_data."""
return DEVICE_ID_IN_EVENT.in_(json_quotable_device_ids) return DEVICE_ID_IN_EVENT.in_(json_quoted_device_ids)

View File

@ -5,10 +5,9 @@ from collections.abc import Iterable
from datetime import datetime as dt from datetime import datetime as dt
import sqlalchemy import sqlalchemy
from sqlalchemy import lambda_stmt, select, union_all from sqlalchemy import select, union_all
from sqlalchemy.orm import Query from sqlalchemy.orm import Query
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.selectable import CTE, CompoundSelect, Select
from sqlalchemy.sql.selectable import CTE, CompoundSelect
from homeassistant.components.recorder.models import ( from homeassistant.components.recorder.models import (
ENTITY_ID_IN_EVENT, ENTITY_ID_IN_EVENT,
@ -36,12 +35,12 @@ def _select_entities_context_ids_sub_query(
end_day: dt, end_day: dt,
event_types: tuple[str, ...], event_types: tuple[str, ...],
entity_ids: list[str], entity_ids: list[str],
json_quotable_entity_ids: list[str], json_quoted_entity_ids: list[str],
) -> CompoundSelect: ) -> CompoundSelect:
"""Generate a subquery to find context ids for multiple entities.""" """Generate a subquery to find context ids for multiple entities."""
union = union_all( union = union_all(
select_events_context_id_subquery(start_day, end_day, event_types).where( select_events_context_id_subquery(start_day, end_day, event_types).where(
apply_event_entity_id_matchers(json_quotable_entity_ids) apply_event_entity_id_matchers(json_quoted_entity_ids)
), ),
apply_entities_hints(select(States.context_id)) apply_entities_hints(select(States.context_id))
.filter((States.last_updated > start_day) & (States.last_updated < end_day)) .filter((States.last_updated > start_day) & (States.last_updated < end_day))
@ -56,7 +55,7 @@ def _apply_entities_context_union(
end_day: dt, end_day: dt,
event_types: tuple[str, ...], event_types: tuple[str, ...],
entity_ids: list[str], entity_ids: list[str],
json_quotable_entity_ids: list[str], json_quoted_entity_ids: list[str],
) -> CompoundSelect: ) -> CompoundSelect:
"""Generate a CTE to find the entity and device context ids and a query to find linked row.""" """Generate a CTE to find the entity and device context ids and a query to find linked row."""
entities_cte: CTE = _select_entities_context_ids_sub_query( entities_cte: CTE = _select_entities_context_ids_sub_query(
@ -64,7 +63,7 @@ def _apply_entities_context_union(
end_day, end_day,
event_types, event_types,
entity_ids, entity_ids,
json_quotable_entity_ids, json_quoted_entity_ids,
).cte() ).cte()
# We used to optimize this to exclude rows we already in the union with # We used to optimize this to exclude rows we already in the union with
# a States.entity_id.not_in(entity_ids) but that made the # a States.entity_id.not_in(entity_ids) but that made the
@ -91,21 +90,19 @@ def entities_stmt(
end_day: dt, end_day: dt,
event_types: tuple[str, ...], event_types: tuple[str, ...],
entity_ids: list[str], entity_ids: list[str],
json_quotable_entity_ids: list[str], json_quoted_entity_ids: list[str],
) -> StatementLambdaElement: ) -> Select:
"""Generate a logbook query for multiple entities.""" """Generate a logbook query for multiple entities."""
return lambda_stmt( return _apply_entities_context_union(
lambda: _apply_entities_context_union( select_events_without_states(start_day, end_day, event_types).where(
select_events_without_states(start_day, end_day, event_types).where( apply_event_entity_id_matchers(json_quoted_entity_ids)
apply_event_entity_id_matchers(json_quotable_entity_ids) ),
), start_day,
start_day, end_day,
end_day, event_types,
event_types, entity_ids,
entity_ids, json_quoted_entity_ids,
json_quotable_entity_ids, ).order_by(Events.time_fired)
).order_by(Events.time_fired)
)
def states_query_for_entity_ids( def states_query_for_entity_ids(
@ -118,12 +115,12 @@ def states_query_for_entity_ids(
def apply_event_entity_id_matchers( def apply_event_entity_id_matchers(
json_quotable_entity_ids: Iterable[str], json_quoted_entity_ids: Iterable[str],
) -> sqlalchemy.or_: ) -> sqlalchemy.or_:
"""Create matchers for the entity_id in the event_data.""" """Create matchers for the entity_id in the event_data."""
return ENTITY_ID_IN_EVENT.in_( return ENTITY_ID_IN_EVENT.in_(json_quoted_entity_ids) | OLD_ENTITY_ID_IN_EVENT.in_(
json_quotable_entity_ids json_quoted_entity_ids
) | OLD_ENTITY_ID_IN_EVENT.in_(json_quotable_entity_ids) )
def apply_entities_hints(query: Query) -> Query: def apply_entities_hints(query: Query) -> Query:

View File

@ -5,10 +5,9 @@ from collections.abc import Iterable
from datetime import datetime as dt from datetime import datetime as dt
import sqlalchemy import sqlalchemy
from sqlalchemy import lambda_stmt, select, union_all from sqlalchemy import select, union_all
from sqlalchemy.orm import Query from sqlalchemy.orm import Query
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.selectable import CTE, CompoundSelect, Select
from sqlalchemy.sql.selectable import CTE, CompoundSelect
from homeassistant.components.recorder.models import EventData, Events, States from homeassistant.components.recorder.models import EventData, Events, States
@ -33,14 +32,14 @@ def _select_entities_device_id_context_ids_sub_query(
end_day: dt, end_day: dt,
event_types: tuple[str, ...], event_types: tuple[str, ...],
entity_ids: list[str], entity_ids: list[str],
json_quotable_entity_ids: list[str], json_quoted_entity_ids: list[str],
json_quotable_device_ids: list[str], json_quoted_device_ids: list[str],
) -> CompoundSelect: ) -> CompoundSelect:
"""Generate a subquery to find context ids for multiple entities and multiple devices.""" """Generate a subquery to find context ids for multiple entities and multiple devices."""
union = union_all( union = union_all(
select_events_context_id_subquery(start_day, end_day, event_types).where( select_events_context_id_subquery(start_day, end_day, event_types).where(
_apply_event_entity_id_device_id_matchers( _apply_event_entity_id_device_id_matchers(
json_quotable_entity_ids, json_quotable_device_ids json_quoted_entity_ids, json_quoted_device_ids
) )
), ),
apply_entities_hints(select(States.context_id)) apply_entities_hints(select(States.context_id))
@ -56,16 +55,16 @@ def _apply_entities_devices_context_union(
end_day: dt, end_day: dt,
event_types: tuple[str, ...], event_types: tuple[str, ...],
entity_ids: list[str], entity_ids: list[str],
json_quotable_entity_ids: list[str], json_quoted_entity_ids: list[str],
json_quotable_device_ids: list[str], json_quoted_device_ids: list[str],
) -> CompoundSelect: ) -> CompoundSelect:
devices_entities_cte: CTE = _select_entities_device_id_context_ids_sub_query( devices_entities_cte: CTE = _select_entities_device_id_context_ids_sub_query(
start_day, start_day,
end_day, end_day,
event_types, event_types,
entity_ids, entity_ids,
json_quotable_entity_ids, json_quoted_entity_ids,
json_quotable_device_ids, json_quoted_device_ids,
).cte() ).cte()
# We used to optimize this to exclude rows we already in the union with # We used to optimize this to exclude rows we already in the union with
# a States.entity_id.not_in(entity_ids) but that made the # a States.entity_id.not_in(entity_ids) but that made the
@ -92,32 +91,30 @@ def entities_devices_stmt(
end_day: dt, end_day: dt,
event_types: tuple[str, ...], event_types: tuple[str, ...],
entity_ids: list[str], entity_ids: list[str],
json_quotable_entity_ids: list[str], json_quoted_entity_ids: list[str],
json_quotable_device_ids: list[str], json_quoted_device_ids: list[str],
) -> StatementLambdaElement: ) -> Select:
"""Generate a logbook query for multiple entities.""" """Generate a logbook query for multiple entities."""
stmt = lambda_stmt( stmt = _apply_entities_devices_context_union(
lambda: _apply_entities_devices_context_union( select_events_without_states(start_day, end_day, event_types).where(
select_events_without_states(start_day, end_day, event_types).where( _apply_event_entity_id_device_id_matchers(
_apply_event_entity_id_device_id_matchers( json_quoted_entity_ids, json_quoted_device_ids
json_quotable_entity_ids, json_quotable_device_ids )
) ),
), start_day,
start_day, end_day,
end_day, event_types,
event_types, entity_ids,
entity_ids, json_quoted_entity_ids,
json_quotable_entity_ids, json_quoted_device_ids,
json_quotable_device_ids, ).order_by(Events.time_fired)
).order_by(Events.time_fired)
)
return stmt return stmt
def _apply_event_entity_id_device_id_matchers( def _apply_event_entity_id_device_id_matchers(
json_quotable_entity_ids: Iterable[str], json_quotable_device_ids: Iterable[str] json_quoted_entity_ids: Iterable[str], json_quoted_device_ids: Iterable[str]
) -> sqlalchemy.or_: ) -> sqlalchemy.or_:
"""Create matchers for the device_id and entity_id in the event_data.""" """Create matchers for the device_id and entity_id in the event_data."""
return apply_event_entity_id_matchers( return apply_event_entity_id_matchers(
json_quotable_entity_ids json_quoted_entity_ids
) | apply_event_device_id_matchers(json_quotable_device_ids) ) | apply_event_device_id_matchers(json_quoted_device_ids)

View File

@ -4,6 +4,8 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from regenmaschine.controller import Controller
from homeassistant.components.sensor import ( from homeassistant.components.sensor import (
SensorDeviceClass, SensorDeviceClass,
SensorEntity, SensorEntity,
@ -13,8 +15,9 @@ from homeassistant.components.sensor import (
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import TEMP_CELSIUS, VOLUME_CUBIC_METERS from homeassistant.const import TEMP_CELSIUS, VOLUME_CUBIC_METERS
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity import EntityCategory from homeassistant.helpers.entity import EntityCategory, EntityDescription
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from . import RainMachineEntity from . import RainMachineEntity
@ -205,16 +208,33 @@ class ZoneTimeRemainingSensor(RainMachineEntity, SensorEntity):
entity_description: RainMachineSensorDescriptionUid entity_description: RainMachineSensorDescriptionUid
def __init__(
self,
entry: ConfigEntry,
coordinator: DataUpdateCoordinator,
controller: Controller,
description: EntityDescription,
) -> None:
"""Initialize."""
super().__init__(entry, coordinator, controller, description)
self._running_or_queued: bool = False
@callback @callback
def update_from_latest_data(self) -> None: def update_from_latest_data(self) -> None:
"""Update the state.""" """Update the state."""
data = self.coordinator.data[self.entity_description.uid] data = self.coordinator.data[self.entity_description.uid]
now = utcnow() now = utcnow()
if RUN_STATE_MAP.get(data["state"]) != RunStates.RUNNING: if RUN_STATE_MAP.get(data["state"]) == RunStates.NOT_RUNNING:
# If the zone isn't actively running, return immediately: if self._running_or_queued:
# If we go from running to not running, update the state to be right
# now (i.e., the time the zone stopped running):
self._attr_native_value = now
self._running_or_queued = False
return return
self._running_or_queued = True
new_timestamp = now + timedelta(seconds=data["remaining"]) new_timestamp = now + timedelta(seconds=data["remaining"])
if self._attr_native_value: if self._attr_native_value:

View File

@ -9,13 +9,11 @@ import logging
import time import time
from typing import Any, cast from typing import Any, cast
from sqlalchemy import Column, Text, and_, func, lambda_stmt, or_, select from sqlalchemy import Column, Text, and_, func, or_, select
from sqlalchemy.engine.row import Row from sqlalchemy.engine.row import Row
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.selectable import Select, Subquery
from sqlalchemy.sql.selectable import Subquery
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.websocket_api.const import ( from homeassistant.components.websocket_api.const import (
@ -36,7 +34,7 @@ from .models import (
process_timestamp_to_utc_isoformat, process_timestamp_to_utc_isoformat,
row_to_compressed_state, row_to_compressed_state,
) )
from .util import execute_stmt_lambda_element, session_scope from .util import execute_stmt, session_scope
# mypy: allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-defs, no-check-untyped-defs
@ -116,22 +114,18 @@ def _schema_version(hass: HomeAssistant) -> int:
return recorder.get_instance(hass).schema_version return recorder.get_instance(hass).schema_version
def lambda_stmt_and_join_attributes( def stmt_and_join_attributes(
schema_version: int, no_attributes: bool, include_last_changed: bool = True schema_version: int, no_attributes: bool, include_last_changed: bool = True
) -> tuple[StatementLambdaElement, bool]: ) -> tuple[Select, bool]:
"""Return the lambda_stmt and if StateAttributes should be joined. """Return the stmt and if StateAttributes should be joined."""
Because these are lambda_stmt the values inside the lambdas need
to be explicitly written out to avoid caching the wrong values.
"""
# If no_attributes was requested we do the query # If no_attributes was requested we do the query
# without the attributes fields and do not join the # without the attributes fields and do not join the
# state_attributes table # state_attributes table
if no_attributes: if no_attributes:
if include_last_changed: if include_last_changed:
return lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR)), False return select(*QUERY_STATE_NO_ATTR), False
return ( return (
lambda_stmt(lambda: select(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED)), select(*QUERY_STATE_NO_ATTR_NO_LAST_CHANGED),
False, False,
) )
# If we in the process of migrating schema we do # If we in the process of migrating schema we do
@ -140,19 +134,19 @@ def lambda_stmt_and_join_attributes(
if schema_version < 25: if schema_version < 25:
if include_last_changed: if include_last_changed:
return ( return (
lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25)), select(*QUERY_STATES_PRE_SCHEMA_25),
False, False,
) )
return ( return (
lambda_stmt(lambda: select(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED)), select(*QUERY_STATES_PRE_SCHEMA_25_NO_LAST_CHANGED),
False, False,
) )
# Finally if no migration is in progress and no_attributes # Finally if no migration is in progress and no_attributes
# was not requested, we query both attributes columns and # was not requested, we query both attributes columns and
# join state_attributes # join state_attributes
if include_last_changed: if include_last_changed:
return lambda_stmt(lambda: select(*QUERY_STATES)), True return select(*QUERY_STATES), True
return lambda_stmt(lambda: select(*QUERY_STATES_NO_LAST_CHANGED)), True return select(*QUERY_STATES_NO_LAST_CHANGED), True
def get_significant_states( def get_significant_states(
@ -184,7 +178,7 @@ def get_significant_states(
) )
def _ignore_domains_filter(query: Query) -> Query: def _ignore_domains_filter(query: Select) -> Select:
"""Add a filter to ignore domains we do not fetch history for.""" """Add a filter to ignore domains we do not fetch history for."""
return query.filter( return query.filter(
and_( and_(
@ -204,9 +198,9 @@ def _significant_states_stmt(
filters: Filters | None, filters: Filters | None,
significant_changes_only: bool, significant_changes_only: bool,
no_attributes: bool, no_attributes: bool,
) -> StatementLambdaElement: ) -> Select:
"""Query the database for significant state changes.""" """Query the database for significant state changes."""
stmt, join_attributes = lambda_stmt_and_join_attributes( stmt, join_attributes = stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=not significant_changes_only schema_version, no_attributes, include_last_changed=not significant_changes_only
) )
if ( if (
@ -215,11 +209,11 @@ def _significant_states_stmt(
and significant_changes_only and significant_changes_only
and split_entity_id(entity_ids[0])[0] not in SIGNIFICANT_DOMAINS and split_entity_id(entity_ids[0])[0] not in SIGNIFICANT_DOMAINS
): ):
stmt += lambda q: q.filter( stmt = stmt.filter(
(States.last_changed == States.last_updated) | States.last_changed.is_(None) (States.last_changed == States.last_updated) | States.last_changed.is_(None)
) )
elif significant_changes_only: elif significant_changes_only:
stmt += lambda q: q.filter( stmt = stmt.filter(
or_( or_(
*[ *[
States.entity_id.like(entity_domain) States.entity_id.like(entity_domain)
@ -233,25 +227,22 @@ def _significant_states_stmt(
) )
if entity_ids: if entity_ids:
stmt += lambda q: q.filter(States.entity_id.in_(entity_ids)) stmt = stmt.filter(States.entity_id.in_(entity_ids))
else: else:
stmt += _ignore_domains_filter stmt = _ignore_domains_filter(stmt)
if filters and filters.has_config: if filters and filters.has_config:
entity_filter = filters.states_entity_filter() entity_filter = filters.states_entity_filter()
stmt = stmt.add_criteria( stmt = stmt.filter(entity_filter)
lambda q: q.filter(entity_filter), track_on=[filters]
)
stmt += lambda q: q.filter(States.last_updated > start_time) stmt = stmt.filter(States.last_updated > start_time)
if end_time: if end_time:
stmt += lambda q: q.filter(States.last_updated < end_time) stmt = stmt.filter(States.last_updated < end_time)
if join_attributes: if join_attributes:
stmt += lambda q: q.outerjoin( stmt = stmt.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id StateAttributes, States.attributes_id == StateAttributes.attributes_id
) )
stmt += lambda q: q.order_by(States.entity_id, States.last_updated) return stmt.order_by(States.entity_id, States.last_updated)
return stmt
def get_significant_states_with_session( def get_significant_states_with_session(
@ -288,9 +279,7 @@ def get_significant_states_with_session(
significant_changes_only, significant_changes_only,
no_attributes, no_attributes,
) )
states = execute_stmt_lambda_element( states = execute_stmt(session, stmt, None if entity_ids else start_time, end_time)
session, stmt, None if entity_ids else start_time, end_time
)
return _sorted_states_to_dict( return _sorted_states_to_dict(
hass, hass,
session, session,
@ -342,28 +331,28 @@ def _state_changed_during_period_stmt(
no_attributes: bool, no_attributes: bool,
descending: bool, descending: bool,
limit: int | None, limit: int | None,
) -> StatementLambdaElement: ) -> Select:
stmt, join_attributes = lambda_stmt_and_join_attributes( stmt, join_attributes = stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=False schema_version, no_attributes, include_last_changed=False
) )
stmt += lambda q: q.filter( stmt = stmt.filter(
((States.last_changed == States.last_updated) | States.last_changed.is_(None)) ((States.last_changed == States.last_updated) | States.last_changed.is_(None))
& (States.last_updated > start_time) & (States.last_updated > start_time)
) )
if end_time: if end_time:
stmt += lambda q: q.filter(States.last_updated < end_time) stmt = stmt.filter(States.last_updated < end_time)
if entity_id: if entity_id:
stmt += lambda q: q.filter(States.entity_id == entity_id) stmt = stmt.filter(States.entity_id == entity_id)
if join_attributes: if join_attributes:
stmt += lambda q: q.outerjoin( stmt = stmt.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id StateAttributes, States.attributes_id == StateAttributes.attributes_id
) )
if descending: if descending:
stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc()) stmt = stmt.order_by(States.entity_id, States.last_updated.desc())
else: else:
stmt += lambda q: q.order_by(States.entity_id, States.last_updated) stmt = stmt.order_by(States.entity_id, States.last_updated)
if limit: if limit:
stmt += lambda q: q.limit(limit) stmt = stmt.limit(limit)
return stmt return stmt
@ -391,7 +380,7 @@ def state_changes_during_period(
descending, descending,
limit, limit,
) )
states = execute_stmt_lambda_element( states = execute_stmt(
session, stmt, None if entity_id else start_time, end_time session, stmt, None if entity_id else start_time, end_time
) )
return cast( return cast(
@ -409,23 +398,22 @@ def state_changes_during_period(
def _get_last_state_changes_stmt( def _get_last_state_changes_stmt(
schema_version: int, number_of_states: int, entity_id: str | None schema_version: int, number_of_states: int, entity_id: str | None
) -> StatementLambdaElement: ) -> Select:
stmt, join_attributes = lambda_stmt_and_join_attributes( stmt, join_attributes = stmt_and_join_attributes(
schema_version, False, include_last_changed=False schema_version, False, include_last_changed=False
) )
stmt += lambda q: q.filter( stmt = stmt.filter(
(States.last_changed == States.last_updated) | States.last_changed.is_(None) (States.last_changed == States.last_updated) | States.last_changed.is_(None)
) )
if entity_id: if entity_id:
stmt += lambda q: q.filter(States.entity_id == entity_id) stmt = stmt.filter(States.entity_id == entity_id)
if join_attributes: if join_attributes:
stmt += lambda q: q.outerjoin( stmt = stmt.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id StateAttributes, States.attributes_id == StateAttributes.attributes_id
) )
stmt += lambda q: q.order_by(States.entity_id, States.last_updated.desc()).limit( return stmt.order_by(States.entity_id, States.last_updated.desc()).limit(
number_of_states number_of_states
) )
return stmt
def get_last_state_changes( def get_last_state_changes(
@ -440,7 +428,7 @@ def get_last_state_changes(
stmt = _get_last_state_changes_stmt( stmt = _get_last_state_changes_stmt(
_schema_version(hass), number_of_states, entity_id _schema_version(hass), number_of_states, entity_id
) )
states = list(execute_stmt_lambda_element(session, stmt)) states = list(execute_stmt(session, stmt))
return cast( return cast(
MutableMapping[str, list[State]], MutableMapping[str, list[State]],
_sorted_states_to_dict( _sorted_states_to_dict(
@ -460,14 +448,14 @@ def _get_states_for_entites_stmt(
utc_point_in_time: datetime, utc_point_in_time: datetime,
entity_ids: list[str], entity_ids: list[str],
no_attributes: bool, no_attributes: bool,
) -> StatementLambdaElement: ) -> Select:
"""Baked query to get states for specific entities.""" """Baked query to get states for specific entities."""
stmt, join_attributes = lambda_stmt_and_join_attributes( stmt, join_attributes = stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=True schema_version, no_attributes, include_last_changed=True
) )
# We got an include-list of entities, accelerate the query by filtering already # We got an include-list of entities, accelerate the query by filtering already
# in the inner query. # in the inner query.
stmt += lambda q: q.where( stmt = stmt.where(
States.state_id States.state_id
== ( == (
select(func.max(States.state_id).label("max_state_id")) select(func.max(States.state_id).label("max_state_id"))
@ -481,7 +469,7 @@ def _get_states_for_entites_stmt(
).c.max_state_id ).c.max_state_id
) )
if join_attributes: if join_attributes:
stmt += lambda q: q.outerjoin( stmt = stmt.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id) StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
) )
return stmt return stmt
@ -512,9 +500,9 @@ def _get_states_for_all_stmt(
utc_point_in_time: datetime, utc_point_in_time: datetime,
filters: Filters | None, filters: Filters | None,
no_attributes: bool, no_attributes: bool,
) -> StatementLambdaElement: ) -> Select:
"""Baked query to get states for all entities.""" """Baked query to get states for all entities."""
stmt, join_attributes = lambda_stmt_and_join_attributes( stmt, join_attributes = stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=True schema_version, no_attributes, include_last_changed=True
) )
# We did not get an include-list of entities, query all states in the inner # We did not get an include-list of entities, query all states in the inner
@ -524,7 +512,7 @@ def _get_states_for_all_stmt(
most_recent_states_by_date = _generate_most_recent_states_by_date( most_recent_states_by_date = _generate_most_recent_states_by_date(
run_start, utc_point_in_time run_start, utc_point_in_time
) )
stmt += lambda q: q.where( stmt = stmt.where(
States.state_id States.state_id
== ( == (
select(func.max(States.state_id).label("max_state_id")) select(func.max(States.state_id).label("max_state_id"))
@ -540,12 +528,12 @@ def _get_states_for_all_stmt(
.subquery() .subquery()
).c.max_state_id, ).c.max_state_id,
) )
stmt += _ignore_domains_filter stmt = _ignore_domains_filter(stmt)
if filters and filters.has_config: if filters and filters.has_config:
entity_filter = filters.states_entity_filter() entity_filter = filters.states_entity_filter()
stmt = stmt.add_criteria(lambda q: q.filter(entity_filter), track_on=[filters]) stmt = stmt.filter(entity_filter)
if join_attributes: if join_attributes:
stmt += lambda q: q.outerjoin( stmt = stmt.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id) StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
) )
return stmt return stmt
@ -563,7 +551,7 @@ def _get_rows_with_session(
"""Return the states at a specific point in time.""" """Return the states at a specific point in time."""
schema_version = _schema_version(hass) schema_version = _schema_version(hass)
if entity_ids and len(entity_ids) == 1: if entity_ids and len(entity_ids) == 1:
return execute_stmt_lambda_element( return execute_stmt(
session, session,
_get_single_entity_states_stmt( _get_single_entity_states_stmt(
schema_version, utc_point_in_time, entity_ids[0], no_attributes schema_version, utc_point_in_time, entity_ids[0], no_attributes
@ -588,7 +576,7 @@ def _get_rows_with_session(
schema_version, run.start, utc_point_in_time, filters, no_attributes schema_version, run.start, utc_point_in_time, filters, no_attributes
) )
return execute_stmt_lambda_element(session, stmt) return execute_stmt(session, stmt)
def _get_single_entity_states_stmt( def _get_single_entity_states_stmt(
@ -596,14 +584,14 @@ def _get_single_entity_states_stmt(
utc_point_in_time: datetime, utc_point_in_time: datetime,
entity_id: str, entity_id: str,
no_attributes: bool = False, no_attributes: bool = False,
) -> StatementLambdaElement: ) -> Select:
# Use an entirely different (and extremely fast) query if we only # Use an entirely different (and extremely fast) query if we only
# have a single entity id # have a single entity id
stmt, join_attributes = lambda_stmt_and_join_attributes( stmt, join_attributes = stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=True schema_version, no_attributes, include_last_changed=True
) )
stmt += ( stmt = (
lambda q: q.filter( stmt.filter(
States.last_updated < utc_point_in_time, States.last_updated < utc_point_in_time,
States.entity_id == entity_id, States.entity_id == entity_id,
) )
@ -611,7 +599,7 @@ def _get_single_entity_states_stmt(
.limit(1) .limit(1)
) )
if join_attributes: if join_attributes:
stmt += lambda q: q.outerjoin( stmt = stmt.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id StateAttributes, States.attributes_id == StateAttributes.attributes_id
) )
return stmt return stmt

View File

@ -14,13 +14,12 @@ import re
from statistics import mean from statistics import mean
from typing import TYPE_CHECKING, Any, Literal, overload from typing import TYPE_CHECKING, Any, Literal, overload
from sqlalchemy import bindparam, func, lambda_stmt, select from sqlalchemy import bindparam, func, select
from sqlalchemy.engine.row import Row from sqlalchemy.engine.row import Row
from sqlalchemy.exc import SQLAlchemyError, StatementError from sqlalchemy.exc import SQLAlchemyError, StatementError
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal_column, true from sqlalchemy.sql.expression import literal_column, true
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.selectable import Select, Subquery
from sqlalchemy.sql.selectable import Subquery
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
@ -53,12 +52,7 @@ from .models import (
process_timestamp, process_timestamp,
process_timestamp_to_utc_isoformat, process_timestamp_to_utc_isoformat,
) )
from .util import ( from .util import execute, execute_stmt, retryable_database_job, session_scope
execute,
execute_stmt_lambda_element,
retryable_database_job,
session_scope,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from . import Recorder from . import Recorder
@ -483,10 +477,10 @@ def delete_statistics_meta_duplicates(session: Session) -> None:
def _compile_hourly_statistics_summary_mean_stmt( def _compile_hourly_statistics_summary_mean_stmt(
start_time: datetime, end_time: datetime start_time: datetime, end_time: datetime
) -> StatementLambdaElement: ) -> Select:
"""Generate the summary mean statement for hourly statistics.""" """Generate the summary mean statement for hourly statistics."""
return lambda_stmt( return (
lambda: select(*QUERY_STATISTICS_SUMMARY_MEAN) select(*QUERY_STATISTICS_SUMMARY_MEAN)
.filter(StatisticsShortTerm.start >= start_time) .filter(StatisticsShortTerm.start >= start_time)
.filter(StatisticsShortTerm.start < end_time) .filter(StatisticsShortTerm.start < end_time)
.group_by(StatisticsShortTerm.metadata_id) .group_by(StatisticsShortTerm.metadata_id)
@ -509,7 +503,7 @@ def compile_hourly_statistics(
# Compute last hour's average, min, max # Compute last hour's average, min, max
summary: dict[str, StatisticData] = {} summary: dict[str, StatisticData] = {}
stmt = _compile_hourly_statistics_summary_mean_stmt(start_time, end_time) stmt = _compile_hourly_statistics_summary_mean_stmt(start_time, end_time)
stats = execute_stmt_lambda_element(session, stmt) stats = execute_stmt(session, stmt)
if stats: if stats:
for stat in stats: for stat in stats:
@ -691,17 +685,17 @@ def _generate_get_metadata_stmt(
statistic_ids: list[str] | tuple[str] | None = None, statistic_ids: list[str] | tuple[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None, statistic_source: str | None = None,
) -> StatementLambdaElement: ) -> Select:
"""Generate a statement to fetch metadata.""" """Generate a statement to fetch metadata."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META)) stmt = select(*QUERY_STATISTIC_META)
if statistic_ids is not None: if statistic_ids is not None:
stmt += lambda q: q.where(StatisticsMeta.statistic_id.in_(statistic_ids)) stmt = stmt.where(StatisticsMeta.statistic_id.in_(statistic_ids))
if statistic_source is not None: if statistic_source is not None:
stmt += lambda q: q.where(StatisticsMeta.source == statistic_source) stmt = stmt.where(StatisticsMeta.source == statistic_source)
if statistic_type == "mean": if statistic_type == "mean":
stmt += lambda q: q.where(StatisticsMeta.has_mean == true()) stmt = stmt.where(StatisticsMeta.has_mean == true())
elif statistic_type == "sum": elif statistic_type == "sum":
stmt += lambda q: q.where(StatisticsMeta.has_sum == true()) stmt = stmt.where(StatisticsMeta.has_sum == true())
return stmt return stmt
@ -723,7 +717,7 @@ def get_metadata_with_session(
# Fetch metatadata from the database # Fetch metatadata from the database
stmt = _generate_get_metadata_stmt(statistic_ids, statistic_type, statistic_source) stmt = _generate_get_metadata_stmt(statistic_ids, statistic_type, statistic_source)
result = execute_stmt_lambda_element(session, stmt) result = execute_stmt(session, stmt)
if not result: if not result:
return {} return {}
@ -985,44 +979,30 @@ def _statistics_during_period_stmt(
start_time: datetime, start_time: datetime,
end_time: datetime | None, end_time: datetime | None,
metadata_ids: list[int] | None, metadata_ids: list[int] | None,
) -> StatementLambdaElement: ) -> Select:
"""Prepare a database query for statistics during a given period. """Prepare a database query for statistics during a given period."""
stmt = select(*QUERY_STATISTICS).filter(Statistics.start >= start_time)
This prepares a lambda_stmt query, so we don't insert the parameters yet.
"""
stmt = lambda_stmt(
lambda: select(*QUERY_STATISTICS).filter(Statistics.start >= start_time)
)
if end_time is not None: if end_time is not None:
stmt += lambda q: q.filter(Statistics.start < end_time) stmt = stmt.filter(Statistics.start < end_time)
if metadata_ids: if metadata_ids:
stmt += lambda q: q.filter(Statistics.metadata_id.in_(metadata_ids)) stmt = stmt.filter(Statistics.metadata_id.in_(metadata_ids))
stmt += lambda q: q.order_by(Statistics.metadata_id, Statistics.start) return stmt.order_by(Statistics.metadata_id, Statistics.start)
return stmt
def _statistics_during_period_stmt_short_term( def _statistics_during_period_stmt_short_term(
start_time: datetime, start_time: datetime,
end_time: datetime | None, end_time: datetime | None,
metadata_ids: list[int] | None, metadata_ids: list[int] | None,
) -> StatementLambdaElement: ) -> Select:
"""Prepare a database query for short term statistics during a given period. """Prepare a database query for short term statistics during a given period."""
stmt = select(*QUERY_STATISTICS_SHORT_TERM).filter(
This prepares a lambda_stmt query, so we don't insert the parameters yet. StatisticsShortTerm.start >= start_time
"""
stmt = lambda_stmt(
lambda: select(*QUERY_STATISTICS_SHORT_TERM).filter(
StatisticsShortTerm.start >= start_time
)
) )
if end_time is not None: if end_time is not None:
stmt += lambda q: q.filter(StatisticsShortTerm.start < end_time) stmt = stmt.filter(StatisticsShortTerm.start < end_time)
if metadata_ids: if metadata_ids:
stmt += lambda q: q.filter(StatisticsShortTerm.metadata_id.in_(metadata_ids)) stmt = stmt.filter(StatisticsShortTerm.metadata_id.in_(metadata_ids))
stmt += lambda q: q.order_by( return stmt.order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start)
StatisticsShortTerm.metadata_id, StatisticsShortTerm.start
)
return stmt
def statistics_during_period( def statistics_during_period(
@ -1057,7 +1037,7 @@ def statistics_during_period(
else: else:
table = Statistics table = Statistics
stmt = _statistics_during_period_stmt(start_time, end_time, metadata_ids) stmt = _statistics_during_period_stmt(start_time, end_time, metadata_ids)
stats = execute_stmt_lambda_element(session, stmt) stats = execute_stmt(session, stmt)
if not stats: if not stats:
return {} return {}
@ -1088,10 +1068,10 @@ def statistics_during_period(
def _get_last_statistics_stmt( def _get_last_statistics_stmt(
metadata_id: int, metadata_id: int,
number_of_stats: int, number_of_stats: int,
) -> StatementLambdaElement: ) -> Select:
"""Generate a statement for number_of_stats statistics for a given statistic_id.""" """Generate a statement for number_of_stats statistics for a given statistic_id."""
return lambda_stmt( return (
lambda: select(*QUERY_STATISTICS) select(*QUERY_STATISTICS)
.filter_by(metadata_id=metadata_id) .filter_by(metadata_id=metadata_id)
.order_by(Statistics.metadata_id, Statistics.start.desc()) .order_by(Statistics.metadata_id, Statistics.start.desc())
.limit(number_of_stats) .limit(number_of_stats)
@ -1101,10 +1081,10 @@ def _get_last_statistics_stmt(
def _get_last_statistics_short_term_stmt( def _get_last_statistics_short_term_stmt(
metadata_id: int, metadata_id: int,
number_of_stats: int, number_of_stats: int,
) -> StatementLambdaElement: ) -> Select:
"""Generate a statement for number_of_stats short term statistics for a given statistic_id.""" """Generate a statement for number_of_stats short term statistics for a given statistic_id."""
return lambda_stmt( return (
lambda: select(*QUERY_STATISTICS_SHORT_TERM) select(*QUERY_STATISTICS_SHORT_TERM)
.filter_by(metadata_id=metadata_id) .filter_by(metadata_id=metadata_id)
.order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc()) .order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc())
.limit(number_of_stats) .limit(number_of_stats)
@ -1130,7 +1110,7 @@ def _get_last_statistics(
stmt = _get_last_statistics_stmt(metadata_id, number_of_stats) stmt = _get_last_statistics_stmt(metadata_id, number_of_stats)
else: else:
stmt = _get_last_statistics_short_term_stmt(metadata_id, number_of_stats) stmt = _get_last_statistics_short_term_stmt(metadata_id, number_of_stats)
stats = execute_stmt_lambda_element(session, stmt) stats = execute_stmt(session, stmt)
if not stats: if not stats:
return {} return {}
@ -1180,11 +1160,11 @@ def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery:
def _latest_short_term_statistics_stmt( def _latest_short_term_statistics_stmt(
metadata_ids: list[int], metadata_ids: list[int],
) -> StatementLambdaElement: ) -> Select:
"""Create the statement for finding the latest short term stat rows.""" """Create the statement for finding the latest short term stat rows."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) stmt = select(*QUERY_STATISTICS_SHORT_TERM)
most_recent_statistic_row = _generate_most_recent_statistic_row(metadata_ids) most_recent_statistic_row = _generate_most_recent_statistic_row(metadata_ids)
stmt += lambda s: s.join( return stmt.join(
most_recent_statistic_row, most_recent_statistic_row,
( (
StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable
@ -1192,7 +1172,6 @@ def _latest_short_term_statistics_stmt(
) )
& (StatisticsShortTerm.start == most_recent_statistic_row.c.start_max), & (StatisticsShortTerm.start == most_recent_statistic_row.c.start_max),
) )
return stmt
def get_latest_short_term_statistics( def get_latest_short_term_statistics(
@ -1215,7 +1194,7 @@ def get_latest_short_term_statistics(
if statistic_id in metadata if statistic_id in metadata
] ]
stmt = _latest_short_term_statistics_stmt(metadata_ids) stmt = _latest_short_term_statistics_stmt(metadata_ids)
stats = execute_stmt_lambda_element(session, stmt) stats = execute_stmt(session, stmt)
if not stats: if not stats:
return {} return {}

View File

@ -22,7 +22,6 @@ from sqlalchemy.engine.row import Row
from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.exc import OperationalError, SQLAlchemyError
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.lambdas import StatementLambdaElement
from typing_extensions import Concatenate, ParamSpec from typing_extensions import Concatenate, ParamSpec
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -167,9 +166,9 @@ def execute(
assert False # unreachable # pragma: no cover assert False # unreachable # pragma: no cover
def execute_stmt_lambda_element( def execute_stmt(
session: Session, session: Session,
stmt: StatementLambdaElement, query: Query,
start_time: datetime | None = None, start_time: datetime | None = None,
end_time: datetime | None = None, end_time: datetime | None = None,
yield_per: int | None = DEFAULT_YIELD_STATES_ROWS, yield_per: int | None = DEFAULT_YIELD_STATES_ROWS,
@ -185,11 +184,12 @@ def execute_stmt_lambda_element(
specific entities) since they are usually faster specific entities) since they are usually faster
with .all(). with .all().
""" """
executed = session.execute(stmt)
use_all = not start_time or ((end_time or dt_util.utcnow()) - start_time).days <= 1 use_all = not start_time or ((end_time or dt_util.utcnow()) - start_time).days <= 1
for tryno in range(0, RETRIES): for tryno in range(0, RETRIES):
try: try:
return executed.all() if use_all else executed.yield_per(yield_per) # type: ignore[no-any-return] if use_all:
return session.execute(query).all() # type: ignore[no-any-return]
return session.execute(query).yield_per(yield_per) # type: ignore[no-any-return]
except SQLAlchemyError as err: except SQLAlchemyError as err:
_LOGGER.error("Error executing query: %s", err) _LOGGER.error("Error executing query: %s", err)
if tryno == RETRIES - 1: if tryno == RETRIES - 1:

View File

@ -7,7 +7,7 @@
"samsungctl[websocket]==0.7.1", "samsungctl[websocket]==0.7.1",
"samsungtvws[async,encrypted]==2.5.0", "samsungtvws[async,encrypted]==2.5.0",
"wakeonlan==2.0.1", "wakeonlan==2.0.1",
"async-upnp-client==0.30.1" "async-upnp-client==0.31.1"
], ],
"ssdp": [ "ssdp": [
{ {

View File

@ -751,17 +751,23 @@ class SonosMediaPlayerEntity(SonosEntity, MediaPlayerEntity):
media_content_type, media_content_type,
) )
def join_players(self, group_members): async def async_join_players(self, group_members):
"""Join `group_members` as a player group with the current player.""" """Join `group_members` as a player group with the current player."""
speakers = [] async with self.hass.data[DATA_SONOS].topology_condition:
for entity_id in group_members: speakers = []
if speaker := self.hass.data[DATA_SONOS].entity_id_mappings.get(entity_id): for entity_id in group_members:
speakers.append(speaker) if speaker := self.hass.data[DATA_SONOS].entity_id_mappings.get(
else: entity_id
raise HomeAssistantError(f"Not a known Sonos entity_id: {entity_id}") ):
speakers.append(speaker)
else:
raise HomeAssistantError(
f"Not a known Sonos entity_id: {entity_id}"
)
self.speaker.join(speakers) await self.hass.async_add_executor_job(self.speaker.join, speakers)
def unjoin_player(self): async def async_unjoin_player(self):
"""Remove this player from any group.""" """Remove this player from any group."""
self.speaker.unjoin() async with self.hass.data[DATA_SONOS].topology_condition:
await self.hass.async_add_executor_job(self.speaker.unjoin)

View File

@ -2,7 +2,7 @@
"domain": "ssdp", "domain": "ssdp",
"name": "Simple Service Discovery Protocol (SSDP)", "name": "Simple Service Discovery Protocol (SSDP)",
"documentation": "https://www.home-assistant.io/integrations/ssdp", "documentation": "https://www.home-assistant.io/integrations/ssdp",
"requirements": ["async-upnp-client==0.30.1"], "requirements": ["async-upnp-client==0.31.1"],
"dependencies": ["network"], "dependencies": ["network"],
"after_dependencies": ["zeroconf"], "after_dependencies": ["zeroconf"],
"codeowners": [], "codeowners": [],

View File

@ -3,7 +3,7 @@
"name": "UPnP/IGD", "name": "UPnP/IGD",
"config_flow": true, "config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/upnp", "documentation": "https://www.home-assistant.io/integrations/upnp",
"requirements": ["async-upnp-client==0.30.1", "getmac==0.8.2"], "requirements": ["async-upnp-client==0.31.1", "getmac==0.8.2"],
"dependencies": ["network", "ssdp"], "dependencies": ["network", "ssdp"],
"codeowners": ["@StevenLooman", "@ehendrix23"], "codeowners": ["@StevenLooman", "@ehendrix23"],
"ssdp": [ "ssdp": [

View File

@ -3,7 +3,7 @@
"name": "Belkin WeMo", "name": "Belkin WeMo",
"config_flow": true, "config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/wemo", "documentation": "https://www.home-assistant.io/integrations/wemo",
"requirements": ["pywemo==0.8.1"], "requirements": ["pywemo==0.9.1"],
"ssdp": [ "ssdp": [
{ {
"manufacturer": "Belkin International Inc." "manufacturer": "Belkin International Inc."

View File

@ -2,7 +2,7 @@
"domain": "yeelight", "domain": "yeelight",
"name": "Yeelight", "name": "Yeelight",
"documentation": "https://www.home-assistant.io/integrations/yeelight", "documentation": "https://www.home-assistant.io/integrations/yeelight",
"requirements": ["yeelight==0.7.10", "async-upnp-client==0.30.1"], "requirements": ["yeelight==0.7.10", "async-upnp-client==0.31.1"],
"codeowners": ["@zewelor", "@shenxn", "@starkillerOG", "@alexyao2015"], "codeowners": ["@zewelor", "@shenxn", "@starkillerOG", "@alexyao2015"],
"config_flow": true, "config_flow": true,
"dependencies": ["network"], "dependencies": ["network"],

View File

@ -7,7 +7,7 @@ from .backports.enum import StrEnum
MAJOR_VERSION: Final = 2022 MAJOR_VERSION: Final = 2022
MINOR_VERSION: Final = 6 MINOR_VERSION: Final = 6
PATCH_VERSION: Final = "3" PATCH_VERSION: Final = "4"
__short_version__: Final = f"{MAJOR_VERSION}.{MINOR_VERSION}" __short_version__: Final = f"{MAJOR_VERSION}.{MINOR_VERSION}"
__version__: Final = f"{__short_version__}.{PATCH_VERSION}" __version__: Final = f"{__short_version__}.{PATCH_VERSION}"
REQUIRED_PYTHON_VER: Final[tuple[int, int, int]] = (3, 9, 0) REQUIRED_PYTHON_VER: Final[tuple[int, int, int]] = (3, 9, 0)

View File

@ -4,7 +4,7 @@ aiodiscover==1.4.11
aiohttp==3.8.1 aiohttp==3.8.1
aiohttp_cors==0.7.0 aiohttp_cors==0.7.0
astral==2.2 astral==2.2
async-upnp-client==0.30.1 async-upnp-client==0.31.1
async_timeout==4.0.2 async_timeout==4.0.2
atomicwrites==1.4.0 atomicwrites==1.4.0
attrs==21.2.0 attrs==21.2.0

View File

@ -336,7 +336,7 @@ asterisk_mbox==0.5.0
# homeassistant.components.ssdp # homeassistant.components.ssdp
# homeassistant.components.upnp # homeassistant.components.upnp
# homeassistant.components.yeelight # homeassistant.components.yeelight
async-upnp-client==0.30.1 async-upnp-client==0.31.1
# homeassistant.components.supla # homeassistant.components.supla
asyncpysupla==0.0.5 asyncpysupla==0.0.5
@ -2023,7 +2023,7 @@ pyvolumio==0.1.5
pywebpush==1.9.2 pywebpush==1.9.2
# homeassistant.components.wemo # homeassistant.components.wemo
pywemo==0.8.1 pywemo==0.9.1
# homeassistant.components.wilight # homeassistant.components.wilight
pywilight==0.0.70 pywilight==0.0.70

View File

@ -278,7 +278,7 @@ arcam-fmj==0.12.0
# homeassistant.components.ssdp # homeassistant.components.ssdp
# homeassistant.components.upnp # homeassistant.components.upnp
# homeassistant.components.yeelight # homeassistant.components.yeelight
async-upnp-client==0.30.1 async-upnp-client==0.31.1
# homeassistant.components.sleepiq # homeassistant.components.sleepiq
asyncsleepiq==1.2.3 asyncsleepiq==1.2.3
@ -1343,7 +1343,7 @@ pyvolumio==0.1.5
pywebpush==1.9.2 pywebpush==1.9.2
# homeassistant.components.wemo # homeassistant.components.wemo
pywemo==0.8.1 pywemo==0.9.1
# homeassistant.components.wilight # homeassistant.components.wilight
pywilight==0.0.70 pywilight==0.0.70

View File

@ -1,5 +1,5 @@
[metadata] [metadata]
version = 2022.6.3 version = 2022.6.4
url = https://www.home-assistant.io/ url = https://www.home-assistant.io/
[options] [options]

View File

@ -1,4 +1,4 @@
"""Common stuff for AVM Fritz!Box tests.""" """Common stuff for Fritz!Tools tests."""
import logging import logging
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -73,13 +73,19 @@ class FritzHostMock(FritzHosts):
return MOCK_MESH_DATA return MOCK_MESH_DATA
@pytest.fixture(name="fc_data")
def fc_data_mock():
"""Fixture for default fc_data."""
return MOCK_FB_SERVICES
@pytest.fixture() @pytest.fixture()
def fc_class_mock(): def fc_class_mock(fc_data):
"""Fixture that sets up a mocked FritzConnection class.""" """Fixture that sets up a mocked FritzConnection class."""
with patch( with patch(
"homeassistant.components.fritz.common.FritzConnection", autospec=True "homeassistant.components.fritz.common.FritzConnection", autospec=True
) as result: ) as result:
result.return_value = FritzConnectionMock(MOCK_FB_SERVICES) result.return_value = FritzConnectionMock(fc_data)
yield result yield result

View File

@ -1,4 +1,4 @@
"""Common stuff for AVM Fritz!Box tests.""" """Common stuff for Fritz!Tools tests."""
from homeassistant.components import ssdp from homeassistant.components import ssdp
from homeassistant.components.fritz.const import DOMAIN from homeassistant.components.fritz.const import DOMAIN
from homeassistant.components.ssdp import ATTR_UPNP_FRIENDLY_NAME, ATTR_UPNP_UDN from homeassistant.components.ssdp import ATTR_UPNP_FRIENDLY_NAME, ATTR_UPNP_UDN
@ -194,6 +194,7 @@ MOCK_FB_SERVICES: dict[str, dict] = {
}, },
} }
MOCK_MESH_DATA = { MOCK_MESH_DATA = {
"schema_version": "1.9", "schema_version": "1.9",
"nodes": [ "nodes": [

View File

@ -1,4 +1,4 @@
"""Tests for Shelly button platform.""" """Tests for Fritz!Tools button platform."""
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest

View File

@ -1,4 +1,4 @@
"""Tests for AVM Fritz!Box config flow.""" """Tests for Fritz!Tools config flow."""
import dataclasses import dataclasses
from unittest.mock import patch from unittest.mock import patch

View File

@ -1,4 +1,4 @@
"""Tests for the AVM Fritz!Box integration.""" """Tests for Fritz!Tools diagnostics platform."""
from __future__ import annotations from __future__ import annotations
from aiohttp import ClientSession from aiohttp import ClientSession

View File

@ -1,4 +1,4 @@
"""Tests for AVM Fritz!Box.""" """Tests for Fritz!Tools."""
from unittest.mock import patch from unittest.mock import patch
from fritzconnection.core.exceptions import FritzSecurityError from fritzconnection.core.exceptions import FritzSecurityError

View File

@ -1,4 +1,4 @@
"""Tests for Shelly button platform.""" """Tests for Fritz!Tools sensor platform."""
from __future__ import annotations from __future__ import annotations
from datetime import timedelta from datetime import timedelta

View File

@ -0,0 +1,189 @@
"""Tests for Fritz!Tools switch platform."""
from __future__ import annotations
import pytest
from homeassistant.components.fritz.const import DOMAIN
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from .const import MOCK_FB_SERVICES, MOCK_USER_DATA
from tests.common import MockConfigEntry
MOCK_WLANCONFIGS_SAME_SSID: dict[str, dict] = {
"WLANConfiguration1": {
"GetInfo": {
"NewEnable": True,
"NewStatus": "Up",
"NewMaxBitRate": "Auto",
"NewChannel": 13,
"NewSSID": "WiFi",
"NewBeaconType": "11iandWPA3",
"NewX_AVM-DE_PossibleBeaconTypes": "None,11i,11iandWPA3",
"NewMACAddressControlEnabled": False,
"NewStandard": "ax",
"NewBSSID": "1C:ED:6F:12:34:12",
"NewBasicEncryptionModes": "None",
"NewBasicAuthenticationMode": "None",
"NewMaxCharsSSID": 32,
"NewMinCharsSSID": 1,
"NewAllowedCharsSSID": "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz !\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~",
"NewMinCharsPSK": 64,
"NewMaxCharsPSK": 64,
"NewAllowedCharsPSK": "0123456789ABCDEFabcdef",
}
},
"WLANConfiguration2": {
"GetInfo": {
"NewEnable": True,
"NewStatus": "Up",
"NewMaxBitRate": "Auto",
"NewChannel": 52,
"NewSSID": "WiFi",
"NewBeaconType": "11iandWPA3",
"NewX_AVM-DE_PossibleBeaconTypes": "None,11i,11iandWPA3",
"NewMACAddressControlEnabled": False,
"NewStandard": "ax",
"NewBSSID": "1C:ED:6F:12:34:13",
"NewBasicEncryptionModes": "None",
"NewBasicAuthenticationMode": "None",
"NewMaxCharsSSID": 32,
"NewMinCharsSSID": 1,
"NewAllowedCharsSSID": "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz !\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~",
"NewMinCharsPSK": 64,
"NewMaxCharsPSK": 64,
"NewAllowedCharsPSK": "0123456789ABCDEFabcdef",
}
},
}
MOCK_WLANCONFIGS_DIFF_SSID: dict[str, dict] = {
"WLANConfiguration1": {
"GetInfo": {
"NewEnable": True,
"NewStatus": "Up",
"NewMaxBitRate": "Auto",
"NewChannel": 13,
"NewSSID": "WiFi",
"NewBeaconType": "11iandWPA3",
"NewX_AVM-DE_PossibleBeaconTypes": "None,11i,11iandWPA3",
"NewMACAddressControlEnabled": False,
"NewStandard": "ax",
"NewBSSID": "1C:ED:6F:12:34:12",
"NewBasicEncryptionModes": "None",
"NewBasicAuthenticationMode": "None",
"NewMaxCharsSSID": 32,
"NewMinCharsSSID": 1,
"NewAllowedCharsSSID": "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz !\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~",
"NewMinCharsPSK": 64,
"NewMaxCharsPSK": 64,
"NewAllowedCharsPSK": "0123456789ABCDEFabcdef",
}
},
"WLANConfiguration2": {
"GetInfo": {
"NewEnable": True,
"NewStatus": "Up",
"NewMaxBitRate": "Auto",
"NewChannel": 52,
"NewSSID": "WiFi2",
"NewBeaconType": "11iandWPA3",
"NewX_AVM-DE_PossibleBeaconTypes": "None,11i,11iandWPA3",
"NewMACAddressControlEnabled": False,
"NewStandard": "ax",
"NewBSSID": "1C:ED:6F:12:34:13",
"NewBasicEncryptionModes": "None",
"NewBasicAuthenticationMode": "None",
"NewMaxCharsSSID": 32,
"NewMinCharsSSID": 1,
"NewAllowedCharsSSID": "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz !\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~",
"NewMinCharsPSK": 64,
"NewMaxCharsPSK": 64,
"NewAllowedCharsPSK": "0123456789ABCDEFabcdef",
}
},
}
MOCK_WLANCONFIGS_DIFF2_SSID: dict[str, dict] = {
"WLANConfiguration1": {
"GetInfo": {
"NewEnable": True,
"NewStatus": "Up",
"NewMaxBitRate": "Auto",
"NewChannel": 13,
"NewSSID": "WiFi",
"NewBeaconType": "11iandWPA3",
"NewX_AVM-DE_PossibleBeaconTypes": "None,11i,11iandWPA3",
"NewMACAddressControlEnabled": False,
"NewStandard": "ax",
"NewBSSID": "1C:ED:6F:12:34:12",
"NewBasicEncryptionModes": "None",
"NewBasicAuthenticationMode": "None",
"NewMaxCharsSSID": 32,
"NewMinCharsSSID": 1,
"NewAllowedCharsSSID": "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz !\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~",
"NewMinCharsPSK": 64,
"NewMaxCharsPSK": 64,
"NewAllowedCharsPSK": "0123456789ABCDEFabcdef",
}
},
"WLANConfiguration2": {
"GetInfo": {
"NewEnable": True,
"NewStatus": "Up",
"NewMaxBitRate": "Auto",
"NewChannel": 52,
"NewSSID": "WiFi+",
"NewBeaconType": "11iandWPA3",
"NewX_AVM-DE_PossibleBeaconTypes": "None,11i,11iandWPA3",
"NewMACAddressControlEnabled": False,
"NewStandard": "ax",
"NewBSSID": "1C:ED:6F:12:34:13",
"NewBasicEncryptionModes": "None",
"NewBasicAuthenticationMode": "None",
"NewMaxCharsSSID": 32,
"NewMinCharsSSID": 1,
"NewAllowedCharsSSID": "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz !\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~",
"NewMinCharsPSK": 64,
"NewMaxCharsPSK": 64,
"NewAllowedCharsPSK": "0123456789ABCDEFabcdef",
}
},
}
@pytest.mark.parametrize(
"fc_data, expected_wifi_names",
[
(
{**MOCK_FB_SERVICES, **MOCK_WLANCONFIGS_SAME_SSID},
["WiFi (2.4Ghz)", "WiFi (5Ghz)"],
),
({**MOCK_FB_SERVICES, **MOCK_WLANCONFIGS_DIFF_SSID}, ["WiFi", "WiFi2"]),
(
{**MOCK_FB_SERVICES, **MOCK_WLANCONFIGS_DIFF2_SSID},
["WiFi (2.4Ghz)", "WiFi+ (5Ghz)"],
),
],
)
async def test_switch_setup(
hass: HomeAssistant,
expected_wifi_names: list[str],
fc_class_mock,
fh_class_mock,
):
"""Test setup of Fritz!Tools switches."""
entry = MockConfigEntry(domain=DOMAIN, data=MOCK_USER_DATA)
entry.add_to_hass(hass)
assert await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
assert entry.state == ConfigEntryState.LOADED
switches = hass.states.async_all(Platform.SWITCH)
assert len(switches) == 3
assert switches[0].name == f"Mock Title Wi-Fi {expected_wifi_names[0]}"
assert switches[1].name == f"Mock Title Wi-Fi {expected_wifi_names[1]}"
assert switches[2].name == "printer Internet Access"

View File

@ -1,4 +1,4 @@
"""The tests for the Fritzbox update entity.""" """Tests for Fritz!Tools update platform."""
from unittest.mock import patch from unittest.mock import patch

View File

@ -9,7 +9,6 @@ from sqlalchemy import text
from sqlalchemy.engine.result import ChunkedIteratorResult from sqlalchemy.engine.result import ChunkedIteratorResult
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.lambdas import StatementLambdaElement
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.recorder import history, util from homeassistant.components.recorder import history, util
@ -712,8 +711,8 @@ def test_build_mysqldb_conv():
@patch("homeassistant.components.recorder.util.QUERY_RETRY_WAIT", 0) @patch("homeassistant.components.recorder.util.QUERY_RETRY_WAIT", 0)
def test_execute_stmt_lambda_element(hass_recorder): def test_execute_stmt(hass_recorder):
"""Test executing with execute_stmt_lambda_element.""" """Test executing with execute_stmt."""
hass = hass_recorder() hass = hass_recorder()
instance = recorder.get_instance(hass) instance = recorder.get_instance(hass)
hass.states.set("sensor.on", "on") hass.states.set("sensor.on", "on")
@ -724,13 +723,15 @@ def test_execute_stmt_lambda_element(hass_recorder):
one_week_from_now = now + timedelta(days=7) one_week_from_now = now + timedelta(days=7)
class MockExecutor: class MockExecutor:
_calls = 0
def __init__(self, stmt): def __init__(self, stmt):
assert isinstance(stmt, StatementLambdaElement) """Init the mock."""
self.calls = 0
def all(self): def all(self):
self.calls += 1 MockExecutor._calls += 1
if self.calls == 2: if MockExecutor._calls == 2:
return ["mock_row"] return ["mock_row"]
raise SQLAlchemyError raise SQLAlchemyError
@ -739,24 +740,24 @@ def test_execute_stmt_lambda_element(hass_recorder):
stmt = history._get_single_entity_states_stmt( stmt = history._get_single_entity_states_stmt(
instance.schema_version, dt_util.utcnow(), "sensor.on", False instance.schema_version, dt_util.utcnow(), "sensor.on", False
) )
rows = util.execute_stmt_lambda_element(session, stmt) rows = util.execute_stmt(session, stmt)
assert isinstance(rows, list) assert isinstance(rows, list)
assert rows[0].state == new_state.state assert rows[0].state == new_state.state
assert rows[0].entity_id == new_state.entity_id assert rows[0].entity_id == new_state.entity_id
# Time window >= 2 days, we get a ChunkedIteratorResult # Time window >= 2 days, we get a ChunkedIteratorResult
rows = util.execute_stmt_lambda_element(session, stmt, now, one_week_from_now) rows = util.execute_stmt(session, stmt, now, one_week_from_now)
assert isinstance(rows, ChunkedIteratorResult) assert isinstance(rows, ChunkedIteratorResult)
row = next(rows) row = next(rows)
assert row.state == new_state.state assert row.state == new_state.state
assert row.entity_id == new_state.entity_id assert row.entity_id == new_state.entity_id
# Time window < 2 days, we get a list # Time window < 2 days, we get a list
rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow) rows = util.execute_stmt(session, stmt, now, tomorrow)
assert isinstance(rows, list) assert isinstance(rows, list)
assert rows[0].state == new_state.state assert rows[0].state == new_state.state
assert rows[0].entity_id == new_state.entity_id assert rows[0].entity_id == new_state.entity_id
with patch.object(session, "execute", MockExecutor): with patch.object(session, "execute", MockExecutor):
rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow) rows = util.execute_stmt(session, stmt, now, tomorrow)
assert rows == ["mock_row"] assert rows == ["mock_row"]