This commit is contained in:
Franck Nijhof 2023-03-08 18:35:50 +01:00 committed by GitHub
commit 3dca4c2f23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 1330 additions and 555 deletions

View File

@ -1100,6 +1100,7 @@ build.json @home-assistant/supervisor
/homeassistant/components/smhi/ @gjohansson-ST /homeassistant/components/smhi/ @gjohansson-ST
/tests/components/smhi/ @gjohansson-ST /tests/components/smhi/ @gjohansson-ST
/homeassistant/components/sms/ @ocalvo /homeassistant/components/sms/ @ocalvo
/homeassistant/components/snapcast/ @luar123
/homeassistant/components/snooz/ @AustinBrunkhorst /homeassistant/components/snooz/ @AustinBrunkhorst
/tests/components/snooz/ @AustinBrunkhorst /tests/components/snooz/ @AustinBrunkhorst
/homeassistant/components/solaredge/ @frenck /homeassistant/components/solaredge/ @frenck

View File

@ -68,7 +68,6 @@ SENSOR_TYPES: list[AirQEntityDescription] = [
AirQEntityDescription( AirQEntityDescription(
key="co", key="co",
name="CO", name="CO",
device_class=SensorDeviceClass.CO,
native_unit_of_measurement=CONCENTRATION_MILLIGRAMS_PER_CUBIC_METER, native_unit_of_measurement=CONCENTRATION_MILLIGRAMS_PER_CUBIC_METER,
state_class=SensorStateClass.MEASUREMENT, state_class=SensorStateClass.MEASUREMENT,
value=lambda data: data.get("co"), value=lambda data: data.get("co"),
@ -289,7 +288,6 @@ SENSOR_TYPES: list[AirQEntityDescription] = [
AirQEntityDescription( AirQEntityDescription(
key="tvoc", key="tvoc",
name="VOC", name="VOC",
device_class=SensorDeviceClass.VOLATILE_ORGANIC_COMPOUNDS,
native_unit_of_measurement=CONCENTRATION_PARTS_PER_BILLION, native_unit_of_measurement=CONCENTRATION_PARTS_PER_BILLION,
state_class=SensorStateClass.MEASUREMENT, state_class=SensorStateClass.MEASUREMENT,
value=lambda data: data.get("tvoc"), value=lambda data: data.get("tvoc"),
@ -297,7 +295,6 @@ SENSOR_TYPES: list[AirQEntityDescription] = [
AirQEntityDescription( AirQEntityDescription(
key="tvoc_ionsc", key="tvoc_ionsc",
name="VOC (Industrial)", name="VOC (Industrial)",
device_class=SensorDeviceClass.VOLATILE_ORGANIC_COMPOUNDS,
native_unit_of_measurement=CONCENTRATION_PARTS_PER_BILLION, native_unit_of_measurement=CONCENTRATION_PARTS_PER_BILLION,
state_class=SensorStateClass.MEASUREMENT, state_class=SensorStateClass.MEASUREMENT,
value=lambda data: data.get("tvoc_ionsc"), value=lambda data: data.get("tvoc_ionsc"),

View File

@ -1,5 +1,6 @@
"""Rest API for Home Assistant.""" """Rest API for Home Assistant."""
import asyncio import asyncio
from functools import lru_cache
from http import HTTPStatus from http import HTTPStatus
import logging import logging
@ -350,6 +351,12 @@ class APIComponentsView(HomeAssistantView):
return self.json(request.app["hass"].config.components) return self.json(request.app["hass"].config.components)
@lru_cache
def _cached_template(template_str: str, hass: ha.HomeAssistant) -> template.Template:
"""Return a cached template."""
return template.Template(template_str, hass)
class APITemplateView(HomeAssistantView): class APITemplateView(HomeAssistantView):
"""View to handle Template requests.""" """View to handle Template requests."""
@ -362,7 +369,7 @@ class APITemplateView(HomeAssistantView):
raise Unauthorized() raise Unauthorized()
try: try:
data = await request.json() data = await request.json()
tpl = template.Template(data["template"], request.app["hass"]) tpl = _cached_template(data["template"], request.app["hass"])
return tpl.async_render(variables=data.get("variables"), parse_result=False) return tpl.async_render(variables=data.get("variables"), parse_result=False)
except (ValueError, TemplateError) as ex: except (ValueError, TemplateError) as ex:
return self.json_message( return self.json_message(

View File

@ -227,20 +227,21 @@ class BaseHaRemoteScanner(BaseHaScanner):
self.hass, self._async_expire_devices, timedelta(seconds=30) self.hass, self._async_expire_devices, timedelta(seconds=30)
) )
cancel_stop = self.hass.bus.async_listen( cancel_stop = self.hass.bus.async_listen(
EVENT_HOMEASSISTANT_STOP, self._save_history EVENT_HOMEASSISTANT_STOP, self._async_save_history
) )
self._async_setup_scanner_watchdog() self._async_setup_scanner_watchdog()
@hass_callback @hass_callback
def _cancel() -> None: def _cancel() -> None:
self._save_history() self._async_save_history()
self._async_stop_scanner_watchdog() self._async_stop_scanner_watchdog()
cancel_track() cancel_track()
cancel_stop() cancel_stop()
return _cancel return _cancel
def _save_history(self, event: Event | None = None) -> None: @hass_callback
def _async_save_history(self, event: Event | None = None) -> None:
"""Save the history.""" """Save the history."""
self._storage.async_set_advertisement_history( self._storage.async_set_advertisement_history(
self.source, self.source,
@ -252,6 +253,7 @@ class BaseHaRemoteScanner(BaseHaScanner):
), ),
) )
@hass_callback
def _async_expire_devices(self, _datetime: datetime.datetime) -> None: def _async_expire_devices(self, _datetime: datetime.datetime) -> None:
"""Expire old devices.""" """Expire old devices."""
now = MONOTONIC_TIME() now = MONOTONIC_TIME()

View File

@ -14,6 +14,6 @@
"integration_type": "device", "integration_type": "device",
"iot_class": "local_push", "iot_class": "local_push",
"loggers": ["aioesphomeapi", "noiseprotocol"], "loggers": ["aioesphomeapi", "noiseprotocol"],
"requirements": ["aioesphomeapi==13.4.1", "esphome-dashboard-api==1.2.3"], "requirements": ["aioesphomeapi==13.4.2", "esphome-dashboard-api==1.2.3"],
"zeroconf": ["_esphomelib._tcp.local."] "zeroconf": ["_esphomelib._tcp.local."]
} }

View File

@ -7,5 +7,5 @@
"integration_type": "hub", "integration_type": "hub",
"iot_class": "local_push", "iot_class": "local_push",
"loggers": ["pyfibaro"], "loggers": ["pyfibaro"],
"requirements": ["pyfibaro==0.6.8"] "requirements": ["pyfibaro==0.6.9"]
} }

View File

@ -20,5 +20,5 @@
"documentation": "https://www.home-assistant.io/integrations/frontend", "documentation": "https://www.home-assistant.io/integrations/frontend",
"integration_type": "system", "integration_type": "system",
"quality_scale": "internal", "quality_scale": "internal",
"requirements": ["home-assistant-frontend==20230302.0"] "requirements": ["home-assistant-frontend==20230306.0"]
} }

View File

@ -41,7 +41,7 @@ async def async_setup_platform(
[ [
GeniusClimateZone(broker, z) GeniusClimateZone(broker, z)
for z in broker.client.zone_objs for z in broker.client.zone_objs
if z.data["type"] in GH_ZONES if z.data.get("type") in GH_ZONES
] ]
) )

View File

@ -42,7 +42,7 @@ async def async_setup_platform(
[ [
GeniusSwitch(broker, z) GeniusSwitch(broker, z)
for z in broker.client.zone_objs for z in broker.client.zone_objs
if z.data["type"] == GH_ON_OFF_ZONE if z.data.get("type") == GH_ON_OFF_ZONE
] ]
) )

View File

@ -48,7 +48,7 @@ async def async_setup_platform(
[ [
GeniusWaterHeater(broker, z) GeniusWaterHeater(broker, z)
for z in broker.client.zone_objs for z in broker.client.zone_objs
if z.data["type"] in GH_HEATERS if z.data.get("type") in GH_HEATERS
] ]
) )

View File

@ -36,6 +36,7 @@ X_AUTH_TOKEN = "X-Supervisor-Token"
X_INGRESS_PATH = "X-Ingress-Path" X_INGRESS_PATH = "X-Ingress-Path"
X_HASS_USER_ID = "X-Hass-User-ID" X_HASS_USER_ID = "X-Hass-User-ID"
X_HASS_IS_ADMIN = "X-Hass-Is-Admin" X_HASS_IS_ADMIN = "X-Hass-Is-Admin"
X_HASS_SOURCE = "X-Hass-Source"
WS_TYPE = "type" WS_TYPE = "type"
WS_ID = "id" WS_ID = "id"

View File

@ -17,7 +17,7 @@ from homeassistant.const import SERVER_PORT
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from .const import ATTR_DISCOVERY, DOMAIN from .const import ATTR_DISCOVERY, DOMAIN, X_HASS_SOURCE
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -445,6 +445,8 @@ class HassIO:
payload=None, payload=None,
timeout=10, timeout=10,
return_text=False, return_text=False,
*,
source="core.handler",
): ):
"""Send API command to Hass.io. """Send API command to Hass.io.
@ -458,7 +460,8 @@ class HassIO:
headers={ headers={
aiohttp.hdrs.AUTHORIZATION: ( aiohttp.hdrs.AUTHORIZATION: (
f"Bearer {os.environ.get('SUPERVISOR_TOKEN', '')}" f"Bearer {os.environ.get('SUPERVISOR_TOKEN', '')}"
) ),
X_HASS_SOURCE: source,
}, },
timeout=aiohttp.ClientTimeout(total=timeout), timeout=aiohttp.ClientTimeout(total=timeout),
) )

View File

@ -6,6 +6,7 @@ from http import HTTPStatus
import logging import logging
import os import os
import re import re
from urllib.parse import quote, unquote
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
@ -19,13 +20,16 @@ from aiohttp.hdrs import (
TRANSFER_ENCODING, TRANSFER_ENCODING,
) )
from aiohttp.web_exceptions import HTTPBadGateway from aiohttp.web_exceptions import HTTPBadGateway
from multidict import istr
from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView from homeassistant.components.http import (
KEY_AUTHENTICATED,
KEY_HASS_USER,
HomeAssistantView,
)
from homeassistant.components.onboarding import async_is_onboarded from homeassistant.components.onboarding import async_is_onboarded
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .const import X_HASS_IS_ADMIN, X_HASS_USER_ID from .const import X_HASS_SOURCE
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -34,23 +38,53 @@ MAX_UPLOAD_SIZE = 1024 * 1024 * 1024
# pylint: disable=implicit-str-concat # pylint: disable=implicit-str-concat
NO_TIMEOUT = re.compile( NO_TIMEOUT = re.compile(
r"^(?:" r"^(?:"
r"|homeassistant/update"
r"|hassos/update"
r"|hassos/update/cli"
r"|supervisor/update"
r"|addons/[^/]+/(?:update|install|rebuild)"
r"|backups/.+/full" r"|backups/.+/full"
r"|backups/.+/partial" r"|backups/.+/partial"
r"|backups/[^/]+/(?:upload|download)" r"|backups/[^/]+/(?:upload|download)"
r")$" r")$"
) )
NO_AUTH_ONBOARDING = re.compile(r"^(?:" r"|supervisor/logs" r"|backups/[^/]+/.+" r")$") # fmt: off
# Onboarding can upload backups and restore it
PATHS_NOT_ONBOARDED = re.compile(
r"^(?:"
r"|backups/[a-f0-9]{8}(/info|/new/upload|/download|/restore/full|/restore/partial)?"
r"|backups/new/upload"
r")$"
)
NO_AUTH = re.compile(r"^(?:" r"|app/.*" r"|[store\/]*addons/[^/]+/(logo|icon)" r")$") # Authenticated users manage backups + download logs
PATHS_ADMIN = re.compile(
r"^(?:"
r"|backups/[a-f0-9]{8}(/info|/download|/restore/full|/restore/partial)?"
r"|backups/new/upload"
r"|audio/logs"
r"|cli/logs"
r"|core/logs"
r"|dns/logs"
r"|host/logs"
r"|multicast/logs"
r"|observer/logs"
r"|supervisor/logs"
r"|addons/[^/]+/logs"
r")$"
)
NO_STORE = re.compile(r"^(?:" r"|app/entrypoint.js" r")$") # Unauthenticated requests come in for Supervisor panel + add-on images
PATHS_NO_AUTH = re.compile(
r"^(?:"
r"|app/.*"
r"|(store/)?addons/[^/]+/(logo|icon)"
r")$"
)
NO_STORE = re.compile(
r"^(?:"
r"|app/entrypoint.js"
r")$"
)
# pylint: enable=implicit-str-concat # pylint: enable=implicit-str-concat
# fmt: on
class HassIOView(HomeAssistantView): class HassIOView(HomeAssistantView):
@ -65,28 +99,56 @@ class HassIOView(HomeAssistantView):
self._host = host self._host = host
self._websession = websession self._websession = websession
async def _handle( async def _handle(self, request: web.Request, path: str) -> web.StreamResponse:
self, request: web.Request, path: str
) -> web.Response | web.StreamResponse:
"""Route data to Hass.io."""
hass = request.app["hass"]
if _need_auth(hass, path) and not request[KEY_AUTHENTICATED]:
return web.Response(status=HTTPStatus.UNAUTHORIZED)
return await self._command_proxy(path, request)
delete = _handle
get = _handle
post = _handle
async def _command_proxy(
self, path: str, request: web.Request
) -> web.StreamResponse:
"""Return a client request with proxy origin for Hass.io supervisor. """Return a client request with proxy origin for Hass.io supervisor.
This method is a coroutine. Use cases:
- Onboarding allows restoring backups
- Load Supervisor panel and add-on logo unauthenticated
- User upload/restore backups
""" """
headers = _init_header(request) # No bullshit
if path != unquote(path):
return web.Response(status=HTTPStatus.BAD_REQUEST)
hass: HomeAssistant = request.app["hass"]
is_admin = request[KEY_AUTHENTICATED] and request[KEY_HASS_USER].is_admin
authorized = is_admin
if is_admin:
allowed_paths = PATHS_ADMIN
elif not async_is_onboarded(hass):
allowed_paths = PATHS_NOT_ONBOARDED
# During onboarding we need the user to manage backups
authorized = True
else:
# Either unauthenticated or not an admin
allowed_paths = PATHS_NO_AUTH
no_auth_path = PATHS_NO_AUTH.match(path)
headers = {
X_HASS_SOURCE: "core.http",
}
if no_auth_path:
if request.method != "GET":
return web.Response(status=HTTPStatus.METHOD_NOT_ALLOWED)
else:
if not allowed_paths.match(path):
return web.Response(status=HTTPStatus.UNAUTHORIZED)
if authorized:
headers[
AUTHORIZATION
] = f"Bearer {os.environ.get('SUPERVISOR_TOKEN', '')}"
if request.method == "POST":
headers[CONTENT_TYPE] = request.content_type
# _stored_content_type is only computed once `content_type` is accessed
if path == "backups/new/upload": if path == "backups/new/upload":
# We need to reuse the full content type that includes the boundary # We need to reuse the full content type that includes the boundary
headers[ headers[
@ -96,7 +158,7 @@ class HassIOView(HomeAssistantView):
try: try:
client = await self._websession.request( client = await self._websession.request(
method=request.method, method=request.method,
url=f"http://{self._host}/{path}", url=f"http://{self._host}/{quote(path)}",
params=request.query, params=request.query,
data=request.content, data=request.content,
headers=headers, headers=headers,
@ -123,20 +185,8 @@ class HassIOView(HomeAssistantView):
raise HTTPBadGateway() raise HTTPBadGateway()
get = _handle
def _init_header(request: web.Request) -> dict[istr, str]: post = _handle
"""Create initial header."""
headers = {
AUTHORIZATION: f"Bearer {os.environ.get('SUPERVISOR_TOKEN', '')}",
CONTENT_TYPE: request.content_type,
}
# Add user data
if request.get("hass_user") is not None:
headers[istr(X_HASS_USER_ID)] = request["hass_user"].id
headers[istr(X_HASS_IS_ADMIN)] = str(int(request["hass_user"].is_admin))
return headers
def _response_header(response: aiohttp.ClientResponse, path: str) -> dict[str, str]: def _response_header(response: aiohttp.ClientResponse, path: str) -> dict[str, str]:
@ -164,12 +214,3 @@ def _get_timeout(path: str) -> ClientTimeout:
if NO_TIMEOUT.match(path): if NO_TIMEOUT.match(path):
return ClientTimeout(connect=10, total=None) return ClientTimeout(connect=10, total=None)
return ClientTimeout(connect=10, total=300) return ClientTimeout(connect=10, total=300)
def _need_auth(hass: HomeAssistant, path: str) -> bool:
"""Return if a path need authentication."""
if not async_is_onboarded(hass) and NO_AUTH_ONBOARDING.match(path):
return False
if NO_AUTH.match(path):
return False
return True

View File

@ -3,20 +3,22 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Iterable from collections.abc import Iterable
from functools import lru_cache
from ipaddress import ip_address from ipaddress import ip_address
import logging import logging
import os from urllib.parse import quote
import aiohttp import aiohttp
from aiohttp import ClientTimeout, hdrs, web from aiohttp import ClientTimeout, hdrs, web
from aiohttp.web_exceptions import HTTPBadGateway, HTTPBadRequest from aiohttp.web_exceptions import HTTPBadGateway, HTTPBadRequest
from multidict import CIMultiDict from multidict import CIMultiDict
from yarl import URL
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import X_AUTH_TOKEN, X_INGRESS_PATH from .const import X_HASS_SOURCE, X_INGRESS_PATH
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -42,9 +44,19 @@ class HassIOIngress(HomeAssistantView):
self._host = host self._host = host
self._websession = websession self._websession = websession
@lru_cache
def _create_url(self, token: str, path: str) -> str: def _create_url(self, token: str, path: str) -> str:
"""Create URL to service.""" """Create URL to service."""
return f"http://{self._host}/ingress/{token}/{path}" base_path = f"/ingress/{token}/"
url = f"http://{self._host}{base_path}{quote(path)}"
try:
if not URL(url).path.startswith(base_path):
raise HTTPBadRequest()
except ValueError as err:
raise HTTPBadRequest() from err
return url
async def _handle( async def _handle(
self, request: web.Request, token: str, path: str self, request: web.Request, token: str, path: str
@ -185,10 +197,8 @@ def _init_header(request: web.Request, token: str) -> CIMultiDict | dict[str, st
continue continue
headers[name] = value headers[name] = value
# Inject token / cleanup later on Supervisor
headers[X_AUTH_TOKEN] = os.environ.get("SUPERVISOR_TOKEN", "")
# Ingress information # Ingress information
headers[X_HASS_SOURCE] = "core.ingress"
headers[X_INGRESS_PATH] = f"/api/hassio_ingress/{token}" headers[X_INGRESS_PATH] = f"/api/hassio_ingress/{token}"
# Set X-Forwarded-For # Set X-Forwarded-For

View File

@ -116,6 +116,7 @@ async def websocket_supervisor_api(
method=msg[ATTR_METHOD], method=msg[ATTR_METHOD],
timeout=msg.get(ATTR_TIMEOUT, 10), timeout=msg.get(ATTR_TIMEOUT, 10),
payload=msg.get(ATTR_DATA, {}), payload=msg.get(ATTR_DATA, {}),
source="core.websocket_api",
) )
if result.get(ATTR_RESULT) == "error": if result.get(ATTR_RESULT) == "error":

View File

@ -153,6 +153,7 @@ async def async_setup_entry( # noqa: C901
system.serial, system.serial,
svc_exception, svc_exception,
) )
await system.aqualink.close()
else: else:
cur = system.online cur = system.online
if cur and not prev: if cur and not prev:

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from collections.abc import Awaitable from collections.abc import Awaitable
import httpx
from iaqualink.exception import AqualinkServiceException from iaqualink.exception import AqualinkServiceException
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -12,5 +13,5 @@ async def await_or_reraise(awaitable: Awaitable) -> None:
"""Execute API call while catching service exceptions.""" """Execute API call while catching service exceptions."""
try: try:
await awaitable await awaitable
except AqualinkServiceException as svc_exception: except (AqualinkServiceException, httpx.HTTPError) as svc_exception:
raise HomeAssistantError(f"Aqualink error: {svc_exception}") from svc_exception raise HomeAssistantError(f"Aqualink error: {svc_exception}") from svc_exception

View File

@ -17,8 +17,8 @@
"iot_class": "local_push", "iot_class": "local_push",
"loggers": ["pyinsteon", "pypubsub"], "loggers": ["pyinsteon", "pypubsub"],
"requirements": [ "requirements": [
"pyinsteon==1.3.3", "pyinsteon==1.3.4",
"insteon-frontend-home-assistant==0.3.2" "insteon-frontend-home-assistant==0.3.3"
], ],
"usb": [ "usb": [
{ {

View File

@ -1,11 +1,13 @@
"""Utilities used by insteon component.""" """Utilities used by insteon component."""
import asyncio import asyncio
from collections.abc import Callable
import logging import logging
from pyinsteon import devices from pyinsteon import devices
from pyinsteon.address import Address from pyinsteon.address import Address
from pyinsteon.constants import ALDBStatus, DeviceAction from pyinsteon.constants import ALDBStatus, DeviceAction
from pyinsteon.events import OFF_EVENT, OFF_FAST_EVENT, ON_EVENT, ON_FAST_EVENT from pyinsteon.device_types.device_base import Device
from pyinsteon.events import OFF_EVENT, OFF_FAST_EVENT, ON_EVENT, ON_FAST_EVENT, Event
from pyinsteon.managers.link_manager import ( from pyinsteon.managers.link_manager import (
async_enter_linking_mode, async_enter_linking_mode,
async_enter_unlinking_mode, async_enter_unlinking_mode,
@ -27,7 +29,7 @@ from homeassistant.const import (
CONF_PLATFORM, CONF_PLATFORM,
ENTITY_MATCH_ALL, ENTITY_MATCH_ALL,
) )
from homeassistant.core import ServiceCall, callback from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
@ -89,49 +91,52 @@ from .schemas import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def add_on_off_event_device(hass, device): def _register_event(event: Event, listener: Callable) -> None:
"""Register the events raised by a device."""
_LOGGER.debug(
"Registering on/off event for %s %d %s",
str(event.address),
event.group,
event.name,
)
event.subscribe(listener, force_strong_ref=True)
def add_on_off_event_device(hass: HomeAssistant, device: Device) -> None:
"""Register an Insteon device as an on/off event device.""" """Register an Insteon device as an on/off event device."""
@callback @callback
def async_fire_group_on_off_event(name, address, group, button): def async_fire_group_on_off_event(
name: str, address: Address, group: int, button: str
):
# Firing an event when a button is pressed. # Firing an event when a button is pressed.
if button and button[-2] == "_": if button and button[-2] == "_":
button_id = button[-1].lower() button_id = button[-1].lower()
else: else:
button_id = None button_id = None
schema = {CONF_ADDRESS: address} schema = {CONF_ADDRESS: address, "group": group}
if button_id: if button_id:
schema[EVENT_CONF_BUTTON] = button_id schema[EVENT_CONF_BUTTON] = button_id
if name == ON_EVENT: if name == ON_EVENT:
event = EVENT_GROUP_ON event = EVENT_GROUP_ON
if name == OFF_EVENT: elif name == OFF_EVENT:
event = EVENT_GROUP_OFF event = EVENT_GROUP_OFF
if name == ON_FAST_EVENT: elif name == ON_FAST_EVENT:
event = EVENT_GROUP_ON_FAST event = EVENT_GROUP_ON_FAST
if name == OFF_FAST_EVENT: elif name == OFF_FAST_EVENT:
event = EVENT_GROUP_OFF_FAST event = EVENT_GROUP_OFF_FAST
else:
event = f"insteon.{name}"
_LOGGER.debug("Firing event %s with %s", event, schema) _LOGGER.debug("Firing event %s with %s", event, schema)
hass.bus.async_fire(event, schema) hass.bus.async_fire(event, schema)
for group in device.events: for name_or_group, event in device.events.items():
if isinstance(group, int): if isinstance(name_or_group, int):
for event in device.events[group]: for _, event in device.events[name_or_group].items():
if event in [ _register_event(event, async_fire_group_on_off_event)
OFF_EVENT, else:
ON_EVENT, _register_event(event, async_fire_group_on_off_event)
OFF_FAST_EVENT,
ON_FAST_EVENT,
]:
_LOGGER.debug(
"Registering on/off event for %s %d %s",
str(device.address),
group,
event,
)
device.events[group][event].subscribe(
async_fire_group_on_off_event, force_strong_ref=True
)
def register_new_device_callback(hass): def register_new_device_callback(hass):

View File

@ -84,7 +84,7 @@ def ensure_zone(value):
if value is None: if value is None:
raise vol.Invalid("zone value is None") raise vol.Invalid("zone value is None")
if str(value) not in ZONES is None: if str(value) not in ZONES:
raise vol.Invalid("zone not valid") raise vol.Invalid("zone not valid")
return str(value) return str(value)

View File

@ -140,7 +140,7 @@ ROBOT_SENSOR_MAP: dict[type[Robot], list[RobotSensorEntityDescription]] = {
name="Pet weight", name="Pet weight",
native_unit_of_measurement=UnitOfMass.POUNDS, native_unit_of_measurement=UnitOfMass.POUNDS,
device_class=SensorDeviceClass.WEIGHT, device_class=SensorDeviceClass.WEIGHT,
state_class=SensorStateClass.TOTAL, state_class=SensorStateClass.MEASUREMENT,
), ),
], ],
FeederRobot: [ FeederRobot: [

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from contextlib import suppress from contextlib import suppress
from functools import wraps from functools import lru_cache, wraps
from http import HTTPStatus from http import HTTPStatus
import logging import logging
import secrets import secrets
@ -365,6 +365,12 @@ async def webhook_stream_camera(
return webhook_response(resp, registration=config_entry.data) return webhook_response(resp, registration=config_entry.data)
@lru_cache
def _cached_template(template_str: str, hass: HomeAssistant) -> template.Template:
"""Return a cached template."""
return template.Template(template_str, hass)
@WEBHOOK_COMMANDS.register("render_template") @WEBHOOK_COMMANDS.register("render_template")
@validate_schema( @validate_schema(
{ {
@ -381,7 +387,7 @@ async def webhook_render_template(
resp = {} resp = {}
for key, item in data.items(): for key, item in data.items():
try: try:
tpl = template.Template(item[ATTR_TEMPLATE], hass) tpl = _cached_template(item[ATTR_TEMPLATE], hass)
resp[key] = tpl.async_render(item.get(ATTR_TEMPLATE_VARIABLES)) resp[key] = tpl.async_render(item.get(ATTR_TEMPLATE_VARIABLES))
except TemplateError as ex: except TemplateError as ex:
resp[key] = {"error": str(ex)} resp[key] = {"error": str(ex)}

View File

@ -17,7 +17,6 @@ from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import Subquery
from homeassistant.const import COMPRESSED_STATE_LAST_UPDATED, COMPRESSED_STATE_STATE from homeassistant.const import COMPRESSED_STATE_LAST_UPDATED, COMPRESSED_STATE_STATE
from homeassistant.core import HomeAssistant, State, split_entity_id from homeassistant.core import HomeAssistant, State, split_entity_id
@ -592,17 +591,25 @@ def get_last_state_changes(
) )
def _generate_most_recent_states_for_entities_by_date( def _get_states_for_entities_stmt(
schema_version: int, schema_version: int,
run_start: datetime, run_start: datetime,
utc_point_in_time: datetime, utc_point_in_time: datetime,
entity_ids: list[str], entity_ids: list[str],
) -> Subquery: no_attributes: bool,
"""Generate the sub query for the most recent states for specific entities by date.""" ) -> StatementLambdaElement:
"""Baked query to get states for specific entities."""
stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=True
)
# We got an include-list of entities, accelerate the query by filtering already
# in the inner query.
if schema_version >= 31: if schema_version >= 31:
run_start_ts = process_timestamp(run_start).timestamp() run_start_ts = process_timestamp(run_start).timestamp()
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time) utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time)
return ( stmt += lambda q: q.join(
(
most_recent_states_for_entities_by_date := (
select( select(
States.entity_id.label("max_entity_id"), States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189 # https://github.com/sqlalchemy/sqlalchemy/issues/9189
@ -617,8 +624,18 @@ def _generate_most_recent_states_for_entities_by_date(
.group_by(States.entity_id) .group_by(States.entity_id)
.subquery() .subquery()
) )
return ( ),
select( and_(
States.entity_id
== most_recent_states_for_entities_by_date.c.max_entity_id,
States.last_updated_ts
== most_recent_states_for_entities_by_date.c.max_last_updated,
),
)
else:
stmt += lambda q: q.join(
(
most_recent_states_for_entities_by_date := select(
States.entity_id.label("max_entity_id"), States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189 # https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable # pylint: disable-next=not-callable
@ -631,40 +648,7 @@ def _generate_most_recent_states_for_entities_by_date(
.filter(States.entity_id.in_(entity_ids)) .filter(States.entity_id.in_(entity_ids))
.group_by(States.entity_id) .group_by(States.entity_id)
.subquery() .subquery()
)
def _get_states_for_entities_stmt(
schema_version: int,
run_start: datetime,
utc_point_in_time: datetime,
entity_ids: list[str],
no_attributes: bool,
) -> StatementLambdaElement:
"""Baked query to get states for specific entities."""
stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=True
)
most_recent_states_for_entities_by_date = (
_generate_most_recent_states_for_entities_by_date(
schema_version, run_start, utc_point_in_time, entity_ids
)
)
# We got an include-list of entities, accelerate the query by filtering already
# in the inner query.
if schema_version >= 31:
stmt += lambda q: q.join(
most_recent_states_for_entities_by_date,
and_(
States.entity_id
== most_recent_states_for_entities_by_date.c.max_entity_id,
States.last_updated_ts
== most_recent_states_for_entities_by_date.c.max_last_updated,
), ),
)
else:
stmt += lambda q: q.join(
most_recent_states_for_entities_by_date,
and_( and_(
States.entity_id States.entity_id
== most_recent_states_for_entities_by_date.c.max_entity_id, == most_recent_states_for_entities_by_date.c.max_entity_id,
@ -679,45 +663,6 @@ def _get_states_for_entities_stmt(
return stmt return stmt
def _generate_most_recent_states_by_date(
schema_version: int,
run_start: datetime,
utc_point_in_time: datetime,
) -> Subquery:
"""Generate the sub query for the most recent states by date."""
if schema_version >= 31:
run_start_ts = process_timestamp(run_start).timestamp()
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time)
return (
select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated_ts).label("max_last_updated"),
)
.filter(
(States.last_updated_ts >= run_start_ts)
& (States.last_updated_ts < utc_point_in_time_ts)
)
.group_by(States.entity_id)
.subquery()
)
return (
select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= run_start)
& (States.last_updated < utc_point_in_time)
)
.group_by(States.entity_id)
.subquery()
)
def _get_states_for_all_stmt( def _get_states_for_all_stmt(
schema_version: int, schema_version: int,
run_start: datetime, run_start: datetime,
@ -733,12 +678,26 @@ def _get_states_for_all_stmt(
# query, then filter out unwanted domains as well as applying the custom filter. # query, then filter out unwanted domains as well as applying the custom filter.
# This filtering can't be done in the inner query because the domain column is # This filtering can't be done in the inner query because the domain column is
# not indexed and we can't control what's in the custom filter. # not indexed and we can't control what's in the custom filter.
most_recent_states_by_date = _generate_most_recent_states_by_date(
schema_version, run_start, utc_point_in_time
)
if schema_version >= 31: if schema_version >= 31:
run_start_ts = process_timestamp(run_start).timestamp()
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time)
stmt += lambda q: q.join( stmt += lambda q: q.join(
most_recent_states_by_date, (
most_recent_states_by_date := (
select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated_ts).label("max_last_updated"),
)
.filter(
(States.last_updated_ts >= run_start_ts)
& (States.last_updated_ts < utc_point_in_time_ts)
)
.group_by(States.entity_id)
.subquery()
)
),
and_( and_(
States.entity_id == most_recent_states_by_date.c.max_entity_id, States.entity_id == most_recent_states_by_date.c.max_entity_id,
States.last_updated_ts == most_recent_states_by_date.c.max_last_updated, States.last_updated_ts == most_recent_states_by_date.c.max_last_updated,
@ -746,7 +705,22 @@ def _get_states_for_all_stmt(
) )
else: else:
stmt += lambda q: q.join( stmt += lambda q: q.join(
most_recent_states_by_date, (
most_recent_states_by_date := (
select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= run_start)
& (States.last_updated < utc_point_in_time)
)
.group_by(States.entity_id)
.subquery()
)
),
and_( and_(
States.entity_id == most_recent_states_by_date.c.max_entity_id, States.entity_id == most_recent_states_by_date.c.max_entity_id,
States.last_updated == most_recent_states_by_date.c.max_last_updated, States.last_updated == most_recent_states_by_date.c.max_last_updated,

View File

@ -6,5 +6,5 @@
"integration_type": "system", "integration_type": "system",
"iot_class": "local_push", "iot_class": "local_push",
"quality_scale": "internal", "quality_scale": "internal",
"requirements": ["sqlalchemy==2.0.4", "fnvhash==0.1.0"] "requirements": ["sqlalchemy==2.0.5.post1", "fnvhash==0.1.0"]
} }

View File

@ -50,7 +50,7 @@ from .tasks import (
PostSchemaMigrationTask, PostSchemaMigrationTask,
StatisticsTimestampMigrationCleanupTask, StatisticsTimestampMigrationCleanupTask,
) )
from .util import session_scope from .util import database_job_retry_wrapper, session_scope
if TYPE_CHECKING: if TYPE_CHECKING:
from . import Recorder from . import Recorder
@ -158,7 +158,9 @@ def migrate_schema(
hass.add_job(instance.async_set_db_ready) hass.add_job(instance.async_set_db_ready)
new_version = version + 1 new_version = version + 1
_LOGGER.info("Upgrading recorder db schema to version %s", new_version) _LOGGER.info("Upgrading recorder db schema to version %s", new_version)
_apply_update(hass, engine, session_maker, new_version, current_version) _apply_update(
instance, hass, engine, session_maker, new_version, current_version
)
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
session.add(SchemaChanges(schema_version=new_version)) session.add(SchemaChanges(schema_version=new_version))
@ -508,7 +510,9 @@ def _drop_foreign_key_constraints(
) )
@database_job_retry_wrapper("Apply migration update", 10)
def _apply_update( # noqa: C901 def _apply_update( # noqa: C901
instance: Recorder,
hass: HomeAssistant, hass: HomeAssistant,
engine: Engine, engine: Engine,
session_maker: Callable[[], Session], session_maker: Callable[[], Session],
@ -922,7 +926,7 @@ def _apply_update( # noqa: C901
# There may be duplicated statistics entries, delete duplicates # There may be duplicated statistics entries, delete duplicates
# and try again # and try again
with session_scope(session=session_maker()) as session: with session_scope(session=session_maker()) as session:
delete_statistics_duplicates(hass, session) delete_statistics_duplicates(instance, hass, session)
_migrate_statistics_columns_to_timestamp(session_maker, engine) _migrate_statistics_columns_to_timestamp(session_maker, engine)
# Log at error level to ensure the user sees this message in the log # Log at error level to ensure the user sees this message in the log
# since we logged the error above. # since we logged the error above.
@ -965,7 +969,7 @@ def post_schema_migration(
# since they are no longer used and take up a significant amount of space. # since they are no longer used and take up a significant amount of space.
assert instance.event_session is not None assert instance.event_session is not None
assert instance.engine is not None assert instance.engine is not None
_wipe_old_string_time_columns(instance.engine, instance.event_session) _wipe_old_string_time_columns(instance, instance.engine, instance.event_session)
if old_version < 35 <= new_version: if old_version < 35 <= new_version:
# In version 34 we migrated all the created, start, and last_reset # In version 34 we migrated all the created, start, and last_reset
# columns to be timestamps. In version 34 we need to wipe the old columns # columns to be timestamps. In version 34 we need to wipe the old columns
@ -978,7 +982,10 @@ def _wipe_old_string_statistics_columns(instance: Recorder) -> None:
instance.queue_task(StatisticsTimestampMigrationCleanupTask()) instance.queue_task(StatisticsTimestampMigrationCleanupTask())
def _wipe_old_string_time_columns(engine: Engine, session: Session) -> None: @database_job_retry_wrapper("Wipe old string time columns", 3)
def _wipe_old_string_time_columns(
instance: Recorder, engine: Engine, session: Session
) -> None:
"""Wipe old string time columns to save space.""" """Wipe old string time columns to save space."""
# Wipe Events.time_fired since its been replaced by Events.time_fired_ts # Wipe Events.time_fired since its been replaced by Events.time_fired_ts
# Wipe States.last_updated since its been replaced by States.last_updated_ts # Wipe States.last_updated since its been replaced by States.last_updated_ts
@ -1162,7 +1169,7 @@ def _migrate_statistics_columns_to_timestamp(
"last_reset_ts=" "last_reset_ts="
"UNIX_TIMESTAMP(last_reset) " "UNIX_TIMESTAMP(last_reset) "
"where start_ts is NULL " "where start_ts is NULL "
"LIMIT 250000;" "LIMIT 100000;"
) )
) )
elif engine.dialect.name == SupportedDialect.POSTGRESQL: elif engine.dialect.name == SupportedDialect.POSTGRESQL:
@ -1180,7 +1187,7 @@ def _migrate_statistics_columns_to_timestamp(
"created_ts=EXTRACT(EPOCH FROM created), " "created_ts=EXTRACT(EPOCH FROM created), "
"last_reset_ts=EXTRACT(EPOCH FROM last_reset) " "last_reset_ts=EXTRACT(EPOCH FROM last_reset) "
"where id IN ( " "where id IN ( "
f"SELECT id FROM {table} where start_ts is NULL LIMIT 250000 " f"SELECT id FROM {table} where start_ts is NULL LIMIT 100000 "
" );" " );"
) )
) )

View File

@ -16,14 +16,13 @@ import re
from statistics import mean from statistics import mean
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, cast
from sqlalchemy import and_, bindparam, func, lambda_stmt, select, text from sqlalchemy import Select, and_, bindparam, func, lambda_stmt, select, text
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.engine.row import Row from sqlalchemy.engine.row import Row
from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal_column, true from sqlalchemy.sql.expression import literal_column, true
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import Subquery
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT
@ -75,6 +74,7 @@ from .models import (
datetime_to_timestamp_or_none, datetime_to_timestamp_or_none,
) )
from .util import ( from .util import (
database_job_retry_wrapper,
execute, execute,
execute_stmt_lambda_element, execute_stmt_lambda_element,
get_instance, get_instance,
@ -515,7 +515,10 @@ def _delete_duplicates_from_table(
return (total_deleted_rows, all_non_identical_duplicates) return (total_deleted_rows, all_non_identical_duplicates)
def delete_statistics_duplicates(hass: HomeAssistant, session: Session) -> None: @database_job_retry_wrapper("delete statistics duplicates", 3)
def delete_statistics_duplicates(
instance: Recorder, hass: HomeAssistant, session: Session
) -> None:
"""Identify and delete duplicated statistics. """Identify and delete duplicated statistics.
A backup will be made of duplicated statistics before it is deleted. A backup will be made of duplicated statistics before it is deleted.
@ -646,27 +649,19 @@ def _compile_hourly_statistics_summary_mean_stmt(
) )
def _compile_hourly_statistics_last_sum_stmt_subquery( def _compile_hourly_statistics_last_sum_stmt(
start_time_ts: float, end_time_ts: float start_time_ts: float, end_time_ts: float
) -> Subquery: ) -> StatementLambdaElement:
"""Generate the summary mean statement for hourly statistics.""" """Generate the summary mean statement for hourly statistics."""
return ( return lambda_stmt(
lambda: select(
subquery := (
select(*QUERY_STATISTICS_SUMMARY_SUM) select(*QUERY_STATISTICS_SUMMARY_SUM)
.filter(StatisticsShortTerm.start_ts >= start_time_ts) .filter(StatisticsShortTerm.start_ts >= start_time_ts)
.filter(StatisticsShortTerm.start_ts < end_time_ts) .filter(StatisticsShortTerm.start_ts < end_time_ts)
.subquery() .subquery()
) )
def _compile_hourly_statistics_last_sum_stmt(
start_time_ts: float, end_time_ts: float
) -> StatementLambdaElement:
"""Generate the summary mean statement for hourly statistics."""
subquery = _compile_hourly_statistics_last_sum_stmt_subquery(
start_time_ts, end_time_ts
) )
return lambda_stmt(
lambda: select(subquery)
.filter(subquery.c.rownum == 1) .filter(subquery.c.rownum == 1)
.order_by(subquery.c.metadata_id) .order_by(subquery.c.metadata_id)
) )
@ -1263,7 +1258,8 @@ def _reduce_statistics_per_month(
) )
def _statistics_during_period_stmt( def _generate_statistics_during_period_stmt(
columns: Select,
start_time: datetime, start_time: datetime,
end_time: datetime | None, end_time: datetime | None,
metadata_ids: list[int] | None, metadata_ids: list[int] | None,
@ -1275,21 +1271,6 @@ def _statistics_during_period_stmt(
This prepares a lambda_stmt query, so we don't insert the parameters yet. This prepares a lambda_stmt query, so we don't insert the parameters yet.
""" """
start_time_ts = start_time.timestamp() start_time_ts = start_time.timestamp()
columns = select(table.metadata_id, table.start_ts)
if "last_reset" in types:
columns = columns.add_columns(table.last_reset_ts)
if "max" in types:
columns = columns.add_columns(table.max)
if "mean" in types:
columns = columns.add_columns(table.mean)
if "min" in types:
columns = columns.add_columns(table.min)
if "state" in types:
columns = columns.add_columns(table.state)
if "sum" in types:
columns = columns.add_columns(table.sum)
stmt = lambda_stmt(lambda: columns.filter(table.start_ts >= start_time_ts)) stmt = lambda_stmt(lambda: columns.filter(table.start_ts >= start_time_ts))
if end_time is not None: if end_time is not None:
end_time_ts = end_time.timestamp() end_time_ts = end_time.timestamp()
@ -1303,6 +1284,23 @@ def _statistics_during_period_stmt(
return stmt return stmt
def _generate_max_mean_min_statistic_in_sub_period_stmt(
columns: Select,
start_time: datetime | None,
end_time: datetime | None,
table: type[StatisticsBase],
metadata_id: int,
) -> StatementLambdaElement:
stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id))
if start_time is not None:
start_time_ts = start_time.timestamp()
stmt += lambda q: q.filter(table.start_ts >= start_time_ts)
if end_time is not None:
end_time_ts = end_time.timestamp()
stmt += lambda q: q.filter(table.start_ts < end_time_ts)
return stmt
def _get_max_mean_min_statistic_in_sub_period( def _get_max_mean_min_statistic_in_sub_period(
session: Session, session: Session,
result: dict[str, float], result: dict[str, float],
@ -1328,13 +1326,9 @@ def _get_max_mean_min_statistic_in_sub_period(
# https://github.com/sqlalchemy/sqlalchemy/issues/9189 # https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable # pylint: disable-next=not-callable
columns = columns.add_columns(func.min(table.min)) columns = columns.add_columns(func.min(table.min))
stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id)) stmt = _generate_max_mean_min_statistic_in_sub_period_stmt(
if start_time is not None: columns, start_time, end_time, table, metadata_id
start_time_ts = start_time.timestamp() )
stmt += lambda q: q.filter(table.start_ts >= start_time_ts)
if end_time is not None:
end_time_ts = end_time.timestamp()
stmt += lambda q: q.filter(table.start_ts < end_time_ts)
stats = cast(Sequence[Row[Any]], execute_stmt_lambda_element(session, stmt)) stats = cast(Sequence[Row[Any]], execute_stmt_lambda_element(session, stmt))
if not stats: if not stats:
return return
@ -1749,8 +1743,21 @@ def _statistics_during_period_with_session(
table: type[Statistics | StatisticsShortTerm] = ( table: type[Statistics | StatisticsShortTerm] = (
Statistics if period != "5minute" else StatisticsShortTerm Statistics if period != "5minute" else StatisticsShortTerm
) )
stmt = _statistics_during_period_stmt( columns = select(table.metadata_id, table.start_ts) # type: ignore[call-overload]
start_time, end_time, metadata_ids, table, types if "last_reset" in types:
columns = columns.add_columns(table.last_reset_ts)
if "max" in types:
columns = columns.add_columns(table.max)
if "mean" in types:
columns = columns.add_columns(table.mean)
if "min" in types:
columns = columns.add_columns(table.min)
if "state" in types:
columns = columns.add_columns(table.state)
if "sum" in types:
columns = columns.add_columns(table.sum)
stmt = _generate_statistics_during_period_stmt(
columns, start_time, end_time, metadata_ids, table, types
) )
stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt))
@ -1915,9 +1922,14 @@ def get_last_short_term_statistics(
) )
def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery: def _latest_short_term_statistics_stmt(
"""Generate the subquery to find the most recent statistic row.""" metadata_ids: list[int],
return ( ) -> StatementLambdaElement:
"""Create the statement for finding the latest short term stat rows."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
stmt += lambda s: s.join(
(
most_recent_statistic_row := (
select( select(
StatisticsShortTerm.metadata_id, StatisticsShortTerm.metadata_id,
# https://github.com/sqlalchemy/sqlalchemy/issues/9189 # https://github.com/sqlalchemy/sqlalchemy/issues/9189
@ -1927,16 +1939,7 @@ def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery:
.where(StatisticsShortTerm.metadata_id.in_(metadata_ids)) .where(StatisticsShortTerm.metadata_id.in_(metadata_ids))
.group_by(StatisticsShortTerm.metadata_id) .group_by(StatisticsShortTerm.metadata_id)
).subquery() ).subquery()
),
def _latest_short_term_statistics_stmt(
metadata_ids: list[int],
) -> StatementLambdaElement:
"""Create the statement for finding the latest short term stat rows."""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM))
most_recent_statistic_row = _generate_most_recent_statistic_row(metadata_ids)
stmt += lambda s: s.join(
most_recent_statistic_row,
( (
StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable
== most_recent_statistic_row.c.metadata_id == most_recent_statistic_row.c.metadata_id
@ -1984,11 +1987,17 @@ def get_latest_short_term_statistics(
) )
def _get_most_recent_statistics_subquery( def _generate_statistics_at_time_stmt(
metadata_ids: set[int], table: type[StatisticsBase], start_time_ts: float columns: Select,
) -> Subquery: table: type[StatisticsBase],
"""Generate the subquery to find the most recent statistic row.""" metadata_ids: set[int],
return ( start_time_ts: float,
) -> StatementLambdaElement:
"""Create the statement for finding the statistics for a given time."""
return lambda_stmt(
lambda: columns.join(
(
most_recent_statistic_ids := (
select( select(
# https://github.com/sqlalchemy/sqlalchemy/issues/9189 # https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable # pylint: disable-next=not-callable
@ -2000,6 +2009,13 @@ def _get_most_recent_statistics_subquery(
.group_by(table.metadata_id) .group_by(table.metadata_id)
.subquery() .subquery()
) )
),
and_(
table.start_ts == most_recent_statistic_ids.c.max_start_ts,
table.metadata_id == most_recent_statistic_ids.c.max_metadata_id,
),
)
)
def _statistics_at_time( def _statistics_at_time(
@ -2023,19 +2039,10 @@ def _statistics_at_time(
columns = columns.add_columns(table.state) columns = columns.add_columns(table.state)
if "sum" in types: if "sum" in types:
columns = columns.add_columns(table.sum) columns = columns.add_columns(table.sum)
start_time_ts = start_time.timestamp() start_time_ts = start_time.timestamp()
most_recent_statistic_ids = _get_most_recent_statistics_subquery( stmt = _generate_statistics_at_time_stmt(
metadata_ids, table, start_time_ts columns, table, metadata_ids, start_time_ts
) )
stmt = lambda_stmt(lambda: columns).join(
most_recent_statistic_ids,
and_(
table.start_ts == most_recent_statistic_ids.c.max_start_ts,
table.metadata_id == most_recent_statistic_ids.c.max_metadata_id,
),
)
return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt))

View File

@ -568,6 +568,17 @@ def end_incomplete_runs(session: Session, start_time: datetime) -> None:
session.add(run) session.add(run)
def _is_retryable_error(instance: Recorder, err: OperationalError) -> bool:
"""Return True if the error is retryable."""
assert instance.engine is not None
return bool(
instance.engine.dialect.name == SupportedDialect.MYSQL
and isinstance(err.orig, BaseException)
and err.orig.args
and err.orig.args[0] in RETRYABLE_MYSQL_ERRORS
)
_FuncType = Callable[Concatenate[_RecorderT, _P], bool] _FuncType = Callable[Concatenate[_RecorderT, _P], bool]
@ -585,12 +596,8 @@ def retryable_database_job(
try: try:
return job(instance, *args, **kwargs) return job(instance, *args, **kwargs)
except OperationalError as err: except OperationalError as err:
assert instance.engine is not None if _is_retryable_error(instance, err):
if ( assert isinstance(err.orig, BaseException)
instance.engine.dialect.name == SupportedDialect.MYSQL
and err.orig
and err.orig.args[0] in RETRYABLE_MYSQL_ERRORS
):
_LOGGER.info( _LOGGER.info(
"%s; %s not completed, retrying", err.orig.args[1], description "%s; %s not completed, retrying", err.orig.args[1], description
) )
@ -608,6 +615,46 @@ def retryable_database_job(
return decorator return decorator
_WrappedFuncType = Callable[Concatenate[_RecorderT, _P], None]
def database_job_retry_wrapper(
description: str, attempts: int = 5
) -> Callable[[_WrappedFuncType[_RecorderT, _P]], _WrappedFuncType[_RecorderT, _P]]:
"""Try to execute a database job multiple times.
This wrapper handles InnoDB deadlocks and lock timeouts.
This is different from retryable_database_job in that it will retry the job
attempts number of times instead of returning False if the job fails.
"""
def decorator(
job: _WrappedFuncType[_RecorderT, _P]
) -> _WrappedFuncType[_RecorderT, _P]:
@functools.wraps(job)
def wrapper(instance: _RecorderT, *args: _P.args, **kwargs: _P.kwargs) -> None:
for attempt in range(attempts):
try:
job(instance, *args, **kwargs)
return
except OperationalError as err:
if attempt == attempts - 1 or not _is_retryable_error(
instance, err
):
raise
assert isinstance(err.orig, BaseException)
_LOGGER.info(
"%s; %s failed, retrying", err.orig.args[1], description
)
time.sleep(instance.db_retry_wait)
# Failed with retryable error
return wrapper
return decorator
def periodic_db_cleanups(instance: Recorder) -> None: def periodic_db_cleanups(instance: Recorder) -> None:
"""Run any database cleanups that need to happen periodically. """Run any database cleanups that need to happen periodically.

View File

@ -64,7 +64,7 @@ NUMBER_ENTITIES = (
get_max_value=lambda api, ch: api.zoom_range(ch)["focus"]["pos"]["max"], get_max_value=lambda api, ch: api.zoom_range(ch)["focus"]["pos"]["max"],
supported=lambda api, ch: api.zoom_supported(ch), supported=lambda api, ch: api.zoom_supported(ch),
value=lambda api, ch: api.get_focus(ch), value=lambda api, ch: api.get_focus(ch),
method=lambda api, ch, value: api.set_zoom(ch, int(value)), method=lambda api, ch, value: api.set_focus(ch, int(value)),
), ),
) )

View File

@ -1,13 +1,11 @@
"""SFR Box.""" """SFR Box."""
from __future__ import annotations from __future__ import annotations
import asyncio
from sfrbox_api.bridge import SFRBox from sfrbox_api.bridge import SFRBox
from sfrbox_api.exceptions import SFRBoxAuthenticationError, SFRBoxError from sfrbox_api.exceptions import SFRBoxAuthenticationError, SFRBoxError
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
@ -40,15 +38,17 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass, box, "system", lambda b: b.system_get_info() hass, box, "system", lambda b: b.system_get_info()
), ),
) )
tasks = [ await data.system.async_config_entry_first_refresh()
data.dsl.async_config_entry_first_refresh(), system_info = data.system.data
data.system.async_config_entry_first_refresh(),
] if system_info.net_infra == "adsl":
await asyncio.gather(*tasks) await data.dsl.async_config_entry_first_refresh()
else:
platforms = list(platforms)
platforms.remove(Platform.BINARY_SENSOR)
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = data hass.data.setdefault(DOMAIN, {})[entry.entry_id] = data
system_info = data.system.data
device_registry = dr.async_get(hass) device_registry = dr.async_get(hass)
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,

View File

@ -1,7 +1,6 @@
"""SFR Box sensor platform.""" """SFR Box sensor platform."""
from collections.abc import Callable, Iterable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from itertools import chain
from typing import Generic, TypeVar from typing import Generic, TypeVar
from sfrbox_api.models import DslInfo, SystemInfo from sfrbox_api.models import DslInfo, SystemInfo
@ -204,15 +203,14 @@ async def async_setup_entry(
"""Set up the sensors.""" """Set up the sensors."""
data: DomainData = hass.data[DOMAIN][entry.entry_id] data: DomainData = hass.data[DOMAIN][entry.entry_id]
entities: Iterable[SFRBoxSensor] = chain( entities: list[SFRBoxSensor] = [
(
SFRBoxSensor(data.dsl, description, data.system.data)
for description in DSL_SENSOR_TYPES
),
(
SFRBoxSensor(data.system, description, data.system.data) SFRBoxSensor(data.system, description, data.system.data)
for description in SYSTEM_SENSOR_TYPES for description in SYSTEM_SENSOR_TYPES
), ]
if data.system.data.net_infra == "adsl":
entities.extend(
SFRBoxSensor(data.dsl, description, data.system.data)
for description in DSL_SENSOR_TYPES
) )
async_add_entities(entities) async_add_entities(entities)

View File

@ -1,9 +1,9 @@
{ {
"domain": "snapcast", "domain": "snapcast",
"name": "Snapcast", "name": "Snapcast",
"codeowners": [], "codeowners": ["@luar123"],
"documentation": "https://www.home-assistant.io/integrations/snapcast", "documentation": "https://www.home-assistant.io/integrations/snapcast",
"iot_class": "local_polling", "iot_class": "local_polling",
"loggers": ["construct", "snapcast"], "loggers": ["construct", "snapcast"],
"requirements": ["snapcast==2.3.0"] "requirements": ["snapcast==2.3.2"]
} }

View File

@ -5,5 +5,5 @@
"config_flow": true, "config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/sql", "documentation": "https://www.home-assistant.io/integrations/sql",
"iot_class": "local_polling", "iot_class": "local_polling",
"requirements": ["sqlalchemy==2.0.4"] "requirements": ["sqlalchemy==2.0.5.post1"]
} }

View File

@ -17,9 +17,8 @@ some of their thread accessories can't be pinged, but it's still a thread proble
from __future__ import annotations from __future__ import annotations
from typing import Any, TypedDict from typing import TYPE_CHECKING, Any, TypedDict
from pyroute2 import NDB # pylint: disable=no-name-in-module
from python_otbr_api.tlv_parser import MeshcopTLVType from python_otbr_api.tlv_parser import MeshcopTLVType
from homeassistant.components import zeroconf from homeassistant.components import zeroconf
@ -29,6 +28,9 @@ from homeassistant.core import HomeAssistant
from .dataset_store import async_get_store from .dataset_store import async_get_store
from .discovery import async_read_zeroconf_cache from .discovery import async_read_zeroconf_cache
if TYPE_CHECKING:
from pyroute2 import NDB # pylint: disable=no-name-in-module
class Neighbour(TypedDict): class Neighbour(TypedDict):
"""A neighbour cache entry (ip neigh).""" """A neighbour cache entry (ip neigh)."""
@ -67,16 +69,15 @@ class Network(TypedDict):
unexpected_routers: set[str] unexpected_routers: set[str]
def _get_possible_thread_routes() -> ( def _get_possible_thread_routes(
tuple[dict[str, dict[str, Route]], dict[str, set[str]]] ndb: NDB,
): ) -> tuple[dict[str, dict[str, Route]], dict[str, set[str]]]:
# Build a list of possible thread routes # Build a list of possible thread routes
# Right now, this is ipv6 /64's that have a gateway # Right now, this is ipv6 /64's that have a gateway
# We cross reference with zerconf data to confirm which via's are known border routers # We cross reference with zerconf data to confirm which via's are known border routers
routes: dict[str, dict[str, Route]] = {} routes: dict[str, dict[str, Route]] = {}
reverse_routes: dict[str, set[str]] = {} reverse_routes: dict[str, set[str]] = {}
with NDB() as ndb:
for record in ndb.routes: for record in ndb.routes:
# Limit to IPV6 routes # Limit to IPV6 routes
if record.family != 10: if record.family != 10:
@ -100,25 +101,37 @@ def _get_possible_thread_routes() -> (
return routes, reverse_routes return routes, reverse_routes
def _get_neighbours() -> dict[str, Neighbour]: def _get_neighbours(ndb: NDB) -> dict[str, Neighbour]:
neighbours: dict[str, Neighbour] = {} # Build a list of neighbours
neighbours: dict[str, Neighbour] = {
with NDB() as ndb: record.dst: {
for record in ndb.neighbours:
neighbours[record.dst] = {
"lladdr": record.lladdr, "lladdr": record.lladdr,
"state": record.state, "state": record.state,
"probes": record.probes, "probes": record.probes,
} }
for record in ndb.neighbours
}
return neighbours return neighbours
def _get_routes_and_neighbors():
"""Get the routes and neighbours from pyroute2."""
# Import in the executor since import NDB can take a while
from pyroute2 import ( # pylint: disable=no-name-in-module, import-outside-toplevel
NDB,
)
with NDB() as ndb: # pylint: disable=not-callable
routes, reverse_routes = _get_possible_thread_routes(ndb)
neighbours = _get_neighbours(ndb)
return routes, reverse_routes, neighbours
async def async_get_config_entry_diagnostics( async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry hass: HomeAssistant, entry: ConfigEntry
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics for all known thread networks.""" """Return diagnostics for all known thread networks."""
networks: dict[str, Network] = {} networks: dict[str, Network] = {}
# Start with all networks that HA knows about # Start with all networks that HA knows about
@ -140,13 +153,12 @@ async def async_get_config_entry_diagnostics(
# Find all routes currently act that might be thread related, so we can match them to # Find all routes currently act that might be thread related, so we can match them to
# border routers as we process the zeroconf data. # border routers as we process the zeroconf data.
routes, reverse_routes = await hass.async_add_executor_job( #
_get_possible_thread_routes # Also find all neighbours
routes, reverse_routes, neighbours = await hass.async_add_executor_job(
_get_routes_and_neighbors
) )
# Find all neighbours
neighbours = await hass.async_add_executor_job(_get_neighbours)
aiozc = await zeroconf.async_get_async_instance(hass) aiozc = await zeroconf.async_get_async_instance(hass)
for data in async_read_zeroconf_cache(aiozc): for data in async_read_zeroconf_cache(aiozc):
if not data.extended_pan_id: if not data.extended_pan_id:

View File

@ -3,9 +3,12 @@ from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import logging import logging
import re
from types import MappingProxyType from types import MappingProxyType
from typing import Any, NamedTuple from typing import Any, NamedTuple
from urllib.parse import urlsplit
from aiohttp import CookieJar
from tplink_omada_client.exceptions import ( from tplink_omada_client.exceptions import (
ConnectionFailed, ConnectionFailed,
LoginFailed, LoginFailed,
@ -20,7 +23,10 @@ from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME, CONF_VE
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import selector from homeassistant.helpers import selector
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import (
async_create_clientsession,
async_get_clientsession,
)
from .const import DOMAIN from .const import DOMAIN
@ -42,11 +48,26 @@ async def create_omada_client(
hass: HomeAssistant, data: MappingProxyType[str, Any] hass: HomeAssistant, data: MappingProxyType[str, Any]
) -> OmadaClient: ) -> OmadaClient:
"""Create a TP-Link Omada client API for the given config entry.""" """Create a TP-Link Omada client API for the given config entry."""
host = data[CONF_HOST]
host: str = data[CONF_HOST]
verify_ssl = bool(data[CONF_VERIFY_SSL]) verify_ssl = bool(data[CONF_VERIFY_SSL])
if not host.lower().startswith(("http://", "https://")):
host = "https://" + host
host_parts = urlsplit(host)
if (
host_parts.hostname
and re.fullmatch(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", host_parts.hostname)
is not None
):
# TP-Link API uses cookies for login session, so an unsafe cookie jar is required for IP addresses
websession = async_create_clientsession(hass, cookie_jar=CookieJar(unsafe=True))
else:
websession = async_get_clientsession(hass, verify_ssl=verify_ssl)
username = data[CONF_USERNAME] username = data[CONF_USERNAME]
password = data[CONF_PASSWORD] password = data[CONF_PASSWORD]
websession = async_get_clientsession(hass, verify_ssl=verify_ssl)
return OmadaClient(host, username, password, websession=websession) return OmadaClient(host, username, password, websession=websession)

View File

@ -1,7 +1,7 @@
"""Support for the Tuya lights.""" """Support for the Tuya lights."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass, field
import json import json
from typing import Any, cast from typing import Any, cast
@ -59,7 +59,9 @@ class TuyaLightEntityDescription(LightEntityDescription):
color_data: DPCode | tuple[DPCode, ...] | None = None color_data: DPCode | tuple[DPCode, ...] | None = None
color_mode: DPCode | None = None color_mode: DPCode | None = None
color_temp: DPCode | tuple[DPCode, ...] | None = None color_temp: DPCode | tuple[DPCode, ...] | None = None
default_color_type: ColorTypeData = DEFAULT_COLOR_TYPE_DATA default_color_type: ColorTypeData = field(
default_factory=lambda: DEFAULT_COLOR_TYPE_DATA
)
LIGHTS: dict[str, tuple[TuyaLightEntityDescription, ...]] = { LIGHTS: dict[str, tuple[TuyaLightEntityDescription, ...]] = {

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
import datetime as dt import datetime as dt
from functools import lru_cache
import json import json
from typing import Any, cast from typing import Any, cast
@ -424,6 +425,12 @@ def handle_ping(
connection.send_message(pong_message(msg["id"])) connection.send_message(pong_message(msg["id"]))
@lru_cache
def _cached_template(template_str: str, hass: HomeAssistant) -> template.Template:
"""Return a cached template."""
return template.Template(template_str, hass)
@decorators.websocket_command( @decorators.websocket_command(
{ {
vol.Required("type"): "render_template", vol.Required("type"): "render_template",
@ -440,7 +447,7 @@ async def handle_render_template(
) -> None: ) -> None:
"""Handle render_template command.""" """Handle render_template command."""
template_str = msg["template"] template_str = msg["template"]
template_obj = template.Template(template_str, hass) template_obj = _cached_template(template_str, hass)
variables = msg.get("variables") variables = msg.get("variables")
timeout = msg.get("timeout") timeout = msg.get("timeout")
info = None info = None

View File

@ -1,5 +1,6 @@
"""Support for Zigbee Home Automation devices.""" """Support for Zigbee Home Automation devices."""
import asyncio import asyncio
import copy
import logging import logging
import os import os
@ -90,6 +91,15 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
Will automatically load components to support devices found on the network. Will automatically load components to support devices found on the network.
""" """
# Strip whitespace around `socket://` URIs, this is no longer accepted by zigpy
# This will be removed in 2023.7.0
path = config_entry.data[CONF_DEVICE][CONF_DEVICE_PATH]
data = copy.deepcopy(dict(config_entry.data))
if path.startswith("socket://") and path != path.strip():
data[CONF_DEVICE][CONF_DEVICE_PATH] = path.strip()
hass.config_entries.async_update_entry(config_entry, data=data)
zha_data = hass.data.setdefault(DATA_ZHA, {}) zha_data = hass.data.setdefault(DATA_ZHA, {})
config = zha_data.get(DATA_ZHA_CONFIG, {}) config = zha_data.get(DATA_ZHA_CONFIG, {})

View File

@ -8,7 +8,7 @@ from .backports.enum import StrEnum
APPLICATION_NAME: Final = "HomeAssistant" APPLICATION_NAME: Final = "HomeAssistant"
MAJOR_VERSION: Final = 2023 MAJOR_VERSION: Final = 2023
MINOR_VERSION: Final = 3 MINOR_VERSION: Final = 3
PATCH_VERSION: Final = "1" PATCH_VERSION: Final = "2"
__short_version__: Final = f"{MAJOR_VERSION}.{MINOR_VERSION}" __short_version__: Final = f"{MAJOR_VERSION}.{MINOR_VERSION}"
__version__: Final = f"{__short_version__}.{PATCH_VERSION}" __version__: Final = f"{__short_version__}.{PATCH_VERSION}"
REQUIRED_PYTHON_VER: Final[tuple[int, int, int]] = (3, 10, 0) REQUIRED_PYTHON_VER: Final[tuple[int, int, int]] = (3, 10, 0)

View File

@ -23,7 +23,7 @@ fnvhash==0.1.0
hass-nabucasa==0.61.0 hass-nabucasa==0.61.0
hassil==1.0.6 hassil==1.0.6
home-assistant-bluetooth==1.9.3 home-assistant-bluetooth==1.9.3
home-assistant-frontend==20230302.0 home-assistant-frontend==20230306.0
home-assistant-intents==2023.2.28 home-assistant-intents==2023.2.28
httpx==0.23.3 httpx==0.23.3
ifaddr==0.1.7 ifaddr==0.1.7
@ -42,7 +42,7 @@ pyudev==0.23.2
pyyaml==6.0 pyyaml==6.0
requests==2.28.2 requests==2.28.2
scapy==2.5.0 scapy==2.5.0
sqlalchemy==2.0.4 sqlalchemy==2.0.5.post1
typing-extensions>=4.5.0,<5.0 typing-extensions>=4.5.0,<5.0
voluptuous-serialize==2.6.0 voluptuous-serialize==2.6.0
voluptuous==0.13.1 voluptuous==0.13.1

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "homeassistant" name = "homeassistant"
version = "2023.3.1" version = "2023.3.2"
license = {text = "Apache-2.0"} license = {text = "Apache-2.0"}
description = "Open-source home automation platform running on Python 3." description = "Open-source home automation platform running on Python 3."
readme = "README.rst" readme = "README.rst"

View File

@ -156,7 +156,7 @@ aioecowitt==2023.01.0
aioemonitor==1.0.5 aioemonitor==1.0.5
# homeassistant.components.esphome # homeassistant.components.esphome
aioesphomeapi==13.4.1 aioesphomeapi==13.4.2
# homeassistant.components.flo # homeassistant.components.flo
aioflo==2021.11.0 aioflo==2021.11.0
@ -907,7 +907,7 @@ hole==0.8.0
holidays==0.18.0 holidays==0.18.0
# homeassistant.components.frontend # homeassistant.components.frontend
home-assistant-frontend==20230302.0 home-assistant-frontend==20230306.0
# homeassistant.components.conversation # homeassistant.components.conversation
home-assistant-intents==2023.2.28 home-assistant-intents==2023.2.28
@ -979,7 +979,7 @@ influxdb==5.3.1
inkbird-ble==0.5.6 inkbird-ble==0.5.6
# homeassistant.components.insteon # homeassistant.components.insteon
insteon-frontend-home-assistant==0.3.2 insteon-frontend-home-assistant==0.3.3
# homeassistant.components.intellifire # homeassistant.components.intellifire
intellifire4py==2.2.2 intellifire4py==2.2.2
@ -1621,7 +1621,7 @@ pyevilgenius==2.0.0
pyezviz==0.2.0.9 pyezviz==0.2.0.9
# homeassistant.components.fibaro # homeassistant.components.fibaro
pyfibaro==0.6.8 pyfibaro==0.6.9
# homeassistant.components.fido # homeassistant.components.fido
pyfido==2.1.2 pyfido==2.1.2
@ -1687,7 +1687,7 @@ pyialarm==2.2.0
pyicloud==1.0.0 pyicloud==1.0.0
# homeassistant.components.insteon # homeassistant.components.insteon
pyinsteon==1.3.3 pyinsteon==1.3.4
# homeassistant.components.intesishome # homeassistant.components.intesishome
pyintesishome==1.8.0 pyintesishome==1.8.0
@ -2367,7 +2367,7 @@ smart-meter-texas==0.4.7
smhi-pkg==1.0.16 smhi-pkg==1.0.16
# homeassistant.components.snapcast # homeassistant.components.snapcast
snapcast==2.3.0 snapcast==2.3.2
# homeassistant.components.sonos # homeassistant.components.sonos
soco==0.29.1 soco==0.29.1
@ -2398,7 +2398,7 @@ spotipy==2.22.1
# homeassistant.components.recorder # homeassistant.components.recorder
# homeassistant.components.sql # homeassistant.components.sql
sqlalchemy==2.0.4 sqlalchemy==2.0.5.post1
# homeassistant.components.srp_energy # homeassistant.components.srp_energy
srpenergy==1.3.6 srpenergy==1.3.6

View File

@ -143,7 +143,7 @@ aioecowitt==2023.01.0
aioemonitor==1.0.5 aioemonitor==1.0.5
# homeassistant.components.esphome # homeassistant.components.esphome
aioesphomeapi==13.4.1 aioesphomeapi==13.4.2
# homeassistant.components.flo # homeassistant.components.flo
aioflo==2021.11.0 aioflo==2021.11.0
@ -690,7 +690,7 @@ hole==0.8.0
holidays==0.18.0 holidays==0.18.0
# homeassistant.components.frontend # homeassistant.components.frontend
home-assistant-frontend==20230302.0 home-assistant-frontend==20230306.0
# homeassistant.components.conversation # homeassistant.components.conversation
home-assistant-intents==2023.2.28 home-assistant-intents==2023.2.28
@ -738,7 +738,7 @@ influxdb==5.3.1
inkbird-ble==0.5.6 inkbird-ble==0.5.6
# homeassistant.components.insteon # homeassistant.components.insteon
insteon-frontend-home-assistant==0.3.2 insteon-frontend-home-assistant==0.3.3
# homeassistant.components.intellifire # homeassistant.components.intellifire
intellifire4py==2.2.2 intellifire4py==2.2.2
@ -1161,7 +1161,7 @@ pyevilgenius==2.0.0
pyezviz==0.2.0.9 pyezviz==0.2.0.9
# homeassistant.components.fibaro # homeassistant.components.fibaro
pyfibaro==0.6.8 pyfibaro==0.6.9
# homeassistant.components.fido # homeassistant.components.fido
pyfido==2.1.2 pyfido==2.1.2
@ -1212,7 +1212,7 @@ pyialarm==2.2.0
pyicloud==1.0.0 pyicloud==1.0.0
# homeassistant.components.insteon # homeassistant.components.insteon
pyinsteon==1.3.3 pyinsteon==1.3.4
# homeassistant.components.ipma # homeassistant.components.ipma
pyipma==3.0.6 pyipma==3.0.6
@ -1698,7 +1698,7 @@ spotipy==2.22.1
# homeassistant.components.recorder # homeassistant.components.recorder
# homeassistant.components.sql # homeassistant.components.sql
sqlalchemy==2.0.4 sqlalchemy==2.0.5.post1
# homeassistant.components.srp_energy # homeassistant.components.srp_energy
srpenergy==1.3.6 srpenergy==1.3.6

View File

@ -349,6 +349,52 @@ async def test_api_template(hass: HomeAssistant, mock_api_client: TestClient) ->
assert body == "10" assert body == "10"
hass.states.async_set("sensor.temperature", 20)
resp = await mock_api_client.post(
const.URL_API_TEMPLATE,
json={"template": "{{ states.sensor.temperature.state }}"},
)
body = await resp.text()
assert body == "20"
hass.states.async_remove("sensor.temperature")
resp = await mock_api_client.post(
const.URL_API_TEMPLATE,
json={"template": "{{ states.sensor.temperature.state }}"},
)
body = await resp.text()
assert body == ""
async def test_api_template_cached(
hass: HomeAssistant, mock_api_client: TestClient
) -> None:
"""Test the template API uses the cache."""
hass.states.async_set("sensor.temperature", 30)
resp = await mock_api_client.post(
const.URL_API_TEMPLATE,
json={"template": "{{ states.sensor.temperature.state }}"},
)
body = await resp.text()
assert body == "30"
hass.states.async_set("sensor.temperature", 40)
resp = await mock_api_client.post(
const.URL_API_TEMPLATE,
json={"template": "{{ states.sensor.temperature.state }}"},
)
body = await resp.text()
assert body == "40"
async def test_api_template_error( async def test_api_template_error(
hass: HomeAssistant, mock_api_client: TestClient hass: HomeAssistant, mock_api_client: TestClient

View File

@ -1,5 +1,6 @@
"""Fixtures for Hass.io.""" """Fixtures for Hass.io."""
import os import os
import re
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
@ -12,6 +13,16 @@ from homeassistant.setup import async_setup_component
from . import SUPERVISOR_TOKEN from . import SUPERVISOR_TOKEN
@pytest.fixture(autouse=True)
def disable_security_filter():
"""Disable the security filter to ensure the integration is secure."""
with patch(
"homeassistant.components.http.security_filter.FILTERS",
re.compile("not-matching-anything"),
):
yield
@pytest.fixture @pytest.fixture
def hassio_env(): def hassio_env():
"""Fixture to inject hassio env.""" """Fixture to inject hassio env."""
@ -37,6 +48,13 @@ def hassio_stubs(hassio_env, hass, hass_client, aioclient_mock):
), patch( ), patch(
"homeassistant.components.hassio.HassIO.get_info", "homeassistant.components.hassio.HassIO.get_info",
side_effect=HassioAPIError(), side_effect=HassioAPIError(),
), patch(
"homeassistant.components.hassio.HassIO.get_ingress_panels",
return_value={"panels": []},
), patch(
"homeassistant.components.hassio.repairs.SupervisorRepairs.setup"
), patch(
"homeassistant.components.hassio.HassIO.refresh_updates"
): ):
hass.state = CoreState.starting hass.state = CoreState.starting
hass.loop.run_until_complete(async_setup_component(hass, "hassio", {})) hass.loop.run_until_complete(async_setup_component(hass, "hassio", {}))
@ -67,13 +85,7 @@ async def hassio_client_supervisor(hass, aiohttp_client, hassio_stubs):
@pytest.fixture @pytest.fixture
def hassio_handler(hass, aioclient_mock): async def hassio_handler(hass, aioclient_mock):
"""Create mock hassio handler.""" """Create mock hassio handler."""
async def get_client_session():
return async_get_clientsession(hass)
websession = hass.loop.run_until_complete(get_client_session())
with patch.dict(os.environ, {"SUPERVISOR_TOKEN": SUPERVISOR_TOKEN}): with patch.dict(os.environ, {"SUPERVISOR_TOKEN": SUPERVISOR_TOKEN}):
yield HassIO(hass.loop, websession, "127.0.0.1") yield HassIO(hass.loop, async_get_clientsession(hass), "127.0.0.1")

View File

@ -1,13 +1,21 @@
"""The tests for the hassio component.""" """The tests for the hassio component."""
from __future__ import annotations
from typing import Any, Literal
import aiohttp import aiohttp
from aiohttp import hdrs, web
import pytest import pytest
from homeassistant.components.hassio.handler import HassioAPIError from homeassistant.components.hassio.handler import HassIO, HassioAPIError
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from tests.test_util.aiohttp import AiohttpClientMocker from tests.test_util.aiohttp import AiohttpClientMocker
async def test_api_ping(hassio_handler, aioclient_mock: AiohttpClientMocker) -> None: async def test_api_ping(
hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None:
"""Test setup with API ping.""" """Test setup with API ping."""
aioclient_mock.get("http://127.0.0.1/supervisor/ping", json={"result": "ok"}) aioclient_mock.get("http://127.0.0.1/supervisor/ping", json={"result": "ok"})
@ -16,7 +24,7 @@ async def test_api_ping(hassio_handler, aioclient_mock: AiohttpClientMocker) ->
async def test_api_ping_error( async def test_api_ping_error(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API ping error.""" """Test setup with API ping error."""
aioclient_mock.get("http://127.0.0.1/supervisor/ping", json={"result": "error"}) aioclient_mock.get("http://127.0.0.1/supervisor/ping", json={"result": "error"})
@ -26,7 +34,7 @@ async def test_api_ping_error(
async def test_api_ping_exeption( async def test_api_ping_exeption(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API ping exception.""" """Test setup with API ping exception."""
aioclient_mock.get("http://127.0.0.1/supervisor/ping", exc=aiohttp.ClientError()) aioclient_mock.get("http://127.0.0.1/supervisor/ping", exc=aiohttp.ClientError())
@ -35,7 +43,9 @@ async def test_api_ping_exeption(
assert aioclient_mock.call_count == 1 assert aioclient_mock.call_count == 1
async def test_api_info(hassio_handler, aioclient_mock: AiohttpClientMocker) -> None: async def test_api_info(
hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None:
"""Test setup with API generic info.""" """Test setup with API generic info."""
aioclient_mock.get( aioclient_mock.get(
"http://127.0.0.1/info", "http://127.0.0.1/info",
@ -53,7 +63,7 @@ async def test_api_info(hassio_handler, aioclient_mock: AiohttpClientMocker) ->
async def test_api_info_error( async def test_api_info_error(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API Home Assistant info error.""" """Test setup with API Home Assistant info error."""
aioclient_mock.get( aioclient_mock.get(
@ -67,7 +77,7 @@ async def test_api_info_error(
async def test_api_host_info( async def test_api_host_info(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API Host info.""" """Test setup with API Host info."""
aioclient_mock.get( aioclient_mock.get(
@ -90,7 +100,7 @@ async def test_api_host_info(
async def test_api_supervisor_info( async def test_api_supervisor_info(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API Supervisor info.""" """Test setup with API Supervisor info."""
aioclient_mock.get( aioclient_mock.get(
@ -108,7 +118,9 @@ async def test_api_supervisor_info(
assert data["channel"] == "stable" assert data["channel"] == "stable"
async def test_api_os_info(hassio_handler, aioclient_mock: AiohttpClientMocker) -> None: async def test_api_os_info(
hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None:
"""Test setup with API OS info.""" """Test setup with API OS info."""
aioclient_mock.get( aioclient_mock.get(
"http://127.0.0.1/os/info", "http://127.0.0.1/os/info",
@ -125,7 +137,7 @@ async def test_api_os_info(hassio_handler, aioclient_mock: AiohttpClientMocker)
async def test_api_host_info_error( async def test_api_host_info_error(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API Home Assistant info error.""" """Test setup with API Home Assistant info error."""
aioclient_mock.get( aioclient_mock.get(
@ -139,7 +151,7 @@ async def test_api_host_info_error(
async def test_api_core_info( async def test_api_core_info(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API Home Assistant Core info.""" """Test setup with API Home Assistant Core info."""
aioclient_mock.get( aioclient_mock.get(
@ -153,7 +165,7 @@ async def test_api_core_info(
async def test_api_core_info_error( async def test_api_core_info_error(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API Home Assistant Core info error.""" """Test setup with API Home Assistant Core info error."""
aioclient_mock.get( aioclient_mock.get(
@ -167,7 +179,7 @@ async def test_api_core_info_error(
async def test_api_homeassistant_stop( async def test_api_homeassistant_stop(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API Home Assistant stop.""" """Test setup with API Home Assistant stop."""
aioclient_mock.post("http://127.0.0.1/homeassistant/stop", json={"result": "ok"}) aioclient_mock.post("http://127.0.0.1/homeassistant/stop", json={"result": "ok"})
@ -177,7 +189,7 @@ async def test_api_homeassistant_stop(
async def test_api_homeassistant_restart( async def test_api_homeassistant_restart(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API Home Assistant restart.""" """Test setup with API Home Assistant restart."""
aioclient_mock.post("http://127.0.0.1/homeassistant/restart", json={"result": "ok"}) aioclient_mock.post("http://127.0.0.1/homeassistant/restart", json={"result": "ok"})
@ -187,7 +199,7 @@ async def test_api_homeassistant_restart(
async def test_api_addon_info( async def test_api_addon_info(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API Add-on info.""" """Test setup with API Add-on info."""
aioclient_mock.get( aioclient_mock.get(
@ -201,7 +213,7 @@ async def test_api_addon_info(
async def test_api_addon_stats( async def test_api_addon_stats(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API Add-on stats.""" """Test setup with API Add-on stats."""
aioclient_mock.get( aioclient_mock.get(
@ -215,7 +227,7 @@ async def test_api_addon_stats(
async def test_api_discovery_message( async def test_api_discovery_message(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API discovery message.""" """Test setup with API discovery message."""
aioclient_mock.get( aioclient_mock.get(
@ -229,7 +241,7 @@ async def test_api_discovery_message(
async def test_api_retrieve_discovery( async def test_api_retrieve_discovery(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API discovery message.""" """Test setup with API discovery message."""
aioclient_mock.get( aioclient_mock.get(
@ -243,7 +255,7 @@ async def test_api_retrieve_discovery(
async def test_api_ingress_panels( async def test_api_ingress_panels(
hassio_handler, aioclient_mock: AiohttpClientMocker hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test setup with API Ingress panels.""" """Test setup with API Ingress panels."""
aioclient_mock.get( aioclient_mock.get(
@ -267,3 +279,56 @@ async def test_api_ingress_panels(
assert aioclient_mock.call_count == 1 assert aioclient_mock.call_count == 1
assert data["panels"] assert data["panels"]
assert "slug" in data["panels"] assert "slug" in data["panels"]
@pytest.mark.parametrize(
("api_call", "method", "payload"),
[
["retrieve_discovery_messages", "GET", None],
["refresh_updates", "POST", None],
["update_diagnostics", "POST", True],
],
)
async def test_api_headers(
hass,
aiohttp_raw_server,
socket_enabled,
api_call: str,
method: Literal["GET", "POST"],
payload: Any,
) -> None:
"""Test headers are forwarded correctly."""
received_request = None
async def mock_handler(request):
"""Return OK."""
nonlocal received_request
received_request = request
return web.json_response({"result": "ok", "data": None})
server = await aiohttp_raw_server(mock_handler)
hassio_handler = HassIO(
hass.loop,
async_get_clientsession(hass),
f"{server.host}:{server.port}",
)
api_func = getattr(hassio_handler, api_call)
if payload:
await api_func(payload)
else:
await api_func()
assert received_request is not None
assert received_request.method == method
assert received_request.headers.get("X-Hass-Source") == "core.handler"
if method == "GET":
assert hdrs.CONTENT_TYPE not in received_request.headers
return
assert hdrs.CONTENT_TYPE in received_request.headers
if payload:
assert received_request.headers[hdrs.CONTENT_TYPE] == "application/json"
else:
assert received_request.headers[hdrs.CONTENT_TYPE] == "application/octet-stream"

View File

@ -1,63 +1,45 @@
"""The tests for the hassio component.""" """The tests for the hassio component."""
import asyncio import asyncio
from http import HTTPStatus from http import HTTPStatus
from unittest.mock import patch
from aiohttp import StreamReader from aiohttp import StreamReader
import pytest import pytest
from homeassistant.components.hassio.http import _need_auth
from homeassistant.core import HomeAssistant
from tests.common import MockUser
from tests.test_util.aiohttp import AiohttpClientMocker from tests.test_util.aiohttp import AiohttpClientMocker
async def test_forward_request( @pytest.fixture
hassio_client, aioclient_mock: AiohttpClientMocker def mock_not_onboarded():
) -> None: """Mock that we're not onboarded."""
"""Test fetching normal path.""" with patch(
aioclient_mock.post("http://127.0.0.1/beer", text="response") "homeassistant.components.hassio.http.async_is_onboarded", return_value=False
):
yield
resp = await hassio_client.post("/api/hassio/beer")
# Check we got right response @pytest.fixture
assert resp.status == HTTPStatus.OK def hassio_user_client(hassio_client, hass_admin_user):
body = await resp.text() """Return a Hass.io HTTP client tied to a non-admin user."""
assert body == "response" hass_admin_user.groups = []
return hassio_client
# Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"build_type", ["supervisor/info", "homeassistant/update", "host/info"] "path",
)
async def test_auth_required_forward_request(hassio_noauth_client, build_type) -> None:
"""Test auth required for normal request."""
resp = await hassio_noauth_client.post(f"/api/hassio/{build_type}")
# Check we got right response
assert resp.status == HTTPStatus.UNAUTHORIZED
@pytest.mark.parametrize(
"build_type",
[ [
"app/index.html", "app/entrypoint.js",
"app/hassio-app.html", "addons/bl_b392/logo",
"app/index.html", "addons/bl_b392/icon",
"app/hassio-app.html",
"app/some-chunk.js",
"app/app.js",
], ],
) )
async def test_forward_request_no_auth_for_panel( async def test_forward_request_onboarded_user_get(
hassio_client, build_type, aioclient_mock: AiohttpClientMocker hassio_user_client, aioclient_mock: AiohttpClientMocker, path: str
) -> None: ) -> None:
"""Test no auth needed for .""" """Test fetching normal path."""
aioclient_mock.get(f"http://127.0.0.1/{build_type}", text="response") aioclient_mock.get(f"http://127.0.0.1/{path}", text="response")
resp = await hassio_client.get(f"/api/hassio/{build_type}") resp = await hassio_user_client.get(f"/api/hassio/{path}")
# Check we got right response # Check we got right response
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
@ -66,15 +48,68 @@ async def test_forward_request_no_auth_for_panel(
# Check we forwarded command # Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
# We only expect a single header.
assert aioclient_mock.mock_calls[0][3] == {"X-Hass-Source": "core.http"}
async def test_forward_request_no_auth_for_logo( @pytest.mark.parametrize("method", ["POST", "PUT", "DELETE", "RANDOM"])
hassio_client, aioclient_mock: AiohttpClientMocker async def test_forward_request_onboarded_user_unallowed_methods(
hassio_user_client, aioclient_mock: AiohttpClientMocker, method: str
) -> None: ) -> None:
"""Test no auth needed for logo.""" """Test fetching normal path."""
aioclient_mock.get("http://127.0.0.1/addons/bl_b392/logo", text="response") resp = await hassio_user_client.post("/api/hassio/app/entrypoint.js")
resp = await hassio_client.get("/api/hassio/addons/bl_b392/logo") # Check we got right response
assert resp.status == HTTPStatus.METHOD_NOT_ALLOWED
# Check we did not forward command
assert len(aioclient_mock.mock_calls) == 0
@pytest.mark.parametrize(
("bad_path", "expected_status"),
[
# Caught by bullshit filter
("app/%252E./entrypoint.js", HTTPStatus.BAD_REQUEST),
# The .. is processed, making it an unauthenticated path
("app/../entrypoint.js", HTTPStatus.UNAUTHORIZED),
("app/%2E%2E/entrypoint.js", HTTPStatus.UNAUTHORIZED),
# Unauthenticated path
("supervisor/info", HTTPStatus.UNAUTHORIZED),
("supervisor/logs", HTTPStatus.UNAUTHORIZED),
("addons/bl_b392/logs", HTTPStatus.UNAUTHORIZED),
],
)
async def test_forward_request_onboarded_user_unallowed_paths(
hassio_user_client,
aioclient_mock: AiohttpClientMocker,
bad_path: str,
expected_status: int,
) -> None:
"""Test fetching normal path."""
resp = await hassio_user_client.get(f"/api/hassio/{bad_path}")
# Check we got right response
assert resp.status == expected_status
# Check we didn't forward command
assert len(aioclient_mock.mock_calls) == 0
@pytest.mark.parametrize(
"path",
[
"app/entrypoint.js",
"addons/bl_b392/logo",
"addons/bl_b392/icon",
],
)
async def test_forward_request_onboarded_noauth_get(
hassio_noauth_client, aioclient_mock: AiohttpClientMocker, path: str
) -> None:
"""Test fetching normal path."""
aioclient_mock.get(f"http://127.0.0.1/{path}", text="response")
resp = await hassio_noauth_client.get(f"/api/hassio/{path}")
# Check we got right response # Check we got right response
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
@ -83,15 +118,73 @@ async def test_forward_request_no_auth_for_logo(
# Check we forwarded command # Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
# We only expect a single header.
assert aioclient_mock.mock_calls[0][3] == {"X-Hass-Source": "core.http"}
async def test_forward_request_no_auth_for_icon( @pytest.mark.parametrize("method", ["POST", "PUT", "DELETE", "RANDOM"])
hassio_client, aioclient_mock: AiohttpClientMocker async def test_forward_request_onboarded_noauth_unallowed_methods(
hassio_noauth_client, aioclient_mock: AiohttpClientMocker, method: str
) -> None: ) -> None:
"""Test no auth needed for icon.""" """Test fetching normal path."""
aioclient_mock.get("http://127.0.0.1/addons/bl_b392/icon", text="response") resp = await hassio_noauth_client.post("/api/hassio/app/entrypoint.js")
resp = await hassio_client.get("/api/hassio/addons/bl_b392/icon") # Check we got right response
assert resp.status == HTTPStatus.METHOD_NOT_ALLOWED
# Check we did not forward command
assert len(aioclient_mock.mock_calls) == 0
@pytest.mark.parametrize(
("bad_path", "expected_status"),
[
# Caught by bullshit filter
("app/%252E./entrypoint.js", HTTPStatus.BAD_REQUEST),
# The .. is processed, making it an unauthenticated path
("app/../entrypoint.js", HTTPStatus.UNAUTHORIZED),
("app/%2E%2E/entrypoint.js", HTTPStatus.UNAUTHORIZED),
# Unauthenticated path
("supervisor/info", HTTPStatus.UNAUTHORIZED),
("supervisor/logs", HTTPStatus.UNAUTHORIZED),
("addons/bl_b392/logs", HTTPStatus.UNAUTHORIZED),
],
)
async def test_forward_request_onboarded_noauth_unallowed_paths(
hassio_noauth_client,
aioclient_mock: AiohttpClientMocker,
bad_path: str,
expected_status: int,
) -> None:
"""Test fetching normal path."""
resp = await hassio_noauth_client.get(f"/api/hassio/{bad_path}")
# Check we got right response
assert resp.status == expected_status
# Check we didn't forward command
assert len(aioclient_mock.mock_calls) == 0
@pytest.mark.parametrize(
("path", "authenticated"),
[
("app/entrypoint.js", False),
("addons/bl_b392/logo", False),
("addons/bl_b392/icon", False),
("backups/1234abcd/info", True),
],
)
async def test_forward_request_not_onboarded_get(
hassio_noauth_client,
aioclient_mock: AiohttpClientMocker,
path: str,
authenticated: bool,
mock_not_onboarded,
) -> None:
"""Test fetching normal path."""
aioclient_mock.get(f"http://127.0.0.1/{path}", text="response")
resp = await hassio_noauth_client.get(f"/api/hassio/{path}")
# Check we got right response # Check we got right response
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
@ -100,61 +193,224 @@ async def test_forward_request_no_auth_for_icon(
# Check we forwarded command # Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
expected_headers = {
"X-Hass-Source": "core.http",
}
if authenticated:
expected_headers["Authorization"] = "Bearer 123456"
assert aioclient_mock.mock_calls[0][3] == expected_headers
async def test_forward_log_request( @pytest.mark.parametrize(
hassio_client, aioclient_mock: AiohttpClientMocker "path",
[
"backups/new/upload",
"backups/1234abcd/restore/full",
"backups/1234abcd/restore/partial",
],
)
async def test_forward_request_not_onboarded_post(
hassio_noauth_client,
aioclient_mock: AiohttpClientMocker,
path: str,
mock_not_onboarded,
) -> None: ) -> None:
"""Test fetching normal log path doesn't remove ANSI color escape codes.""" """Test fetching normal path."""
aioclient_mock.get("http://127.0.0.1/beer/logs", text="\033[32mresponse\033[0m") aioclient_mock.get(f"http://127.0.0.1/{path}", text="response")
resp = await hassio_client.get("/api/hassio/beer/logs") resp = await hassio_noauth_client.get(f"/api/hassio/{path}")
# Check we got right response # Check we got right response
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
body = await resp.text() body = await resp.text()
assert body == "\033[32mresponse\033[0m" assert body == "response"
# Check we forwarded command # Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
# We only expect a single header.
assert aioclient_mock.mock_calls[0][3] == {
"X-Hass-Source": "core.http",
"Authorization": "Bearer 123456",
}
@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE", "RANDOM"])
async def test_forward_request_not_onboarded_unallowed_methods(
hassio_noauth_client, aioclient_mock: AiohttpClientMocker, method: str
) -> None:
"""Test fetching normal path."""
resp = await hassio_noauth_client.post("/api/hassio/app/entrypoint.js")
# Check we got right response
assert resp.status == HTTPStatus.METHOD_NOT_ALLOWED
# Check we did not forward command
assert len(aioclient_mock.mock_calls) == 0
@pytest.mark.parametrize(
("bad_path", "expected_status"),
[
# Caught by bullshit filter
("app/%252E./entrypoint.js", HTTPStatus.BAD_REQUEST),
# The .. is processed, making it an unauthenticated path
("app/../entrypoint.js", HTTPStatus.UNAUTHORIZED),
("app/%2E%2E/entrypoint.js", HTTPStatus.UNAUTHORIZED),
# Unauthenticated path
("supervisor/info", HTTPStatus.UNAUTHORIZED),
("supervisor/logs", HTTPStatus.UNAUTHORIZED),
("addons/bl_b392/logs", HTTPStatus.UNAUTHORIZED),
],
)
async def test_forward_request_not_onboarded_unallowed_paths(
hassio_noauth_client,
aioclient_mock: AiohttpClientMocker,
bad_path: str,
expected_status: int,
mock_not_onboarded,
) -> None:
"""Test fetching normal path."""
resp = await hassio_noauth_client.get(f"/api/hassio/{bad_path}")
# Check we got right response
assert resp.status == expected_status
# Check we didn't forward command
assert len(aioclient_mock.mock_calls) == 0
@pytest.mark.parametrize(
("path", "authenticated"),
[
("app/entrypoint.js", False),
("addons/bl_b392/logo", False),
("addons/bl_b392/icon", False),
("backups/1234abcd/info", True),
("supervisor/logs", True),
("addons/bl_b392/logs", True),
],
)
async def test_forward_request_admin_get(
hassio_client,
aioclient_mock: AiohttpClientMocker,
path: str,
authenticated: bool,
) -> None:
"""Test fetching normal path."""
aioclient_mock.get(f"http://127.0.0.1/{path}", text="response")
resp = await hassio_client.get(f"/api/hassio/{path}")
# Check we got right response
assert resp.status == HTTPStatus.OK
body = await resp.text()
assert body == "response"
# Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1
expected_headers = {
"X-Hass-Source": "core.http",
}
if authenticated:
expected_headers["Authorization"] = "Bearer 123456"
assert aioclient_mock.mock_calls[0][3] == expected_headers
@pytest.mark.parametrize(
"path",
[
"backups/new/upload",
"backups/1234abcd/restore/full",
"backups/1234abcd/restore/partial",
],
)
async def test_forward_request_admin_post(
hassio_client,
aioclient_mock: AiohttpClientMocker,
path: str,
) -> None:
"""Test fetching normal path."""
aioclient_mock.get(f"http://127.0.0.1/{path}", text="response")
resp = await hassio_client.get(f"/api/hassio/{path}")
# Check we got right response
assert resp.status == HTTPStatus.OK
body = await resp.text()
assert body == "response"
# Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1
# We only expect a single header.
assert aioclient_mock.mock_calls[0][3] == {
"X-Hass-Source": "core.http",
"Authorization": "Bearer 123456",
}
@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE", "RANDOM"])
async def test_forward_request_admin_unallowed_methods(
hassio_client, aioclient_mock: AiohttpClientMocker, method: str
) -> None:
"""Test fetching normal path."""
resp = await hassio_client.post("/api/hassio/app/entrypoint.js")
# Check we got right response
assert resp.status == HTTPStatus.METHOD_NOT_ALLOWED
# Check we did not forward command
assert len(aioclient_mock.mock_calls) == 0
@pytest.mark.parametrize(
("bad_path", "expected_status"),
[
# Caught by bullshit filter
("app/%252E./entrypoint.js", HTTPStatus.BAD_REQUEST),
# The .. is processed, making it an unauthenticated path
("app/../entrypoint.js", HTTPStatus.UNAUTHORIZED),
("app/%2E%2E/entrypoint.js", HTTPStatus.UNAUTHORIZED),
# Unauthenticated path
("supervisor/info", HTTPStatus.UNAUTHORIZED),
],
)
async def test_forward_request_admin_unallowed_paths(
hassio_client,
aioclient_mock: AiohttpClientMocker,
bad_path: str,
expected_status: int,
) -> None:
"""Test fetching normal path."""
resp = await hassio_client.get(f"/api/hassio/{bad_path}")
# Check we got right response
assert resp.status == expected_status
# Check we didn't forward command
assert len(aioclient_mock.mock_calls) == 0
async def test_bad_gateway_when_cannot_find_supervisor( async def test_bad_gateway_when_cannot_find_supervisor(
hassio_client, aioclient_mock: AiohttpClientMocker hassio_client, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test we get a bad gateway error if we can't find supervisor.""" """Test we get a bad gateway error if we can't find supervisor."""
aioclient_mock.get("http://127.0.0.1/addons/test/info", exc=asyncio.TimeoutError) aioclient_mock.get("http://127.0.0.1/app/entrypoint.js", exc=asyncio.TimeoutError)
resp = await hassio_client.get("/api/hassio/addons/test/info") resp = await hassio_client.get("/api/hassio/app/entrypoint.js")
assert resp.status == HTTPStatus.BAD_GATEWAY assert resp.status == HTTPStatus.BAD_GATEWAY
async def test_forwarding_user_info(
hassio_client, hass_admin_user: MockUser, aioclient_mock: AiohttpClientMocker
) -> None:
"""Test that we forward user info correctly."""
aioclient_mock.get("http://127.0.0.1/hello")
resp = await hassio_client.get("/api/hassio/hello")
# Check we got right response
assert resp.status == HTTPStatus.OK
assert len(aioclient_mock.mock_calls) == 1
req_headers = aioclient_mock.mock_calls[0][-1]
assert req_headers["X-Hass-User-ID"] == hass_admin_user.id
assert req_headers["X-Hass-Is-Admin"] == "1"
async def test_backup_upload_headers( async def test_backup_upload_headers(
hassio_client, aioclient_mock: AiohttpClientMocker, caplog: pytest.LogCaptureFixture hassio_client,
aioclient_mock: AiohttpClientMocker,
caplog: pytest.LogCaptureFixture,
mock_not_onboarded,
) -> None: ) -> None:
"""Test that we forward the full header for backup upload.""" """Test that we forward the full header for backup upload."""
content_type = "multipart/form-data; boundary='--webkit'" content_type = "multipart/form-data; boundary='--webkit'"
aioclient_mock.get("http://127.0.0.1/backups/new/upload") aioclient_mock.post("http://127.0.0.1/backups/new/upload")
resp = await hassio_client.get( resp = await hassio_client.post(
"/api/hassio/backups/new/upload", headers={"Content-Type": content_type} "/api/hassio/backups/new/upload", headers={"Content-Type": content_type}
) )
@ -168,19 +424,19 @@ async def test_backup_upload_headers(
async def test_backup_download_headers( async def test_backup_download_headers(
hassio_client, aioclient_mock: AiohttpClientMocker hassio_client, aioclient_mock: AiohttpClientMocker, mock_not_onboarded
) -> None: ) -> None:
"""Test that we forward the full header for backup download.""" """Test that we forward the full header for backup download."""
content_disposition = "attachment; filename=test.tar" content_disposition = "attachment; filename=test.tar"
aioclient_mock.get( aioclient_mock.get(
"http://127.0.0.1/backups/slug/download", "http://127.0.0.1/backups/1234abcd/download",
headers={ headers={
"Content-Length": "50000000", "Content-Length": "50000000",
"Content-Disposition": content_disposition, "Content-Disposition": content_disposition,
}, },
) )
resp = await hassio_client.get("/api/hassio/backups/slug/download") resp = await hassio_client.get("/api/hassio/backups/1234abcd/download")
# Check we got right response # Check we got right response
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
@ -190,21 +446,10 @@ async def test_backup_download_headers(
assert resp.headers["Content-Disposition"] == content_disposition assert resp.headers["Content-Disposition"] == content_disposition
def test_need_auth(hass: HomeAssistant) -> None:
"""Test if the requested path needs authentication."""
assert not _need_auth(hass, "addons/test/logo")
assert _need_auth(hass, "backups/new/upload")
assert _need_auth(hass, "supervisor/logs")
hass.data["onboarding"] = False
assert not _need_auth(hass, "backups/new/upload")
assert not _need_auth(hass, "supervisor/logs")
async def test_stream(hassio_client, aioclient_mock: AiohttpClientMocker) -> None: async def test_stream(hassio_client, aioclient_mock: AiohttpClientMocker) -> None:
"""Verify that the request is a stream.""" """Verify that the request is a stream."""
aioclient_mock.get("http://127.0.0.1/test") aioclient_mock.get("http://127.0.0.1/app/entrypoint.js")
await hassio_client.get("/api/hassio/test", data="test") await hassio_client.get("/api/hassio/app/entrypoint.js", data="test")
assert isinstance(aioclient_mock.mock_calls[-1][2], StreamReader) assert isinstance(aioclient_mock.mock_calls[-1][2], StreamReader)

View File

@ -21,7 +21,7 @@ from tests.test_util.aiohttp import AiohttpClientMocker
], ],
) )
async def test_ingress_request_get( async def test_ingress_request_get(
hassio_client, build_type, aioclient_mock: AiohttpClientMocker hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test no auth needed for .""" """Test no auth needed for ."""
aioclient_mock.get( aioclient_mock.get(
@ -29,7 +29,7 @@ async def test_ingress_request_get(
text="test", text="test",
) )
resp = await hassio_client.get( resp = await hassio_noauth_client.get(
f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}",
headers={"X-Test-Header": "beer"}, headers={"X-Test-Header": "beer"},
) )
@ -41,7 +41,8 @@ async def test_ingress_request_get(
# Check we forwarded command # Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3]
assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress"
assert ( assert (
aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"]
== f"/api/hassio_ingress/{build_type[0]}" == f"/api/hassio_ingress/{build_type[0]}"
@ -63,7 +64,7 @@ async def test_ingress_request_get(
], ],
) )
async def test_ingress_request_post( async def test_ingress_request_post(
hassio_client, build_type, aioclient_mock: AiohttpClientMocker hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test no auth needed for .""" """Test no auth needed for ."""
aioclient_mock.post( aioclient_mock.post(
@ -71,7 +72,7 @@ async def test_ingress_request_post(
text="test", text="test",
) )
resp = await hassio_client.post( resp = await hassio_noauth_client.post(
f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}",
headers={"X-Test-Header": "beer"}, headers={"X-Test-Header": "beer"},
) )
@ -83,7 +84,8 @@ async def test_ingress_request_post(
# Check we forwarded command # Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3]
assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress"
assert ( assert (
aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"]
== f"/api/hassio_ingress/{build_type[0]}" == f"/api/hassio_ingress/{build_type[0]}"
@ -105,7 +107,7 @@ async def test_ingress_request_post(
], ],
) )
async def test_ingress_request_put( async def test_ingress_request_put(
hassio_client, build_type, aioclient_mock: AiohttpClientMocker hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test no auth needed for .""" """Test no auth needed for ."""
aioclient_mock.put( aioclient_mock.put(
@ -113,7 +115,7 @@ async def test_ingress_request_put(
text="test", text="test",
) )
resp = await hassio_client.put( resp = await hassio_noauth_client.put(
f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}",
headers={"X-Test-Header": "beer"}, headers={"X-Test-Header": "beer"},
) )
@ -125,7 +127,8 @@ async def test_ingress_request_put(
# Check we forwarded command # Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3]
assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress"
assert ( assert (
aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"]
== f"/api/hassio_ingress/{build_type[0]}" == f"/api/hassio_ingress/{build_type[0]}"
@ -147,7 +150,7 @@ async def test_ingress_request_put(
], ],
) )
async def test_ingress_request_delete( async def test_ingress_request_delete(
hassio_client, build_type, aioclient_mock: AiohttpClientMocker hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test no auth needed for .""" """Test no auth needed for ."""
aioclient_mock.delete( aioclient_mock.delete(
@ -155,7 +158,7 @@ async def test_ingress_request_delete(
text="test", text="test",
) )
resp = await hassio_client.delete( resp = await hassio_noauth_client.delete(
f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}",
headers={"X-Test-Header": "beer"}, headers={"X-Test-Header": "beer"},
) )
@ -167,7 +170,8 @@ async def test_ingress_request_delete(
# Check we forwarded command # Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3]
assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress"
assert ( assert (
aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"]
== f"/api/hassio_ingress/{build_type[0]}" == f"/api/hassio_ingress/{build_type[0]}"
@ -189,7 +193,7 @@ async def test_ingress_request_delete(
], ],
) )
async def test_ingress_request_patch( async def test_ingress_request_patch(
hassio_client, build_type, aioclient_mock: AiohttpClientMocker hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test no auth needed for .""" """Test no auth needed for ."""
aioclient_mock.patch( aioclient_mock.patch(
@ -197,7 +201,7 @@ async def test_ingress_request_patch(
text="test", text="test",
) )
resp = await hassio_client.patch( resp = await hassio_noauth_client.patch(
f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}",
headers={"X-Test-Header": "beer"}, headers={"X-Test-Header": "beer"},
) )
@ -209,7 +213,8 @@ async def test_ingress_request_patch(
# Check we forwarded command # Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3]
assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress"
assert ( assert (
aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"]
== f"/api/hassio_ingress/{build_type[0]}" == f"/api/hassio_ingress/{build_type[0]}"
@ -231,7 +236,7 @@ async def test_ingress_request_patch(
], ],
) )
async def test_ingress_request_options( async def test_ingress_request_options(
hassio_client, build_type, aioclient_mock: AiohttpClientMocker hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test no auth needed for .""" """Test no auth needed for ."""
aioclient_mock.options( aioclient_mock.options(
@ -239,7 +244,7 @@ async def test_ingress_request_options(
text="test", text="test",
) )
resp = await hassio_client.options( resp = await hassio_noauth_client.options(
f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}",
headers={"X-Test-Header": "beer"}, headers={"X-Test-Header": "beer"},
) )
@ -251,7 +256,8 @@ async def test_ingress_request_options(
# Check we forwarded command # Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3]
assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress"
assert ( assert (
aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"]
== f"/api/hassio_ingress/{build_type[0]}" == f"/api/hassio_ingress/{build_type[0]}"
@ -273,20 +279,21 @@ async def test_ingress_request_options(
], ],
) )
async def test_ingress_websocket( async def test_ingress_websocket(
hassio_client, build_type, aioclient_mock: AiohttpClientMocker hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker
) -> None: ) -> None:
"""Test no auth needed for .""" """Test no auth needed for ."""
aioclient_mock.get(f"http://127.0.0.1/ingress/{build_type[0]}/{build_type[1]}") aioclient_mock.get(f"http://127.0.0.1/ingress/{build_type[0]}/{build_type[1]}")
# Ignore error because we can setup a full IO infrastructure # Ignore error because we can setup a full IO infrastructure
await hassio_client.ws_connect( await hassio_noauth_client.ws_connect(
f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}",
headers={"X-Test-Header": "beer"}, headers={"X-Test-Header": "beer"},
) )
# Check we forwarded command # Check we forwarded command
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3]
assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress"
assert ( assert (
aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"]
== f"/api/hassio_ingress/{build_type[0]}" == f"/api/hassio_ingress/{build_type[0]}"
@ -298,7 +305,9 @@ async def test_ingress_websocket(
async def test_ingress_missing_peername( async def test_ingress_missing_peername(
hassio_client, aioclient_mock: AiohttpClientMocker, caplog: pytest.LogCaptureFixture hassio_noauth_client,
aioclient_mock: AiohttpClientMocker,
caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test hadnling of missing peername.""" """Test hadnling of missing peername."""
aioclient_mock.get( aioclient_mock.get(
@ -314,7 +323,7 @@ async def test_ingress_missing_peername(
return_value=MagicMock(), return_value=MagicMock(),
) as transport_mock: ) as transport_mock:
transport_mock.get_extra_info = get_extra_info transport_mock.get_extra_info = get_extra_info
resp = await hassio_client.get( resp = await hassio_noauth_client.get(
"/api/hassio_ingress/lorem/ipsum", "/api/hassio_ingress/lorem/ipsum",
headers={"X-Test-Header": "beer"}, headers={"X-Test-Header": "beer"},
) )
@ -323,3 +332,19 @@ async def test_ingress_missing_peername(
# Check we got right response # Check we got right response
assert resp.status == HTTPStatus.BAD_REQUEST assert resp.status == HTTPStatus.BAD_REQUEST
async def test_forwarding_paths_as_requested(
hassio_noauth_client, aioclient_mock
) -> None:
"""Test incomnig URLs with double encoding go out as dobule encoded."""
# This double encoded string should be forwarded double-encoded too.
aioclient_mock.get(
"http://127.0.0.1/ingress/mock-token/hello/%252e./world",
text="test",
)
resp = await hassio_noauth_client.get(
"/api/hassio_ingress/mock-token/hello/%252e./world",
)
assert await resp.text() == "test"

View File

@ -153,6 +153,11 @@ async def test_websocket_supervisor_api(
msg = await websocket_client.receive_json() msg = await websocket_client.receive_json()
assert msg["result"]["version_latest"] == "1.0.0" assert msg["result"]["version_latest"] == "1.0.0"
assert aioclient_mock.mock_calls[-1][3] == {
"X-Hass-Source": "core.websocket_api",
"Authorization": "Bearer 123456",
}
async def test_websocket_supervisor_api_error( async def test_websocket_supervisor_api_error(
hassio_env, hassio_env,

View File

@ -69,7 +69,7 @@ async def test_schema_update_calls(recorder_db_url: str, hass: HomeAssistant) ->
session_maker = instance.get_session session_maker = instance.get_session
update.assert_has_calls( update.assert_has_calls(
[ [
call(hass, engine, session_maker, version + 1, 0) call(instance, hass, engine, session_maker, version + 1, 0)
for version in range(0, db_schema.SCHEMA_VERSION) for version in range(0, db_schema.SCHEMA_VERSION)
] ]
) )
@ -304,6 +304,8 @@ async def test_schema_migrate(
migration_version = None migration_version = None
real_migrate_schema = recorder.migration.migrate_schema real_migrate_schema = recorder.migration.migrate_schema
real_apply_update = recorder.migration._apply_update real_apply_update = recorder.migration._apply_update
real_create_index = recorder.migration._create_index
create_calls = 0
def _create_engine_test(*args, **kwargs): def _create_engine_test(*args, **kwargs):
"""Test version of create_engine that initializes with old schema. """Test version of create_engine that initializes with old schema.
@ -355,6 +357,17 @@ async def test_schema_migrate(
migration_stall.wait() migration_stall.wait()
real_apply_update(*args) real_apply_update(*args)
def _sometimes_failing_create_index(*args):
"""Make the first index create raise a retryable error to ensure we retry."""
if recorder_db_url.startswith("mysql://"):
nonlocal create_calls
if create_calls < 1:
create_calls += 1
mysql_exception = OperationalError("statement", {}, [])
mysql_exception.orig = Exception(1205, "retryable")
raise mysql_exception
real_create_index(*args)
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch( with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
"homeassistant.components.recorder.core.create_engine", "homeassistant.components.recorder.core.create_engine",
new=_create_engine_test, new=_create_engine_test,
@ -368,6 +381,11 @@ async def test_schema_migrate(
), patch( ), patch(
"homeassistant.components.recorder.migration._apply_update", "homeassistant.components.recorder.migration._apply_update",
wraps=_instrument_apply_update, wraps=_instrument_apply_update,
) as apply_update_mock, patch(
"homeassistant.components.recorder.util.time.sleep"
), patch(
"homeassistant.components.recorder.migration._create_index",
wraps=_sometimes_failing_create_index,
), patch( ), patch(
"homeassistant.components.recorder.Recorder._schedule_compile_missing_statistics", "homeassistant.components.recorder.Recorder._schedule_compile_missing_statistics",
), patch( ), patch(
@ -394,12 +412,13 @@ async def test_schema_migrate(
assert migration_version == db_schema.SCHEMA_VERSION assert migration_version == db_schema.SCHEMA_VERSION
assert setup_run.called assert setup_run.called
assert recorder.util.async_migration_in_progress(hass) is not True assert recorder.util.async_migration_in_progress(hass) is not True
assert apply_update_mock.called
def test_invalid_update(hass: HomeAssistant) -> None: def test_invalid_update(hass: HomeAssistant) -> None:
"""Test that an invalid new version raises an exception.""" """Test that an invalid new version raises an exception."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
migration._apply_update(hass, Mock(), Mock(), -1, 0) migration._apply_update(Mock(), hass, Mock(), Mock(), -1, 0)
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -2,7 +2,7 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
import json import json
import sqlite3 import sqlite3
from unittest.mock import MagicMock, patch from unittest.mock import patch
import pytest import pytest
from sqlalchemy.exc import DatabaseError, OperationalError from sqlalchemy.exc import DatabaseError, OperationalError
@ -192,7 +192,7 @@ async def test_purge_old_states_encounters_temporary_mysql_error(
await async_wait_recording_done(hass) await async_wait_recording_done(hass)
mysql_exception = OperationalError("statement", {}, []) mysql_exception = OperationalError("statement", {}, [])
mysql_exception.orig = MagicMock(args=(1205, "retryable")) mysql_exception.orig = Exception(1205, "retryable")
with patch( with patch(
"homeassistant.components.recorder.util.time.sleep" "homeassistant.components.recorder.util.time.sleep"

View File

@ -8,7 +8,7 @@ import sys
from unittest.mock import ANY, DEFAULT, MagicMock, patch, sentinel from unittest.mock import ANY, DEFAULT, MagicMock, patch, sentinel
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine, select
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -22,6 +22,10 @@ from homeassistant.components.recorder.models import (
) )
from homeassistant.components.recorder.statistics import ( from homeassistant.components.recorder.statistics import (
STATISTIC_UNIT_TO_UNIT_CONVERTER, STATISTIC_UNIT_TO_UNIT_CONVERTER,
_generate_get_metadata_stmt,
_generate_max_mean_min_statistic_in_sub_period_stmt,
_generate_statistics_at_time_stmt,
_generate_statistics_during_period_stmt,
_statistics_during_period_with_session, _statistics_during_period_with_session,
_update_or_add_metadata, _update_or_add_metadata,
async_add_external_statistics, async_add_external_statistics,
@ -1231,8 +1235,9 @@ def test_delete_duplicates_no_duplicates(
"""Test removal of duplicated statistics.""" """Test removal of duplicated statistics."""
hass = hass_recorder() hass = hass_recorder()
wait_recording_done(hass) wait_recording_done(hass)
instance = recorder.get_instance(hass)
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
delete_statistics_duplicates(hass, session) delete_statistics_duplicates(instance, hass, session)
assert "duplicated statistics rows" not in caplog.text assert "duplicated statistics rows" not in caplog.text
assert "Found non identical" not in caplog.text assert "Found non identical" not in caplog.text
assert "Found duplicated" not in caplog.text assert "Found duplicated" not in caplog.text
@ -1798,3 +1803,100 @@ def record_states(hass):
states[sns4].append(set_state(sns4, "20", attributes=sns4_attr)) states[sns4].append(set_state(sns4, "20", attributes=sns4_attr))
return zero, four, states return zero, four, states
def test_cache_key_for_generate_statistics_during_period_stmt():
"""Test cache key for _generate_statistics_during_period_stmt."""
columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts)
stmt = _generate_statistics_during_period_stmt(
columns, dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm, {}
)
cache_key_1 = stmt._generate_cache_key()
stmt2 = _generate_statistics_during_period_stmt(
columns, dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm, {}
)
cache_key_2 = stmt2._generate_cache_key()
assert cache_key_1 == cache_key_2
columns2 = select(
StatisticsShortTerm.metadata_id,
StatisticsShortTerm.start_ts,
StatisticsShortTerm.sum,
StatisticsShortTerm.mean,
)
stmt3 = _generate_statistics_during_period_stmt(
columns2,
dt_util.utcnow(),
dt_util.utcnow(),
[0],
StatisticsShortTerm,
{"max", "mean"},
)
cache_key_3 = stmt3._generate_cache_key()
assert cache_key_1 != cache_key_3
def test_cache_key_for_generate_get_metadata_stmt():
"""Test cache key for _generate_get_metadata_stmt."""
stmt_mean = _generate_get_metadata_stmt([0], "mean")
stmt_mean2 = _generate_get_metadata_stmt([1], "mean")
stmt_sum = _generate_get_metadata_stmt([0], "sum")
stmt_none = _generate_get_metadata_stmt()
assert stmt_mean._generate_cache_key() == stmt_mean2._generate_cache_key()
assert stmt_mean._generate_cache_key() != stmt_sum._generate_cache_key()
assert stmt_mean._generate_cache_key() != stmt_none._generate_cache_key()
def test_cache_key_for_generate_max_mean_min_statistic_in_sub_period_stmt():
"""Test cache key for _generate_max_mean_min_statistic_in_sub_period_stmt."""
columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts)
stmt = _generate_max_mean_min_statistic_in_sub_period_stmt(
columns,
dt_util.utcnow(),
dt_util.utcnow(),
StatisticsShortTerm,
[0],
)
cache_key_1 = stmt._generate_cache_key()
stmt2 = _generate_max_mean_min_statistic_in_sub_period_stmt(
columns,
dt_util.utcnow(),
dt_util.utcnow(),
StatisticsShortTerm,
[0],
)
cache_key_2 = stmt2._generate_cache_key()
assert cache_key_1 == cache_key_2
columns2 = select(
StatisticsShortTerm.metadata_id,
StatisticsShortTerm.start_ts,
StatisticsShortTerm.sum,
StatisticsShortTerm.mean,
)
stmt3 = _generate_max_mean_min_statistic_in_sub_period_stmt(
columns2,
dt_util.utcnow(),
dt_util.utcnow(),
StatisticsShortTerm,
[0],
)
cache_key_3 = stmt3._generate_cache_key()
assert cache_key_1 != cache_key_3
def test_cache_key_for_generate_statistics_at_time_stmt():
"""Test cache key for _generate_statistics_at_time_stmt."""
columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts)
stmt = _generate_statistics_at_time_stmt(columns, StatisticsShortTerm, {0}, 0.0)
cache_key_1 = stmt._generate_cache_key()
stmt2 = _generate_statistics_at_time_stmt(columns, StatisticsShortTerm, {0}, 0.0)
cache_key_2 = stmt2._generate_cache_key()
assert cache_key_1 == cache_key_2
columns2 = select(
StatisticsShortTerm.metadata_id,
StatisticsShortTerm.start_ts,
StatisticsShortTerm.sum,
StatisticsShortTerm.mean,
)
stmt3 = _generate_statistics_at_time_stmt(columns2, StatisticsShortTerm, {0}, 0.0)
cache_key_3 = stmt3._generate_cache_key()
assert cache_key_1 != cache_key_3

View File

@ -133,9 +133,7 @@ class MockNeighbour:
@pytest.fixture @pytest.fixture
def ndb() -> Mock: def ndb() -> Mock:
"""Prevent NDB poking the OS route tables.""" """Prevent NDB poking the OS route tables."""
with patch( with patch("pyroute2.NDB") as ndb, ndb() as instance:
"homeassistant.components.thread.diagnostics.NDB"
) as ndb, ndb() as instance:
instance.neighbours = [] instance.neighbours = []
instance.routes = [] instance.routes = []
yield instance yield instance

View File

@ -22,14 +22,14 @@ from homeassistant.data_entry_flow import FlowResultType
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
MOCK_USER_DATA = { MOCK_USER_DATA = {
"host": "1.1.1.1", "host": "https://fake.omada.host",
"verify_ssl": True, "verify_ssl": True,
"username": "test-username", "username": "test-username",
"password": "test-password", "password": "test-password",
} }
MOCK_ENTRY_DATA = { MOCK_ENTRY_DATA = {
"host": "1.1.1.1", "host": "https://fake.omada.host",
"verify_ssl": True, "verify_ssl": True,
"site": "SiteId", "site": "SiteId",
"username": "test-username", "username": "test-username",
@ -111,7 +111,7 @@ async def test_form_multiple_sites(hass: HomeAssistant) -> None:
assert result3["type"] == FlowResultType.CREATE_ENTRY assert result3["type"] == FlowResultType.CREATE_ENTRY
assert result3["title"] == "OC200 (Site 2)" assert result3["title"] == "OC200 (Site 2)"
assert result3["data"] == { assert result3["data"] == {
"host": "1.1.1.1", "host": "https://fake.omada.host",
"verify_ssl": True, "verify_ssl": True,
"site": "second", "site": "second",
"username": "test-username", "username": "test-username",
@ -272,7 +272,7 @@ async def test_async_step_reauth_success(hass: HomeAssistant) -> None:
mocked_validate.assert_called_once_with( mocked_validate.assert_called_once_with(
hass, hass,
{ {
"host": "1.1.1.1", "host": "https://fake.omada.host",
"verify_ssl": True, "verify_ssl": True,
"site": "SiteId", "site": "SiteId",
"username": "new_uname", "username": "new_uname",
@ -353,6 +353,64 @@ async def test_create_omada_client_parses_args(hass: HomeAssistant) -> None:
assert result is not None assert result is not None
mock_client.assert_called_once_with( mock_client.assert_called_once_with(
"1.1.1.1", "test-username", "test-password", "ws" "https://fake.omada.host", "test-username", "test-password", "ws"
) )
mock_clientsession.assert_called_once_with(hass, verify_ssl=True) mock_clientsession.assert_called_once_with(hass, verify_ssl=True)
async def test_create_omada_client_adds_missing_scheme(hass: HomeAssistant) -> None:
"""Test config arguments are passed to Omada client."""
with patch(
"homeassistant.components.tplink_omada.config_flow.OmadaClient", autospec=True
) as mock_client, patch(
"homeassistant.components.tplink_omada.config_flow.async_get_clientsession",
return_value="ws",
) as mock_clientsession:
result = await create_omada_client(
hass,
{
"host": "fake.omada.host",
"verify_ssl": True,
"username": "test-username",
"password": "test-password",
},
)
assert result is not None
mock_client.assert_called_once_with(
"https://fake.omada.host", "test-username", "test-password", "ws"
)
mock_clientsession.assert_called_once_with(hass, verify_ssl=True)
async def test_create_omada_client_with_ip_creates_clientsession(
hass: HomeAssistant,
) -> None:
"""Test config arguments are passed to Omada client."""
with patch(
"homeassistant.components.tplink_omada.config_flow.OmadaClient", autospec=True
) as mock_client, patch(
"homeassistant.components.tplink_omada.config_flow.CookieJar", autospec=True
) as mock_jar, patch(
"homeassistant.components.tplink_omada.config_flow.async_create_clientsession",
return_value="ws",
) as mock_create_clientsession:
result = await create_omada_client(
hass,
{
"host": "10.10.10.10",
"verify_ssl": True, # Verify is meaningless for IP
"username": "test-username",
"password": "test-password",
},
)
assert result is not None
mock_client.assert_called_once_with(
"https://10.10.10.10", "test-username", "test-password", "ws"
)
mock_create_clientsession.assert_called_once_with(
hass, cookie_jar=mock_jar.return_value
)

View File

@ -1,9 +1,10 @@
"""Tests for ZHA integration init.""" """Tests for ZHA integration init."""
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from zigpy.config import CONF_DEVICE, CONF_DEVICE_PATH from zigpy.config import CONF_DEVICE, CONF_DEVICE_PATH
from homeassistant.components.zha import async_setup_entry
from homeassistant.components.zha.core.const import ( from homeassistant.components.zha.core.const import (
CONF_BAUDRATE, CONF_BAUDRATE,
CONF_RADIO_TYPE, CONF_RADIO_TYPE,
@ -108,3 +109,41 @@ async def test_config_depreciation(hass: HomeAssistant, zha_config) -> None:
) as setup_mock: ) as setup_mock:
assert await async_setup_component(hass, DOMAIN, {DOMAIN: zha_config}) assert await async_setup_component(hass, DOMAIN, {DOMAIN: zha_config})
assert setup_mock.call_count == 1 assert setup_mock.call_count == 1
@pytest.mark.parametrize(
("path", "cleaned_path"),
[
("/dev/path1", "/dev/path1"),
("/dev/path1 ", "/dev/path1 "),
("socket://dev/path1 ", "socket://dev/path1"),
],
)
@patch("homeassistant.components.zha.setup_quirks", Mock(return_value=True))
@patch("homeassistant.components.zha.api.async_load_api", Mock(return_value=True))
async def test_setup_with_v3_spaces_in_uri(
hass: HomeAssistant, path: str, cleaned_path: str
) -> None:
"""Test migration of config entry from v3 with spaces after `socket://` URI."""
config_entry_v3 = MockConfigEntry(
domain=DOMAIN,
data={
CONF_RADIO_TYPE: DATA_RADIO_TYPE,
CONF_DEVICE: {CONF_DEVICE_PATH: path, CONF_BAUDRATE: 115200},
},
version=3,
)
config_entry_v3.add_to_hass(hass)
with patch(
"homeassistant.components.zha.ZHAGateway", return_value=AsyncMock()
) as mock_gateway:
mock_gateway.return_value.coordinator_ieee = "mock_ieee"
mock_gateway.return_value.radio_description = "mock_radio"
assert await async_setup_entry(hass, config_entry_v3)
hass.data[DOMAIN]["zha_gateway"] = mock_gateway.return_value
assert config_entry_v3.data[CONF_RADIO_TYPE] == DATA_RADIO_TYPE
assert config_entry_v3.data[CONF_DEVICE][CONF_DEVICE_PATH] == cleaned_path
assert config_entry_v3.version == 3