Allow fetching multiple statistics (#51996)

This commit is contained in:
Paulus Schoutsen 2021-06-18 12:03:13 -07:00 committed by GitHub
parent 87a43eacb7
commit 805ef3f90b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 47 deletions

View File

@ -119,7 +119,7 @@ class LazyState(history_models.LazyState):
vol.Required("type"): "history/statistics_during_period", vol.Required("type"): "history/statistics_during_period",
vol.Required("start_time"): str, vol.Required("start_time"): str,
vol.Optional("end_time"): str, vol.Optional("end_time"): str,
vol.Optional("statistic_id"): str, vol.Optional("statistic_ids"): [str],
} }
) )
@websocket_api.async_response @websocket_api.async_response
@ -152,9 +152,9 @@ async def ws_get_statistics_during_period(
hass, hass,
start_time, start_time,
end_time, end_time,
msg.get("statistic_id"), msg.get("statistic_ids"),
) )
connection.send_result(msg["id"], {"statistics": statistics}) connection.send_result(msg["id"], statistics)
class HistoryPeriodView(HomeAssistantView): class HistoryPeriodView(HomeAssistantView):

View File

@ -74,7 +74,7 @@ def compile_statistics(instance: Recorder, start: datetime.datetime) -> bool:
return True return True
def statistics_during_period(hass, start_time, end_time=None, statistic_id=None): def statistics_during_period(hass, start_time, end_time=None, statistic_ids=None):
"""Return states changes during UTC period start_time - end_time.""" """Return states changes during UTC period start_time - end_time."""
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
baked_query = hass.data[STATISTICS_BAKERY]( baked_query = hass.data[STATISTICS_BAKERY](
@ -86,20 +86,20 @@ def statistics_during_period(hass, start_time, end_time=None, statistic_id=None)
if end_time is not None: if end_time is not None:
baked_query += lambda q: q.filter(Statistics.start < bindparam("end_time")) baked_query += lambda q: q.filter(Statistics.start < bindparam("end_time"))
if statistic_id is not None: if statistic_ids is not None:
baked_query += lambda q: q.filter_by(statistic_id=bindparam("statistic_id")) baked_query += lambda q: q.filter(
statistic_id = statistic_id.lower() Statistics.statistic_id.in_(bindparam("statistic_ids"))
)
statistic_ids = [statistic_id.lower() for statistic_id in statistic_ids]
baked_query += lambda q: q.order_by(Statistics.statistic_id, Statistics.start) baked_query += lambda q: q.order_by(Statistics.statistic_id, Statistics.start)
stats = execute( stats = execute(
baked_query(session).params( baked_query(session).params(
start_time=start_time, end_time=end_time, statistic_id=statistic_id start_time=start_time, end_time=end_time, statistic_ids=statistic_ids
) )
) )
statistic_ids = [statistic_id] if statistic_id is not None else None
return _sorted_statistics_to_dict(stats, statistic_ids) return _sorted_statistics_to_dict(stats, statistic_ids)

View File

@ -834,11 +834,7 @@ async def test_statistics_during_period(hass, hass_ws_client):
now = dt_util.utcnow() now = dt_util.utcnow()
await hass.async_add_executor_job(init_recorder_component, hass) await hass.async_add_executor_job(init_recorder_component, hass)
await async_setup_component( await async_setup_component(hass, "history", {})
hass,
"history",
{"history": {}},
)
await async_setup_component(hass, "sensor", {}) await async_setup_component(hass, "sensor", {})
await hass.async_add_executor_job(hass.data[recorder.DATA_INSTANCE].block_till_done) await hass.async_add_executor_job(hass.data[recorder.DATA_INSTANCE].block_till_done)
hass.states.async_set( hass.states.async_set(
@ -861,12 +857,12 @@ async def test_statistics_during_period(hass, hass_ws_client):
"type": "history/statistics_during_period", "type": "history/statistics_during_period",
"start_time": now.isoformat(), "start_time": now.isoformat(),
"end_time": now.isoformat(), "end_time": now.isoformat(),
"statistic_id": "sensor.test", "statistic_ids": ["sensor.test"],
} }
) )
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
assert response["result"] == {"statistics": {}} assert response["result"] == {}
client = await hass_ws_client() client = await hass_ws_client()
await client.send_json( await client.send_json(
@ -874,26 +870,24 @@ async def test_statistics_during_period(hass, hass_ws_client):
"id": 1, "id": 1,
"type": "history/statistics_during_period", "type": "history/statistics_during_period",
"start_time": now.isoformat(), "start_time": now.isoformat(),
"statistic_id": "sensor.test", "statistic_ids": ["sensor.test"],
} }
) )
response = await client.receive_json() response = await client.receive_json()
assert response["success"] assert response["success"]
assert response["result"] == { assert response["result"] == {
"statistics": { "sensor.test": [
"sensor.test": [ {
{ "statistic_id": "sensor.test",
"statistic_id": "sensor.test", "start": now.isoformat(),
"start": now.isoformat(), "mean": approx(10.0),
"mean": approx(10.0), "min": approx(10.0),
"min": approx(10.0), "max": approx(10.0),
"max": approx(10.0), "last_reset": None,
"last_reset": None, "state": None,
"state": None, "sum": None,
"sum": None, }
} ]
]
}
} }

View File

@ -26,21 +26,22 @@ def test_compile_hourly_statistics(hass_recorder):
recorder.do_adhoc_statistics(period="hourly", start=zero) recorder.do_adhoc_statistics(period="hourly", start=zero)
wait_recording_done(hass) wait_recording_done(hass)
stats = statistics_during_period(hass, zero) for kwargs in ({}, {"statistic_ids": ["sensor.test1"]}):
assert stats == { stats = statistics_during_period(hass, zero, **kwargs)
"sensor.test1": [ assert stats == {
{ "sensor.test1": [
"statistic_id": "sensor.test1", {
"start": process_timestamp_to_utc_isoformat(zero), "statistic_id": "sensor.test1",
"mean": approx(14.915254237288135), "start": process_timestamp_to_utc_isoformat(zero),
"min": approx(10.0), "mean": approx(14.915254237288135),
"max": approx(20.0), "min": approx(10.0),
"last_reset": None, "max": approx(20.0),
"state": None, "last_reset": None,
"sum": None, "state": None,
} "sum": None,
] }
} ]
}
def record_states(hass): def record_states(hass):