Fix Tibber get_prices when called with aware datetime (#123289)

* Tibber: Add extra test to expose aware/naive datetime issue

* Tibber: Fix get_prices action not working with aware datetimes

* Tibber: Simplify comparison

* Tibber: Combine timezone tests into single parametrized one

* Tibber: Split test again to prevent if statement
This commit is contained in:
functionpointer 2024-10-02 08:43:31 +02:00 committed by GitHub
parent cd090ff000
commit 5bd2d27488
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 9 deletions

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import datetime as dt import datetime as dt
from datetime import date, datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, Final from typing import Any, Final
@ -61,27 +61,24 @@ async def __get_prices(call: ServiceCall, *, hass: HomeAssistant) -> ServiceResp
] ]
selected_data = [ selected_data = [
price price for price in price_data if start <= price["start_time"] < end
for price in price_data
if price["start_time"].replace(tzinfo=None) >= start
and price["start_time"].replace(tzinfo=None) < end
] ]
tibber_prices[home_nickname] = selected_data tibber_prices[home_nickname] = selected_data
return {"prices": tibber_prices} return {"prices": tibber_prices}
def __get_date(date_input: str | None, mode: str | None) -> date | datetime: def __get_date(date_input: str | None, mode: str | None) -> datetime:
"""Get date.""" """Get date."""
if not date_input: if not date_input:
if mode == "end": if mode == "end":
increment = dt.timedelta(days=1) increment = dt.timedelta(days=1)
else: else:
increment = dt.timedelta() increment = dt.timedelta()
return datetime.fromisoformat(dt_util.now().date().isoformat()) + increment return dt_util.start_of_local_day() + increment
if value := dt_util.parse_datetime(date_input): if value := dt_util.parse_datetime(date_input):
return value return dt_util.as_local(value)
raise ServiceValidationError( raise ServiceValidationError(
"Invalid datetime provided.", "Invalid datetime provided.",

View File

@ -11,8 +11,11 @@ from homeassistant.components.tibber.const import DOMAIN
from homeassistant.components.tibber.services import PRICE_SERVICE_NAME, __get_prices from homeassistant.components.tibber.services import PRICE_SERVICE_NAME, __get_prices
from homeassistant.core import ServiceCall from homeassistant.core import ServiceCall
from homeassistant.exceptions import ServiceValidationError from homeassistant.exceptions import ServiceValidationError
from homeassistant.util import dt as dt_util
STARTTIME = dt.datetime.fromtimestamp(1615766400) STARTTIME = dt.datetime.fromtimestamp(1615766400).replace(
tzinfo=dt_util.get_default_time_zone()
)
def generate_mock_home_data(): def generate_mock_home_data():
@ -246,6 +249,75 @@ async def test_get_prices_start_tomorrow(
} }
@pytest.mark.parametrize(
"start_time",
[
STARTTIME.isoformat(),
STARTTIME.replace(tzinfo=None).isoformat(),
(STARTTIME + dt.timedelta(hours=4))
.replace(tzinfo=dt.timezone(dt.timedelta(hours=4)))
.isoformat(),
],
)
async def test_get_prices_with_timezones(
freezer: FrozenDateTimeFactory,
start_time: str,
) -> None:
"""Test __get_prices with timezone and without."""
freezer.move_to(STARTTIME)
call = ServiceCall(DOMAIN, PRICE_SERVICE_NAME, {"start": start_time})
result = await __get_prices(call, hass=create_mock_hass())
assert result == {
"prices": {
"first_home": [
{
"start_time": STARTTIME,
"price": 0.46914,
"level": "VERY_EXPENSIVE",
},
{
"start_time": STARTTIME + dt.timedelta(hours=1),
"price": 0.46914,
"level": "VERY_EXPENSIVE",
},
],
"second_home": [
{
"start_time": STARTTIME,
"price": 0.46914,
"level": "VERY_EXPENSIVE",
},
{
"start_time": STARTTIME + dt.timedelta(hours=1),
"price": 0.46914,
"level": "VERY_EXPENSIVE",
},
],
}
}
@pytest.mark.parametrize(
"start_time",
[
(STARTTIME + dt.timedelta(hours=4)).isoformat(),
(STARTTIME + dt.timedelta(hours=4)).replace(tzinfo=None).isoformat(),
],
)
async def test_get_prices_with_wrong_timezones(
freezer: FrozenDateTimeFactory,
start_time: str,
) -> None:
"""Test __get_prices with timezone and without, while expecting it to fail."""
freezer.move_to(STARTTIME)
call = ServiceCall(DOMAIN, PRICE_SERVICE_NAME, {"start": start_time})
result = await __get_prices(call, hass=create_mock_hass())
assert result == {"prices": {"first_home": [], "second_home": []}}
async def test_get_prices_invalid_input() -> None: async def test_get_prices_invalid_input() -> None:
"""Test __get_prices with invalid input.""" """Test __get_prices with invalid input."""