diff --git a/CODEOWNERS b/CODEOWNERS index cb559a7d7bb..e782f050926 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1100,6 +1100,7 @@ build.json @home-assistant/supervisor /homeassistant/components/smhi/ @gjohansson-ST /tests/components/smhi/ @gjohansson-ST /homeassistant/components/sms/ @ocalvo +/homeassistant/components/snapcast/ @luar123 /homeassistant/components/snooz/ @AustinBrunkhorst /tests/components/snooz/ @AustinBrunkhorst /homeassistant/components/solaredge/ @frenck diff --git a/homeassistant/components/airq/sensor.py b/homeassistant/components/airq/sensor.py index e46893e8d79..a47c308279d 100644 --- a/homeassistant/components/airq/sensor.py +++ b/homeassistant/components/airq/sensor.py @@ -68,7 +68,6 @@ SENSOR_TYPES: list[AirQEntityDescription] = [ AirQEntityDescription( key="co", name="CO", - device_class=SensorDeviceClass.CO, native_unit_of_measurement=CONCENTRATION_MILLIGRAMS_PER_CUBIC_METER, state_class=SensorStateClass.MEASUREMENT, value=lambda data: data.get("co"), @@ -289,7 +288,6 @@ SENSOR_TYPES: list[AirQEntityDescription] = [ AirQEntityDescription( key="tvoc", name="VOC", - device_class=SensorDeviceClass.VOLATILE_ORGANIC_COMPOUNDS, native_unit_of_measurement=CONCENTRATION_PARTS_PER_BILLION, state_class=SensorStateClass.MEASUREMENT, value=lambda data: data.get("tvoc"), @@ -297,7 +295,6 @@ SENSOR_TYPES: list[AirQEntityDescription] = [ AirQEntityDescription( key="tvoc_ionsc", name="VOC (Industrial)", - device_class=SensorDeviceClass.VOLATILE_ORGANIC_COMPOUNDS, native_unit_of_measurement=CONCENTRATION_PARTS_PER_BILLION, state_class=SensorStateClass.MEASUREMENT, value=lambda data: data.get("tvoc_ionsc"), diff --git a/homeassistant/components/api/__init__.py b/homeassistant/components/api/__init__.py index 56a07a6bcf0..5c0a60ecef7 100644 --- a/homeassistant/components/api/__init__.py +++ b/homeassistant/components/api/__init__.py @@ -1,5 +1,6 @@ """Rest API for Home Assistant.""" import asyncio +from functools import lru_cache from http import HTTPStatus import logging @@ -350,6 +351,12 @@ class APIComponentsView(HomeAssistantView): return self.json(request.app["hass"].config.components) +@lru_cache +def _cached_template(template_str: str, hass: ha.HomeAssistant) -> template.Template: + """Return a cached template.""" + return template.Template(template_str, hass) + + class APITemplateView(HomeAssistantView): """View to handle Template requests.""" @@ -362,7 +369,7 @@ class APITemplateView(HomeAssistantView): raise Unauthorized() try: data = await request.json() - tpl = template.Template(data["template"], request.app["hass"]) + tpl = _cached_template(data["template"], request.app["hass"]) return tpl.async_render(variables=data.get("variables"), parse_result=False) except (ValueError, TemplateError) as ex: return self.json_message( diff --git a/homeassistant/components/bluetooth/base_scanner.py b/homeassistant/components/bluetooth/base_scanner.py index 00cc9fff0fe..903f14a9227 100644 --- a/homeassistant/components/bluetooth/base_scanner.py +++ b/homeassistant/components/bluetooth/base_scanner.py @@ -227,20 +227,21 @@ class BaseHaRemoteScanner(BaseHaScanner): self.hass, self._async_expire_devices, timedelta(seconds=30) ) cancel_stop = self.hass.bus.async_listen( - EVENT_HOMEASSISTANT_STOP, self._save_history + EVENT_HOMEASSISTANT_STOP, self._async_save_history ) self._async_setup_scanner_watchdog() @hass_callback def _cancel() -> None: - self._save_history() + self._async_save_history() self._async_stop_scanner_watchdog() cancel_track() cancel_stop() return _cancel - def _save_history(self, event: Event | None = None) -> None: + @hass_callback + def _async_save_history(self, event: Event | None = None) -> None: """Save the history.""" self._storage.async_set_advertisement_history( self.source, @@ -252,6 +253,7 @@ class BaseHaRemoteScanner(BaseHaScanner): ), ) + @hass_callback def _async_expire_devices(self, _datetime: datetime.datetime) -> None: """Expire old devices.""" now = MONOTONIC_TIME() diff --git a/homeassistant/components/esphome/manifest.json b/homeassistant/components/esphome/manifest.json index fde8c26ba5e..e8e4e4876f0 100644 --- a/homeassistant/components/esphome/manifest.json +++ b/homeassistant/components/esphome/manifest.json @@ -14,6 +14,6 @@ "integration_type": "device", "iot_class": "local_push", "loggers": ["aioesphomeapi", "noiseprotocol"], - "requirements": ["aioesphomeapi==13.4.1", "esphome-dashboard-api==1.2.3"], + "requirements": ["aioesphomeapi==13.4.2", "esphome-dashboard-api==1.2.3"], "zeroconf": ["_esphomelib._tcp.local."] } diff --git a/homeassistant/components/fibaro/manifest.json b/homeassistant/components/fibaro/manifest.json index 6522d3b06ed..6dd2104bd9b 100644 --- a/homeassistant/components/fibaro/manifest.json +++ b/homeassistant/components/fibaro/manifest.json @@ -7,5 +7,5 @@ "integration_type": "hub", "iot_class": "local_push", "loggers": ["pyfibaro"], - "requirements": ["pyfibaro==0.6.8"] + "requirements": ["pyfibaro==0.6.9"] } diff --git a/homeassistant/components/frontend/manifest.json b/homeassistant/components/frontend/manifest.json index c09f2d501c6..da68e48cc08 100644 --- a/homeassistant/components/frontend/manifest.json +++ b/homeassistant/components/frontend/manifest.json @@ -20,5 +20,5 @@ "documentation": "https://www.home-assistant.io/integrations/frontend", "integration_type": "system", "quality_scale": "internal", - "requirements": ["home-assistant-frontend==20230302.0"] + "requirements": ["home-assistant-frontend==20230306.0"] } diff --git a/homeassistant/components/geniushub/climate.py b/homeassistant/components/geniushub/climate.py index 21ef2809360..c2b32582cef 100644 --- a/homeassistant/components/geniushub/climate.py +++ b/homeassistant/components/geniushub/climate.py @@ -41,7 +41,7 @@ async def async_setup_platform( [ GeniusClimateZone(broker, z) for z in broker.client.zone_objs - if z.data["type"] in GH_ZONES + if z.data.get("type") in GH_ZONES ] ) diff --git a/homeassistant/components/geniushub/switch.py b/homeassistant/components/geniushub/switch.py index cf29d0ea802..79ba418d509 100644 --- a/homeassistant/components/geniushub/switch.py +++ b/homeassistant/components/geniushub/switch.py @@ -42,7 +42,7 @@ async def async_setup_platform( [ GeniusSwitch(broker, z) for z in broker.client.zone_objs - if z.data["type"] == GH_ON_OFF_ZONE + if z.data.get("type") == GH_ON_OFF_ZONE ] ) diff --git a/homeassistant/components/geniushub/water_heater.py b/homeassistant/components/geniushub/water_heater.py index ea8b1a43961..f8cf7288e57 100644 --- a/homeassistant/components/geniushub/water_heater.py +++ b/homeassistant/components/geniushub/water_heater.py @@ -48,7 +48,7 @@ async def async_setup_platform( [ GeniusWaterHeater(broker, z) for z in broker.client.zone_objs - if z.data["type"] in GH_HEATERS + if z.data.get("type") in GH_HEATERS ] ) diff --git a/homeassistant/components/hassio/const.py b/homeassistant/components/hassio/const.py index 64ef7a718a5..2710e146540 100644 --- a/homeassistant/components/hassio/const.py +++ b/homeassistant/components/hassio/const.py @@ -36,6 +36,7 @@ X_AUTH_TOKEN = "X-Supervisor-Token" X_INGRESS_PATH = "X-Ingress-Path" X_HASS_USER_ID = "X-Hass-User-ID" X_HASS_IS_ADMIN = "X-Hass-Is-Admin" +X_HASS_SOURCE = "X-Hass-Source" WS_TYPE = "type" WS_ID = "id" diff --git a/homeassistant/components/hassio/handler.py b/homeassistant/components/hassio/handler.py index 0d923075bf7..762df4f79ca 100644 --- a/homeassistant/components/hassio/handler.py +++ b/homeassistant/components/hassio/handler.py @@ -17,7 +17,7 @@ from homeassistant.const import SERVER_PORT from homeassistant.core import HomeAssistant from homeassistant.loader import bind_hass -from .const import ATTR_DISCOVERY, DOMAIN +from .const import ATTR_DISCOVERY, DOMAIN, X_HASS_SOURCE _LOGGER = logging.getLogger(__name__) @@ -445,6 +445,8 @@ class HassIO: payload=None, timeout=10, return_text=False, + *, + source="core.handler", ): """Send API command to Hass.io. @@ -458,7 +460,8 @@ class HassIO: headers={ aiohttp.hdrs.AUTHORIZATION: ( f"Bearer {os.environ.get('SUPERVISOR_TOKEN', '')}" - ) + ), + X_HASS_SOURCE: source, }, timeout=aiohttp.ClientTimeout(total=timeout), ) diff --git a/homeassistant/components/hassio/http.py b/homeassistant/components/hassio/http.py index 2b7145bdcaa..8a8583a7daf 100644 --- a/homeassistant/components/hassio/http.py +++ b/homeassistant/components/hassio/http.py @@ -6,6 +6,7 @@ from http import HTTPStatus import logging import os import re +from urllib.parse import quote, unquote import aiohttp from aiohttp import web @@ -19,13 +20,16 @@ from aiohttp.hdrs import ( TRANSFER_ENCODING, ) from aiohttp.web_exceptions import HTTPBadGateway -from multidict import istr -from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView +from homeassistant.components.http import ( + KEY_AUTHENTICATED, + KEY_HASS_USER, + HomeAssistantView, +) from homeassistant.components.onboarding import async_is_onboarded from homeassistant.core import HomeAssistant -from .const import X_HASS_IS_ADMIN, X_HASS_USER_ID +from .const import X_HASS_SOURCE _LOGGER = logging.getLogger(__name__) @@ -34,23 +38,53 @@ MAX_UPLOAD_SIZE = 1024 * 1024 * 1024 # pylint: disable=implicit-str-concat NO_TIMEOUT = re.compile( r"^(?:" - r"|homeassistant/update" - r"|hassos/update" - r"|hassos/update/cli" - r"|supervisor/update" - r"|addons/[^/]+/(?:update|install|rebuild)" r"|backups/.+/full" r"|backups/.+/partial" r"|backups/[^/]+/(?:upload|download)" r")$" ) -NO_AUTH_ONBOARDING = re.compile(r"^(?:" r"|supervisor/logs" r"|backups/[^/]+/.+" r")$") +# fmt: off +# Onboarding can upload backups and restore it +PATHS_NOT_ONBOARDED = re.compile( + r"^(?:" + r"|backups/[a-f0-9]{8}(/info|/new/upload|/download|/restore/full|/restore/partial)?" + r"|backups/new/upload" + r")$" +) -NO_AUTH = re.compile(r"^(?:" r"|app/.*" r"|[store\/]*addons/[^/]+/(logo|icon)" r")$") +# Authenticated users manage backups + download logs +PATHS_ADMIN = re.compile( + r"^(?:" + r"|backups/[a-f0-9]{8}(/info|/download|/restore/full|/restore/partial)?" + r"|backups/new/upload" + r"|audio/logs" + r"|cli/logs" + r"|core/logs" + r"|dns/logs" + r"|host/logs" + r"|multicast/logs" + r"|observer/logs" + r"|supervisor/logs" + r"|addons/[^/]+/logs" + r")$" +) -NO_STORE = re.compile(r"^(?:" r"|app/entrypoint.js" r")$") +# Unauthenticated requests come in for Supervisor panel + add-on images +PATHS_NO_AUTH = re.compile( + r"^(?:" + r"|app/.*" + r"|(store/)?addons/[^/]+/(logo|icon)" + r")$" +) + +NO_STORE = re.compile( + r"^(?:" + r"|app/entrypoint.js" + r")$" +) # pylint: enable=implicit-str-concat +# fmt: on class HassIOView(HomeAssistantView): @@ -65,38 +99,66 @@ class HassIOView(HomeAssistantView): self._host = host self._websession = websession - async def _handle( - self, request: web.Request, path: str - ) -> web.Response | web.StreamResponse: - """Route data to Hass.io.""" - hass = request.app["hass"] - if _need_auth(hass, path) and not request[KEY_AUTHENTICATED]: - return web.Response(status=HTTPStatus.UNAUTHORIZED) - - return await self._command_proxy(path, request) - - delete = _handle - get = _handle - post = _handle - - async def _command_proxy( - self, path: str, request: web.Request - ) -> web.StreamResponse: + async def _handle(self, request: web.Request, path: str) -> web.StreamResponse: """Return a client request with proxy origin for Hass.io supervisor. - This method is a coroutine. + Use cases: + - Onboarding allows restoring backups + - Load Supervisor panel and add-on logo unauthenticated + - User upload/restore backups """ - headers = _init_header(request) - if path == "backups/new/upload": - # We need to reuse the full content type that includes the boundary - headers[ - CONTENT_TYPE - ] = request._stored_content_type # pylint: disable=protected-access + # No bullshit + if path != unquote(path): + return web.Response(status=HTTPStatus.BAD_REQUEST) + + hass: HomeAssistant = request.app["hass"] + is_admin = request[KEY_AUTHENTICATED] and request[KEY_HASS_USER].is_admin + authorized = is_admin + + if is_admin: + allowed_paths = PATHS_ADMIN + + elif not async_is_onboarded(hass): + allowed_paths = PATHS_NOT_ONBOARDED + + # During onboarding we need the user to manage backups + authorized = True + + else: + # Either unauthenticated or not an admin + allowed_paths = PATHS_NO_AUTH + + no_auth_path = PATHS_NO_AUTH.match(path) + headers = { + X_HASS_SOURCE: "core.http", + } + + if no_auth_path: + if request.method != "GET": + return web.Response(status=HTTPStatus.METHOD_NOT_ALLOWED) + + else: + if not allowed_paths.match(path): + return web.Response(status=HTTPStatus.UNAUTHORIZED) + + if authorized: + headers[ + AUTHORIZATION + ] = f"Bearer {os.environ.get('SUPERVISOR_TOKEN', '')}" + + if request.method == "POST": + headers[CONTENT_TYPE] = request.content_type + # _stored_content_type is only computed once `content_type` is accessed + if path == "backups/new/upload": + # We need to reuse the full content type that includes the boundary + headers[ + CONTENT_TYPE + ] = request._stored_content_type # pylint: disable=protected-access try: client = await self._websession.request( method=request.method, - url=f"http://{self._host}/{path}", + url=f"http://{self._host}/{quote(path)}", params=request.query, data=request.content, headers=headers, @@ -123,20 +185,8 @@ class HassIOView(HomeAssistantView): raise HTTPBadGateway() - -def _init_header(request: web.Request) -> dict[istr, str]: - """Create initial header.""" - headers = { - AUTHORIZATION: f"Bearer {os.environ.get('SUPERVISOR_TOKEN', '')}", - CONTENT_TYPE: request.content_type, - } - - # Add user data - if request.get("hass_user") is not None: - headers[istr(X_HASS_USER_ID)] = request["hass_user"].id - headers[istr(X_HASS_IS_ADMIN)] = str(int(request["hass_user"].is_admin)) - - return headers + get = _handle + post = _handle def _response_header(response: aiohttp.ClientResponse, path: str) -> dict[str, str]: @@ -164,12 +214,3 @@ def _get_timeout(path: str) -> ClientTimeout: if NO_TIMEOUT.match(path): return ClientTimeout(connect=10, total=None) return ClientTimeout(connect=10, total=300) - - -def _need_auth(hass: HomeAssistant, path: str) -> bool: - """Return if a path need authentication.""" - if not async_is_onboarded(hass) and NO_AUTH_ONBOARDING.match(path): - return False - if NO_AUTH.match(path): - return False - return True diff --git a/homeassistant/components/hassio/ingress.py b/homeassistant/components/hassio/ingress.py index dceff75bca8..334c7cf719c 100644 --- a/homeassistant/components/hassio/ingress.py +++ b/homeassistant/components/hassio/ingress.py @@ -3,20 +3,22 @@ from __future__ import annotations import asyncio from collections.abc import Iterable +from functools import lru_cache from ipaddress import ip_address import logging -import os +from urllib.parse import quote import aiohttp from aiohttp import ClientTimeout, hdrs, web from aiohttp.web_exceptions import HTTPBadGateway, HTTPBadRequest from multidict import CIMultiDict +from yarl import URL from homeassistant.components.http import HomeAssistantView from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.aiohttp_client import async_get_clientsession -from .const import X_AUTH_TOKEN, X_INGRESS_PATH +from .const import X_HASS_SOURCE, X_INGRESS_PATH _LOGGER = logging.getLogger(__name__) @@ -42,9 +44,19 @@ class HassIOIngress(HomeAssistantView): self._host = host self._websession = websession + @lru_cache def _create_url(self, token: str, path: str) -> str: """Create URL to service.""" - return f"http://{self._host}/ingress/{token}/{path}" + base_path = f"/ingress/{token}/" + url = f"http://{self._host}{base_path}{quote(path)}" + + try: + if not URL(url).path.startswith(base_path): + raise HTTPBadRequest() + except ValueError as err: + raise HTTPBadRequest() from err + + return url async def _handle( self, request: web.Request, token: str, path: str @@ -185,10 +197,8 @@ def _init_header(request: web.Request, token: str) -> CIMultiDict | dict[str, st continue headers[name] = value - # Inject token / cleanup later on Supervisor - headers[X_AUTH_TOKEN] = os.environ.get("SUPERVISOR_TOKEN", "") - # Ingress information + headers[X_HASS_SOURCE] = "core.ingress" headers[X_INGRESS_PATH] = f"/api/hassio_ingress/{token}" # Set X-Forwarded-For diff --git a/homeassistant/components/hassio/websocket_api.py b/homeassistant/components/hassio/websocket_api.py index 3670d5ca1fd..8a9a145f2d6 100644 --- a/homeassistant/components/hassio/websocket_api.py +++ b/homeassistant/components/hassio/websocket_api.py @@ -116,6 +116,7 @@ async def websocket_supervisor_api( method=msg[ATTR_METHOD], timeout=msg.get(ATTR_TIMEOUT, 10), payload=msg.get(ATTR_DATA, {}), + source="core.websocket_api", ) if result.get(ATTR_RESULT) == "error": diff --git a/homeassistant/components/iaqualink/__init__.py b/homeassistant/components/iaqualink/__init__.py index cbdf909001a..225953035a2 100644 --- a/homeassistant/components/iaqualink/__init__.py +++ b/homeassistant/components/iaqualink/__init__.py @@ -153,6 +153,7 @@ async def async_setup_entry( # noqa: C901 system.serial, svc_exception, ) + await system.aqualink.close() else: cur = system.online if cur and not prev: diff --git a/homeassistant/components/iaqualink/utils.py b/homeassistant/components/iaqualink/utils.py index b047af5869c..87bc863a7f8 100644 --- a/homeassistant/components/iaqualink/utils.py +++ b/homeassistant/components/iaqualink/utils.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Awaitable +import httpx from iaqualink.exception import AqualinkServiceException from homeassistant.exceptions import HomeAssistantError @@ -12,5 +13,5 @@ async def await_or_reraise(awaitable: Awaitable) -> None: """Execute API call while catching service exceptions.""" try: await awaitable - except AqualinkServiceException as svc_exception: + except (AqualinkServiceException, httpx.HTTPError) as svc_exception: raise HomeAssistantError(f"Aqualink error: {svc_exception}") from svc_exception diff --git a/homeassistant/components/insteon/manifest.json b/homeassistant/components/insteon/manifest.json index 40316a6ba3e..743e7e4fa19 100644 --- a/homeassistant/components/insteon/manifest.json +++ b/homeassistant/components/insteon/manifest.json @@ -17,8 +17,8 @@ "iot_class": "local_push", "loggers": ["pyinsteon", "pypubsub"], "requirements": [ - "pyinsteon==1.3.3", - "insteon-frontend-home-assistant==0.3.2" + "pyinsteon==1.3.4", + "insteon-frontend-home-assistant==0.3.3" ], "usb": [ { diff --git a/homeassistant/components/insteon/utils.py b/homeassistant/components/insteon/utils.py index c5dbba9c25b..0df823e49b1 100644 --- a/homeassistant/components/insteon/utils.py +++ b/homeassistant/components/insteon/utils.py @@ -1,11 +1,13 @@ """Utilities used by insteon component.""" import asyncio +from collections.abc import Callable import logging from pyinsteon import devices from pyinsteon.address import Address from pyinsteon.constants import ALDBStatus, DeviceAction -from pyinsteon.events import OFF_EVENT, OFF_FAST_EVENT, ON_EVENT, ON_FAST_EVENT +from pyinsteon.device_types.device_base import Device +from pyinsteon.events import OFF_EVENT, OFF_FAST_EVENT, ON_EVENT, ON_FAST_EVENT, Event from pyinsteon.managers.link_manager import ( async_enter_linking_mode, async_enter_unlinking_mode, @@ -27,7 +29,7 @@ from homeassistant.const import ( CONF_PLATFORM, ENTITY_MATCH_ALL, ) -from homeassistant.core import ServiceCall, callback +from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.helpers import device_registry as dr from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, @@ -89,49 +91,52 @@ from .schemas import ( _LOGGER = logging.getLogger(__name__) -def add_on_off_event_device(hass, device): +def _register_event(event: Event, listener: Callable) -> None: + """Register the events raised by a device.""" + _LOGGER.debug( + "Registering on/off event for %s %d %s", + str(event.address), + event.group, + event.name, + ) + event.subscribe(listener, force_strong_ref=True) + + +def add_on_off_event_device(hass: HomeAssistant, device: Device) -> None: """Register an Insteon device as an on/off event device.""" @callback - def async_fire_group_on_off_event(name, address, group, button): + def async_fire_group_on_off_event( + name: str, address: Address, group: int, button: str + ): # Firing an event when a button is pressed. if button and button[-2] == "_": button_id = button[-1].lower() else: button_id = None - schema = {CONF_ADDRESS: address} + schema = {CONF_ADDRESS: address, "group": group} if button_id: schema[EVENT_CONF_BUTTON] = button_id if name == ON_EVENT: event = EVENT_GROUP_ON - if name == OFF_EVENT: + elif name == OFF_EVENT: event = EVENT_GROUP_OFF - if name == ON_FAST_EVENT: + elif name == ON_FAST_EVENT: event = EVENT_GROUP_ON_FAST - if name == OFF_FAST_EVENT: + elif name == OFF_FAST_EVENT: event = EVENT_GROUP_OFF_FAST + else: + event = f"insteon.{name}" _LOGGER.debug("Firing event %s with %s", event, schema) hass.bus.async_fire(event, schema) - for group in device.events: - if isinstance(group, int): - for event in device.events[group]: - if event in [ - OFF_EVENT, - ON_EVENT, - OFF_FAST_EVENT, - ON_FAST_EVENT, - ]: - _LOGGER.debug( - "Registering on/off event for %s %d %s", - str(device.address), - group, - event, - ) - device.events[group][event].subscribe( - async_fire_group_on_off_event, force_strong_ref=True - ) + for name_or_group, event in device.events.items(): + if isinstance(name_or_group, int): + for _, event in device.events[name_or_group].items(): + _register_event(event, async_fire_group_on_off_event) + else: + _register_event(event, async_fire_group_on_off_event) def register_new_device_callback(hass): diff --git a/homeassistant/components/konnected/__init__.py b/homeassistant/components/konnected/__init__.py index bd629d53fc6..119c7c946a5 100644 --- a/homeassistant/components/konnected/__init__.py +++ b/homeassistant/components/konnected/__init__.py @@ -84,7 +84,7 @@ def ensure_zone(value): if value is None: raise vol.Invalid("zone value is None") - if str(value) not in ZONES is None: + if str(value) not in ZONES: raise vol.Invalid("zone not valid") return str(value) diff --git a/homeassistant/components/litterrobot/sensor.py b/homeassistant/components/litterrobot/sensor.py index 4c63f1c3fa8..e7aed366fa3 100644 --- a/homeassistant/components/litterrobot/sensor.py +++ b/homeassistant/components/litterrobot/sensor.py @@ -140,7 +140,7 @@ ROBOT_SENSOR_MAP: dict[type[Robot], list[RobotSensorEntityDescription]] = { name="Pet weight", native_unit_of_measurement=UnitOfMass.POUNDS, device_class=SensorDeviceClass.WEIGHT, - state_class=SensorStateClass.TOTAL, + state_class=SensorStateClass.MEASUREMENT, ), ], FeederRobot: [ diff --git a/homeassistant/components/mobile_app/webhook.py b/homeassistant/components/mobile_app/webhook.py index c7fc375008a..90e244aaf06 100644 --- a/homeassistant/components/mobile_app/webhook.py +++ b/homeassistant/components/mobile_app/webhook.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio from collections.abc import Callable, Coroutine from contextlib import suppress -from functools import wraps +from functools import lru_cache, wraps from http import HTTPStatus import logging import secrets @@ -365,6 +365,12 @@ async def webhook_stream_camera( return webhook_response(resp, registration=config_entry.data) +@lru_cache +def _cached_template(template_str: str, hass: HomeAssistant) -> template.Template: + """Return a cached template.""" + return template.Template(template_str, hass) + + @WEBHOOK_COMMANDS.register("render_template") @validate_schema( { @@ -381,7 +387,7 @@ async def webhook_render_template( resp = {} for key, item in data.items(): try: - tpl = template.Template(item[ATTR_TEMPLATE], hass) + tpl = _cached_template(item[ATTR_TEMPLATE], hass) resp[key] = tpl.async_render(item.get(ATTR_TEMPLATE_VARIABLES)) except TemplateError as ex: resp[key] = {"error": str(ex)} diff --git a/homeassistant/components/recorder/history.py b/homeassistant/components/recorder/history.py index fb1a55cebfb..b67790f9a42 100644 --- a/homeassistant/components/recorder/history.py +++ b/homeassistant/components/recorder/history.py @@ -17,7 +17,6 @@ from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import literal from sqlalchemy.sql.lambdas import StatementLambdaElement -from sqlalchemy.sql.selectable import Subquery from homeassistant.const import COMPRESSED_STATE_LAST_UPDATED, COMPRESSED_STATE_STATE from homeassistant.core import HomeAssistant, State, split_entity_id @@ -592,48 +591,6 @@ def get_last_state_changes( ) -def _generate_most_recent_states_for_entities_by_date( - schema_version: int, - run_start: datetime, - utc_point_in_time: datetime, - entity_ids: list[str], -) -> Subquery: - """Generate the sub query for the most recent states for specific entities by date.""" - if schema_version >= 31: - run_start_ts = process_timestamp(run_start).timestamp() - utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time) - return ( - select( - States.entity_id.label("max_entity_id"), - # https://github.com/sqlalchemy/sqlalchemy/issues/9189 - # pylint: disable-next=not-callable - func.max(States.last_updated_ts).label("max_last_updated"), - ) - .filter( - (States.last_updated_ts >= run_start_ts) - & (States.last_updated_ts < utc_point_in_time_ts) - ) - .filter(States.entity_id.in_(entity_ids)) - .group_by(States.entity_id) - .subquery() - ) - return ( - select( - States.entity_id.label("max_entity_id"), - # https://github.com/sqlalchemy/sqlalchemy/issues/9189 - # pylint: disable-next=not-callable - func.max(States.last_updated).label("max_last_updated"), - ) - .filter( - (States.last_updated >= run_start) - & (States.last_updated < utc_point_in_time) - ) - .filter(States.entity_id.in_(entity_ids)) - .group_by(States.entity_id) - .subquery() - ) - - def _get_states_for_entities_stmt( schema_version: int, run_start: datetime, @@ -645,16 +602,29 @@ def _get_states_for_entities_stmt( stmt, join_attributes = lambda_stmt_and_join_attributes( schema_version, no_attributes, include_last_changed=True ) - most_recent_states_for_entities_by_date = ( - _generate_most_recent_states_for_entities_by_date( - schema_version, run_start, utc_point_in_time, entity_ids - ) - ) # We got an include-list of entities, accelerate the query by filtering already # in the inner query. if schema_version >= 31: + run_start_ts = process_timestamp(run_start).timestamp() + utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time) stmt += lambda q: q.join( - most_recent_states_for_entities_by_date, + ( + most_recent_states_for_entities_by_date := ( + select( + States.entity_id.label("max_entity_id"), + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + func.max(States.last_updated_ts).label("max_last_updated"), + ) + .filter( + (States.last_updated_ts >= run_start_ts) + & (States.last_updated_ts < utc_point_in_time_ts) + ) + .filter(States.entity_id.in_(entity_ids)) + .group_by(States.entity_id) + .subquery() + ) + ), and_( States.entity_id == most_recent_states_for_entities_by_date.c.max_entity_id, @@ -664,7 +634,21 @@ def _get_states_for_entities_stmt( ) else: stmt += lambda q: q.join( - most_recent_states_for_entities_by_date, + ( + most_recent_states_for_entities_by_date := select( + States.entity_id.label("max_entity_id"), + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + func.max(States.last_updated).label("max_last_updated"), + ) + .filter( + (States.last_updated >= run_start) + & (States.last_updated < utc_point_in_time) + ) + .filter(States.entity_id.in_(entity_ids)) + .group_by(States.entity_id) + .subquery() + ), and_( States.entity_id == most_recent_states_for_entities_by_date.c.max_entity_id, @@ -679,45 +663,6 @@ def _get_states_for_entities_stmt( return stmt -def _generate_most_recent_states_by_date( - schema_version: int, - run_start: datetime, - utc_point_in_time: datetime, -) -> Subquery: - """Generate the sub query for the most recent states by date.""" - if schema_version >= 31: - run_start_ts = process_timestamp(run_start).timestamp() - utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time) - return ( - select( - States.entity_id.label("max_entity_id"), - # https://github.com/sqlalchemy/sqlalchemy/issues/9189 - # pylint: disable-next=not-callable - func.max(States.last_updated_ts).label("max_last_updated"), - ) - .filter( - (States.last_updated_ts >= run_start_ts) - & (States.last_updated_ts < utc_point_in_time_ts) - ) - .group_by(States.entity_id) - .subquery() - ) - return ( - select( - States.entity_id.label("max_entity_id"), - # https://github.com/sqlalchemy/sqlalchemy/issues/9189 - # pylint: disable-next=not-callable - func.max(States.last_updated).label("max_last_updated"), - ) - .filter( - (States.last_updated >= run_start) - & (States.last_updated < utc_point_in_time) - ) - .group_by(States.entity_id) - .subquery() - ) - - def _get_states_for_all_stmt( schema_version: int, run_start: datetime, @@ -733,12 +678,26 @@ def _get_states_for_all_stmt( # query, then filter out unwanted domains as well as applying the custom filter. # This filtering can't be done in the inner query because the domain column is # not indexed and we can't control what's in the custom filter. - most_recent_states_by_date = _generate_most_recent_states_by_date( - schema_version, run_start, utc_point_in_time - ) if schema_version >= 31: + run_start_ts = process_timestamp(run_start).timestamp() + utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time) stmt += lambda q: q.join( - most_recent_states_by_date, + ( + most_recent_states_by_date := ( + select( + States.entity_id.label("max_entity_id"), + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + func.max(States.last_updated_ts).label("max_last_updated"), + ) + .filter( + (States.last_updated_ts >= run_start_ts) + & (States.last_updated_ts < utc_point_in_time_ts) + ) + .group_by(States.entity_id) + .subquery() + ) + ), and_( States.entity_id == most_recent_states_by_date.c.max_entity_id, States.last_updated_ts == most_recent_states_by_date.c.max_last_updated, @@ -746,7 +705,22 @@ def _get_states_for_all_stmt( ) else: stmt += lambda q: q.join( - most_recent_states_by_date, + ( + most_recent_states_by_date := ( + select( + States.entity_id.label("max_entity_id"), + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + func.max(States.last_updated).label("max_last_updated"), + ) + .filter( + (States.last_updated >= run_start) + & (States.last_updated < utc_point_in_time) + ) + .group_by(States.entity_id) + .subquery() + ) + ), and_( States.entity_id == most_recent_states_by_date.c.max_entity_id, States.last_updated == most_recent_states_by_date.c.max_last_updated, diff --git a/homeassistant/components/recorder/manifest.json b/homeassistant/components/recorder/manifest.json index f40f866808c..ed885127b1b 100644 --- a/homeassistant/components/recorder/manifest.json +++ b/homeassistant/components/recorder/manifest.json @@ -6,5 +6,5 @@ "integration_type": "system", "iot_class": "local_push", "quality_scale": "internal", - "requirements": ["sqlalchemy==2.0.4", "fnvhash==0.1.0"] + "requirements": ["sqlalchemy==2.0.5.post1", "fnvhash==0.1.0"] } diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 431bc78ba80..0b8fe9243ba 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -50,7 +50,7 @@ from .tasks import ( PostSchemaMigrationTask, StatisticsTimestampMigrationCleanupTask, ) -from .util import session_scope +from .util import database_job_retry_wrapper, session_scope if TYPE_CHECKING: from . import Recorder @@ -158,7 +158,9 @@ def migrate_schema( hass.add_job(instance.async_set_db_ready) new_version = version + 1 _LOGGER.info("Upgrading recorder db schema to version %s", new_version) - _apply_update(hass, engine, session_maker, new_version, current_version) + _apply_update( + instance, hass, engine, session_maker, new_version, current_version + ) with session_scope(session=session_maker()) as session: session.add(SchemaChanges(schema_version=new_version)) @@ -508,7 +510,9 @@ def _drop_foreign_key_constraints( ) +@database_job_retry_wrapper("Apply migration update", 10) def _apply_update( # noqa: C901 + instance: Recorder, hass: HomeAssistant, engine: Engine, session_maker: Callable[[], Session], @@ -922,7 +926,7 @@ def _apply_update( # noqa: C901 # There may be duplicated statistics entries, delete duplicates # and try again with session_scope(session=session_maker()) as session: - delete_statistics_duplicates(hass, session) + delete_statistics_duplicates(instance, hass, session) _migrate_statistics_columns_to_timestamp(session_maker, engine) # Log at error level to ensure the user sees this message in the log # since we logged the error above. @@ -965,7 +969,7 @@ def post_schema_migration( # since they are no longer used and take up a significant amount of space. assert instance.event_session is not None assert instance.engine is not None - _wipe_old_string_time_columns(instance.engine, instance.event_session) + _wipe_old_string_time_columns(instance, instance.engine, instance.event_session) if old_version < 35 <= new_version: # In version 34 we migrated all the created, start, and last_reset # columns to be timestamps. In version 34 we need to wipe the old columns @@ -978,7 +982,10 @@ def _wipe_old_string_statistics_columns(instance: Recorder) -> None: instance.queue_task(StatisticsTimestampMigrationCleanupTask()) -def _wipe_old_string_time_columns(engine: Engine, session: Session) -> None: +@database_job_retry_wrapper("Wipe old string time columns", 3) +def _wipe_old_string_time_columns( + instance: Recorder, engine: Engine, session: Session +) -> None: """Wipe old string time columns to save space.""" # Wipe Events.time_fired since its been replaced by Events.time_fired_ts # Wipe States.last_updated since its been replaced by States.last_updated_ts @@ -1162,7 +1169,7 @@ def _migrate_statistics_columns_to_timestamp( "last_reset_ts=" "UNIX_TIMESTAMP(last_reset) " "where start_ts is NULL " - "LIMIT 250000;" + "LIMIT 100000;" ) ) elif engine.dialect.name == SupportedDialect.POSTGRESQL: @@ -1180,7 +1187,7 @@ def _migrate_statistics_columns_to_timestamp( "created_ts=EXTRACT(EPOCH FROM created), " "last_reset_ts=EXTRACT(EPOCH FROM last_reset) " "where id IN ( " - f"SELECT id FROM {table} where start_ts is NULL LIMIT 250000 " + f"SELECT id FROM {table} where start_ts is NULL LIMIT 100000 " " );" ) ) diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 294c5217623..bd11744ab09 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -16,14 +16,13 @@ import re from statistics import mean from typing import TYPE_CHECKING, Any, Literal, cast -from sqlalchemy import and_, bindparam, func, lambda_stmt, select, text +from sqlalchemy import Select, and_, bindparam, func, lambda_stmt, select, text from sqlalchemy.engine import Engine from sqlalchemy.engine.row import Row from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import literal_column, true from sqlalchemy.sql.lambdas import StatementLambdaElement -from sqlalchemy.sql.selectable import Subquery import voluptuous as vol from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT @@ -75,6 +74,7 @@ from .models import ( datetime_to_timestamp_or_none, ) from .util import ( + database_job_retry_wrapper, execute, execute_stmt_lambda_element, get_instance, @@ -515,7 +515,10 @@ def _delete_duplicates_from_table( return (total_deleted_rows, all_non_identical_duplicates) -def delete_statistics_duplicates(hass: HomeAssistant, session: Session) -> None: +@database_job_retry_wrapper("delete statistics duplicates", 3) +def delete_statistics_duplicates( + instance: Recorder, hass: HomeAssistant, session: Session +) -> None: """Identify and delete duplicated statistics. A backup will be made of duplicated statistics before it is deleted. @@ -646,27 +649,19 @@ def _compile_hourly_statistics_summary_mean_stmt( ) -def _compile_hourly_statistics_last_sum_stmt_subquery( - start_time_ts: float, end_time_ts: float -) -> Subquery: - """Generate the summary mean statement for hourly statistics.""" - return ( - select(*QUERY_STATISTICS_SUMMARY_SUM) - .filter(StatisticsShortTerm.start_ts >= start_time_ts) - .filter(StatisticsShortTerm.start_ts < end_time_ts) - .subquery() - ) - - def _compile_hourly_statistics_last_sum_stmt( start_time_ts: float, end_time_ts: float ) -> StatementLambdaElement: """Generate the summary mean statement for hourly statistics.""" - subquery = _compile_hourly_statistics_last_sum_stmt_subquery( - start_time_ts, end_time_ts - ) return lambda_stmt( - lambda: select(subquery) + lambda: select( + subquery := ( + select(*QUERY_STATISTICS_SUMMARY_SUM) + .filter(StatisticsShortTerm.start_ts >= start_time_ts) + .filter(StatisticsShortTerm.start_ts < end_time_ts) + .subquery() + ) + ) .filter(subquery.c.rownum == 1) .order_by(subquery.c.metadata_id) ) @@ -1263,7 +1258,8 @@ def _reduce_statistics_per_month( ) -def _statistics_during_period_stmt( +def _generate_statistics_during_period_stmt( + columns: Select, start_time: datetime, end_time: datetime | None, metadata_ids: list[int] | None, @@ -1275,21 +1271,6 @@ def _statistics_during_period_stmt( This prepares a lambda_stmt query, so we don't insert the parameters yet. """ start_time_ts = start_time.timestamp() - - columns = select(table.metadata_id, table.start_ts) - if "last_reset" in types: - columns = columns.add_columns(table.last_reset_ts) - if "max" in types: - columns = columns.add_columns(table.max) - if "mean" in types: - columns = columns.add_columns(table.mean) - if "min" in types: - columns = columns.add_columns(table.min) - if "state" in types: - columns = columns.add_columns(table.state) - if "sum" in types: - columns = columns.add_columns(table.sum) - stmt = lambda_stmt(lambda: columns.filter(table.start_ts >= start_time_ts)) if end_time is not None: end_time_ts = end_time.timestamp() @@ -1303,6 +1284,23 @@ def _statistics_during_period_stmt( return stmt +def _generate_max_mean_min_statistic_in_sub_period_stmt( + columns: Select, + start_time: datetime | None, + end_time: datetime | None, + table: type[StatisticsBase], + metadata_id: int, +) -> StatementLambdaElement: + stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id)) + if start_time is not None: + start_time_ts = start_time.timestamp() + stmt += lambda q: q.filter(table.start_ts >= start_time_ts) + if end_time is not None: + end_time_ts = end_time.timestamp() + stmt += lambda q: q.filter(table.start_ts < end_time_ts) + return stmt + + def _get_max_mean_min_statistic_in_sub_period( session: Session, result: dict[str, float], @@ -1328,13 +1326,9 @@ def _get_max_mean_min_statistic_in_sub_period( # https://github.com/sqlalchemy/sqlalchemy/issues/9189 # pylint: disable-next=not-callable columns = columns.add_columns(func.min(table.min)) - stmt = lambda_stmt(lambda: columns.filter(table.metadata_id == metadata_id)) - if start_time is not None: - start_time_ts = start_time.timestamp() - stmt += lambda q: q.filter(table.start_ts >= start_time_ts) - if end_time is not None: - end_time_ts = end_time.timestamp() - stmt += lambda q: q.filter(table.start_ts < end_time_ts) + stmt = _generate_max_mean_min_statistic_in_sub_period_stmt( + columns, start_time, end_time, table, metadata_id + ) stats = cast(Sequence[Row[Any]], execute_stmt_lambda_element(session, stmt)) if not stats: return @@ -1749,8 +1743,21 @@ def _statistics_during_period_with_session( table: type[Statistics | StatisticsShortTerm] = ( Statistics if period != "5minute" else StatisticsShortTerm ) - stmt = _statistics_during_period_stmt( - start_time, end_time, metadata_ids, table, types + columns = select(table.metadata_id, table.start_ts) # type: ignore[call-overload] + if "last_reset" in types: + columns = columns.add_columns(table.last_reset_ts) + if "max" in types: + columns = columns.add_columns(table.max) + if "mean" in types: + columns = columns.add_columns(table.mean) + if "min" in types: + columns = columns.add_columns(table.min) + if "state" in types: + columns = columns.add_columns(table.state) + if "sum" in types: + columns = columns.add_columns(table.sum) + stmt = _generate_statistics_during_period_stmt( + columns, start_time, end_time, metadata_ids, table, types ) stats = cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) @@ -1915,28 +1922,24 @@ def get_last_short_term_statistics( ) -def _generate_most_recent_statistic_row(metadata_ids: list[int]) -> Subquery: - """Generate the subquery to find the most recent statistic row.""" - return ( - select( - StatisticsShortTerm.metadata_id, - # https://github.com/sqlalchemy/sqlalchemy/issues/9189 - # pylint: disable-next=not-callable - func.max(StatisticsShortTerm.start_ts).label("start_max"), - ) - .where(StatisticsShortTerm.metadata_id.in_(metadata_ids)) - .group_by(StatisticsShortTerm.metadata_id) - ).subquery() - - def _latest_short_term_statistics_stmt( metadata_ids: list[int], ) -> StatementLambdaElement: """Create the statement for finding the latest short term stat rows.""" stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SHORT_TERM)) - most_recent_statistic_row = _generate_most_recent_statistic_row(metadata_ids) stmt += lambda s: s.join( - most_recent_statistic_row, + ( + most_recent_statistic_row := ( + select( + StatisticsShortTerm.metadata_id, + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + func.max(StatisticsShortTerm.start_ts).label("start_max"), + ) + .where(StatisticsShortTerm.metadata_id.in_(metadata_ids)) + .group_by(StatisticsShortTerm.metadata_id) + ).subquery() + ), ( StatisticsShortTerm.metadata_id # pylint: disable=comparison-with-callable == most_recent_statistic_row.c.metadata_id @@ -1984,21 +1987,34 @@ def get_latest_short_term_statistics( ) -def _get_most_recent_statistics_subquery( - metadata_ids: set[int], table: type[StatisticsBase], start_time_ts: float -) -> Subquery: - """Generate the subquery to find the most recent statistic row.""" - return ( - select( - # https://github.com/sqlalchemy/sqlalchemy/issues/9189 - # pylint: disable-next=not-callable - func.max(table.start_ts).label("max_start_ts"), - table.metadata_id.label("max_metadata_id"), +def _generate_statistics_at_time_stmt( + columns: Select, + table: type[StatisticsBase], + metadata_ids: set[int], + start_time_ts: float, +) -> StatementLambdaElement: + """Create the statement for finding the statistics for a given time.""" + return lambda_stmt( + lambda: columns.join( + ( + most_recent_statistic_ids := ( + select( + # https://github.com/sqlalchemy/sqlalchemy/issues/9189 + # pylint: disable-next=not-callable + func.max(table.start_ts).label("max_start_ts"), + table.metadata_id.label("max_metadata_id"), + ) + .filter(table.start_ts < start_time_ts) + .filter(table.metadata_id.in_(metadata_ids)) + .group_by(table.metadata_id) + .subquery() + ) + ), + and_( + table.start_ts == most_recent_statistic_ids.c.max_start_ts, + table.metadata_id == most_recent_statistic_ids.c.max_metadata_id, + ), ) - .filter(table.start_ts < start_time_ts) - .filter(table.metadata_id.in_(metadata_ids)) - .group_by(table.metadata_id) - .subquery() ) @@ -2023,19 +2039,10 @@ def _statistics_at_time( columns = columns.add_columns(table.state) if "sum" in types: columns = columns.add_columns(table.sum) - start_time_ts = start_time.timestamp() - most_recent_statistic_ids = _get_most_recent_statistics_subquery( - metadata_ids, table, start_time_ts + stmt = _generate_statistics_at_time_stmt( + columns, table, metadata_ids, start_time_ts ) - stmt = lambda_stmt(lambda: columns).join( - most_recent_statistic_ids, - and_( - table.start_ts == most_recent_statistic_ids.c.max_start_ts, - table.metadata_id == most_recent_statistic_ids.c.max_metadata_id, - ), - ) - return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index 3ff6b62b21e..bfdd8ff5b14 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -568,6 +568,17 @@ def end_incomplete_runs(session: Session, start_time: datetime) -> None: session.add(run) +def _is_retryable_error(instance: Recorder, err: OperationalError) -> bool: + """Return True if the error is retryable.""" + assert instance.engine is not None + return bool( + instance.engine.dialect.name == SupportedDialect.MYSQL + and isinstance(err.orig, BaseException) + and err.orig.args + and err.orig.args[0] in RETRYABLE_MYSQL_ERRORS + ) + + _FuncType = Callable[Concatenate[_RecorderT, _P], bool] @@ -585,12 +596,8 @@ def retryable_database_job( try: return job(instance, *args, **kwargs) except OperationalError as err: - assert instance.engine is not None - if ( - instance.engine.dialect.name == SupportedDialect.MYSQL - and err.orig - and err.orig.args[0] in RETRYABLE_MYSQL_ERRORS - ): + if _is_retryable_error(instance, err): + assert isinstance(err.orig, BaseException) _LOGGER.info( "%s; %s not completed, retrying", err.orig.args[1], description ) @@ -608,6 +615,46 @@ def retryable_database_job( return decorator +_WrappedFuncType = Callable[Concatenate[_RecorderT, _P], None] + + +def database_job_retry_wrapper( + description: str, attempts: int = 5 +) -> Callable[[_WrappedFuncType[_RecorderT, _P]], _WrappedFuncType[_RecorderT, _P]]: + """Try to execute a database job multiple times. + + This wrapper handles InnoDB deadlocks and lock timeouts. + + This is different from retryable_database_job in that it will retry the job + attempts number of times instead of returning False if the job fails. + """ + + def decorator( + job: _WrappedFuncType[_RecorderT, _P] + ) -> _WrappedFuncType[_RecorderT, _P]: + @functools.wraps(job) + def wrapper(instance: _RecorderT, *args: _P.args, **kwargs: _P.kwargs) -> None: + for attempt in range(attempts): + try: + job(instance, *args, **kwargs) + return + except OperationalError as err: + if attempt == attempts - 1 or not _is_retryable_error( + instance, err + ): + raise + assert isinstance(err.orig, BaseException) + _LOGGER.info( + "%s; %s failed, retrying", err.orig.args[1], description + ) + time.sleep(instance.db_retry_wait) + # Failed with retryable error + + return wrapper + + return decorator + + def periodic_db_cleanups(instance: Recorder) -> None: """Run any database cleanups that need to happen periodically. diff --git a/homeassistant/components/reolink/number.py b/homeassistant/components/reolink/number.py index e9b692fffe6..7c807ddadc3 100644 --- a/homeassistant/components/reolink/number.py +++ b/homeassistant/components/reolink/number.py @@ -64,7 +64,7 @@ NUMBER_ENTITIES = ( get_max_value=lambda api, ch: api.zoom_range(ch)["focus"]["pos"]["max"], supported=lambda api, ch: api.zoom_supported(ch), value=lambda api, ch: api.get_focus(ch), - method=lambda api, ch, value: api.set_zoom(ch, int(value)), + method=lambda api, ch, value: api.set_focus(ch, int(value)), ), ) diff --git a/homeassistant/components/sfr_box/__init__.py b/homeassistant/components/sfr_box/__init__.py index 07f122fa4b2..4873acf753e 100644 --- a/homeassistant/components/sfr_box/__init__.py +++ b/homeassistant/components/sfr_box/__init__.py @@ -1,13 +1,11 @@ """SFR Box.""" from __future__ import annotations -import asyncio - from sfrbox_api.bridge import SFRBox from sfrbox_api.exceptions import SFRBoxAuthenticationError, SFRBoxError from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME +from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME, Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.helpers import device_registry as dr @@ -40,15 +38,17 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass, box, "system", lambda b: b.system_get_info() ), ) - tasks = [ - data.dsl.async_config_entry_first_refresh(), - data.system.async_config_entry_first_refresh(), - ] - await asyncio.gather(*tasks) + await data.system.async_config_entry_first_refresh() + system_info = data.system.data + + if system_info.net_infra == "adsl": + await data.dsl.async_config_entry_first_refresh() + else: + platforms = list(platforms) + platforms.remove(Platform.BINARY_SENSOR) hass.data.setdefault(DOMAIN, {})[entry.entry_id] = data - system_info = data.system.data device_registry = dr.async_get(hass) device_registry.async_get_or_create( config_entry_id=entry.entry_id, diff --git a/homeassistant/components/sfr_box/sensor.py b/homeassistant/components/sfr_box/sensor.py index f84441d2491..5f4aadce7e2 100644 --- a/homeassistant/components/sfr_box/sensor.py +++ b/homeassistant/components/sfr_box/sensor.py @@ -1,7 +1,6 @@ """SFR Box sensor platform.""" -from collections.abc import Callable, Iterable +from collections.abc import Callable from dataclasses import dataclass -from itertools import chain from typing import Generic, TypeVar from sfrbox_api.models import DslInfo, SystemInfo @@ -204,16 +203,15 @@ async def async_setup_entry( """Set up the sensors.""" data: DomainData = hass.data[DOMAIN][entry.entry_id] - entities: Iterable[SFRBoxSensor] = chain( - ( + entities: list[SFRBoxSensor] = [ + SFRBoxSensor(data.system, description, data.system.data) + for description in SYSTEM_SENSOR_TYPES + ] + if data.system.data.net_infra == "adsl": + entities.extend( SFRBoxSensor(data.dsl, description, data.system.data) for description in DSL_SENSOR_TYPES - ), - ( - SFRBoxSensor(data.system, description, data.system.data) - for description in SYSTEM_SENSOR_TYPES - ), - ) + ) async_add_entities(entities) diff --git a/homeassistant/components/snapcast/manifest.json b/homeassistant/components/snapcast/manifest.json index d69f06f6983..bdcadc84e7c 100644 --- a/homeassistant/components/snapcast/manifest.json +++ b/homeassistant/components/snapcast/manifest.json @@ -1,9 +1,9 @@ { "domain": "snapcast", "name": "Snapcast", - "codeowners": [], + "codeowners": ["@luar123"], "documentation": "https://www.home-assistant.io/integrations/snapcast", "iot_class": "local_polling", "loggers": ["construct", "snapcast"], - "requirements": ["snapcast==2.3.0"] + "requirements": ["snapcast==2.3.2"] } diff --git a/homeassistant/components/sql/manifest.json b/homeassistant/components/sql/manifest.json index e3efa81e44a..bdedbb9b207 100644 --- a/homeassistant/components/sql/manifest.json +++ b/homeassistant/components/sql/manifest.json @@ -5,5 +5,5 @@ "config_flow": true, "documentation": "https://www.home-assistant.io/integrations/sql", "iot_class": "local_polling", - "requirements": ["sqlalchemy==2.0.4"] + "requirements": ["sqlalchemy==2.0.5.post1"] } diff --git a/homeassistant/components/thread/diagnostics.py b/homeassistant/components/thread/diagnostics.py index b945f818d00..eb1e2a5ef68 100644 --- a/homeassistant/components/thread/diagnostics.py +++ b/homeassistant/components/thread/diagnostics.py @@ -17,9 +17,8 @@ some of their thread accessories can't be pinged, but it's still a thread proble from __future__ import annotations -from typing import Any, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict -from pyroute2 import NDB # pylint: disable=no-name-in-module from python_otbr_api.tlv_parser import MeshcopTLVType from homeassistant.components import zeroconf @@ -29,6 +28,9 @@ from homeassistant.core import HomeAssistant from .dataset_store import async_get_store from .discovery import async_read_zeroconf_cache +if TYPE_CHECKING: + from pyroute2 import NDB # pylint: disable=no-name-in-module + class Neighbour(TypedDict): """A neighbour cache entry (ip neigh).""" @@ -67,58 +69,69 @@ class Network(TypedDict): unexpected_routers: set[str] -def _get_possible_thread_routes() -> ( - tuple[dict[str, dict[str, Route]], dict[str, set[str]]] -): +def _get_possible_thread_routes( + ndb: NDB, +) -> tuple[dict[str, dict[str, Route]], dict[str, set[str]]]: # Build a list of possible thread routes # Right now, this is ipv6 /64's that have a gateway # We cross reference with zerconf data to confirm which via's are known border routers routes: dict[str, dict[str, Route]] = {} reverse_routes: dict[str, set[str]] = {} - with NDB() as ndb: - for record in ndb.routes: - # Limit to IPV6 routes - if record.family != 10: - continue - # Limit to /64 prefixes - if record.dst_len != 64: - continue - # Limit to routes with a via - if not record.gateway and not record.nh_gateway: - continue - gateway = record.gateway or record.nh_gateway - route = routes.setdefault(gateway, {}) - route[record.dst] = { - "metrics": record.metrics, - "priority": record.priority, - # NM creates "nexthop" routes - a single route with many via's - # Kernel creates many routes with a single via - "is_nexthop": record.nh_gateway is not None, - } - reverse_routes.setdefault(record.dst, set()).add(gateway) + for record in ndb.routes: + # Limit to IPV6 routes + if record.family != 10: + continue + # Limit to /64 prefixes + if record.dst_len != 64: + continue + # Limit to routes with a via + if not record.gateway and not record.nh_gateway: + continue + gateway = record.gateway or record.nh_gateway + route = routes.setdefault(gateway, {}) + route[record.dst] = { + "metrics": record.metrics, + "priority": record.priority, + # NM creates "nexthop" routes - a single route with many via's + # Kernel creates many routes with a single via + "is_nexthop": record.nh_gateway is not None, + } + reverse_routes.setdefault(record.dst, set()).add(gateway) return routes, reverse_routes -def _get_neighbours() -> dict[str, Neighbour]: - neighbours: dict[str, Neighbour] = {} - - with NDB() as ndb: - for record in ndb.neighbours: - neighbours[record.dst] = { - "lladdr": record.lladdr, - "state": record.state, - "probes": record.probes, - } - +def _get_neighbours(ndb: NDB) -> dict[str, Neighbour]: + # Build a list of neighbours + neighbours: dict[str, Neighbour] = { + record.dst: { + "lladdr": record.lladdr, + "state": record.state, + "probes": record.probes, + } + for record in ndb.neighbours + } return neighbours +def _get_routes_and_neighbors(): + """Get the routes and neighbours from pyroute2.""" + # Import in the executor since import NDB can take a while + from pyroute2 import ( # pylint: disable=no-name-in-module, import-outside-toplevel + NDB, + ) + + with NDB() as ndb: # pylint: disable=not-callable + routes, reverse_routes = _get_possible_thread_routes(ndb) + neighbours = _get_neighbours(ndb) + + return routes, reverse_routes, neighbours + + async def async_get_config_entry_diagnostics( hass: HomeAssistant, entry: ConfigEntry ) -> dict[str, Any]: """Return diagnostics for all known thread networks.""" - networks: dict[str, Network] = {} # Start with all networks that HA knows about @@ -140,13 +153,12 @@ async def async_get_config_entry_diagnostics( # Find all routes currently act that might be thread related, so we can match them to # border routers as we process the zeroconf data. - routes, reverse_routes = await hass.async_add_executor_job( - _get_possible_thread_routes + # + # Also find all neighbours + routes, reverse_routes, neighbours = await hass.async_add_executor_job( + _get_routes_and_neighbors ) - # Find all neighbours - neighbours = await hass.async_add_executor_job(_get_neighbours) - aiozc = await zeroconf.async_get_async_instance(hass) for data in async_read_zeroconf_cache(aiozc): if not data.extended_pan_id: diff --git a/homeassistant/components/tplink_omada/config_flow.py b/homeassistant/components/tplink_omada/config_flow.py index 6b958b7d258..f6a75abe6d8 100644 --- a/homeassistant/components/tplink_omada/config_flow.py +++ b/homeassistant/components/tplink_omada/config_flow.py @@ -3,9 +3,12 @@ from __future__ import annotations from collections.abc import Mapping import logging +import re from types import MappingProxyType from typing import Any, NamedTuple +from urllib.parse import urlsplit +from aiohttp import CookieJar from tplink_omada_client.exceptions import ( ConnectionFailed, LoginFailed, @@ -20,7 +23,10 @@ from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME, CONF_VE from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers import selector -from homeassistant.helpers.aiohttp_client import async_get_clientsession +from homeassistant.helpers.aiohttp_client import ( + async_create_clientsession, + async_get_clientsession, +) from .const import DOMAIN @@ -42,11 +48,26 @@ async def create_omada_client( hass: HomeAssistant, data: MappingProxyType[str, Any] ) -> OmadaClient: """Create a TP-Link Omada client API for the given config entry.""" - host = data[CONF_HOST] + + host: str = data[CONF_HOST] verify_ssl = bool(data[CONF_VERIFY_SSL]) + + if not host.lower().startswith(("http://", "https://")): + host = "https://" + host + host_parts = urlsplit(host) + if ( + host_parts.hostname + and re.fullmatch(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", host_parts.hostname) + is not None + ): + # TP-Link API uses cookies for login session, so an unsafe cookie jar is required for IP addresses + websession = async_create_clientsession(hass, cookie_jar=CookieJar(unsafe=True)) + else: + websession = async_get_clientsession(hass, verify_ssl=verify_ssl) + username = data[CONF_USERNAME] password = data[CONF_PASSWORD] - websession = async_get_clientsession(hass, verify_ssl=verify_ssl) + return OmadaClient(host, username, password, websession=websession) diff --git a/homeassistant/components/tuya/light.py b/homeassistant/components/tuya/light.py index 1a2d0c526d0..ffc00e6f92c 100644 --- a/homeassistant/components/tuya/light.py +++ b/homeassistant/components/tuya/light.py @@ -1,7 +1,7 @@ """Support for the Tuya lights.""" from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field import json from typing import Any, cast @@ -59,7 +59,9 @@ class TuyaLightEntityDescription(LightEntityDescription): color_data: DPCode | tuple[DPCode, ...] | None = None color_mode: DPCode | None = None color_temp: DPCode | tuple[DPCode, ...] | None = None - default_color_type: ColorTypeData = DEFAULT_COLOR_TYPE_DATA + default_color_type: ColorTypeData = field( + default_factory=lambda: DEFAULT_COLOR_TYPE_DATA + ) LIGHTS: dict[str, tuple[TuyaLightEntityDescription, ...]] = { diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index e8008eb49b6..fa5c6aac294 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -4,6 +4,7 @@ from __future__ import annotations from collections.abc import Callable from contextlib import suppress import datetime as dt +from functools import lru_cache import json from typing import Any, cast @@ -424,6 +425,12 @@ def handle_ping( connection.send_message(pong_message(msg["id"])) +@lru_cache +def _cached_template(template_str: str, hass: HomeAssistant) -> template.Template: + """Return a cached template.""" + return template.Template(template_str, hass) + + @decorators.websocket_command( { vol.Required("type"): "render_template", @@ -440,7 +447,7 @@ async def handle_render_template( ) -> None: """Handle render_template command.""" template_str = msg["template"] - template_obj = template.Template(template_str, hass) + template_obj = _cached_template(template_str, hass) variables = msg.get("variables") timeout = msg.get("timeout") info = None diff --git a/homeassistant/components/zha/__init__.py b/homeassistant/components/zha/__init__.py index d32dcf0bda6..d0496fe7b60 100644 --- a/homeassistant/components/zha/__init__.py +++ b/homeassistant/components/zha/__init__.py @@ -1,5 +1,6 @@ """Support for Zigbee Home Automation devices.""" import asyncio +import copy import logging import os @@ -90,6 +91,15 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b Will automatically load components to support devices found on the network. """ + # Strip whitespace around `socket://` URIs, this is no longer accepted by zigpy + # This will be removed in 2023.7.0 + path = config_entry.data[CONF_DEVICE][CONF_DEVICE_PATH] + data = copy.deepcopy(dict(config_entry.data)) + + if path.startswith("socket://") and path != path.strip(): + data[CONF_DEVICE][CONF_DEVICE_PATH] = path.strip() + hass.config_entries.async_update_entry(config_entry, data=data) + zha_data = hass.data.setdefault(DATA_ZHA, {}) config = zha_data.get(DATA_ZHA_CONFIG, {}) diff --git a/homeassistant/const.py b/homeassistant/const.py index 1ec896a415f..e90ccad63e5 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -8,7 +8,7 @@ from .backports.enum import StrEnum APPLICATION_NAME: Final = "HomeAssistant" MAJOR_VERSION: Final = 2023 MINOR_VERSION: Final = 3 -PATCH_VERSION: Final = "1" +PATCH_VERSION: Final = "2" __short_version__: Final = f"{MAJOR_VERSION}.{MINOR_VERSION}" __version__: Final = f"{__short_version__}.{PATCH_VERSION}" REQUIRED_PYTHON_VER: Final[tuple[int, int, int]] = (3, 10, 0) diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt index 3a199853634..4e4786e7edc 100644 --- a/homeassistant/package_constraints.txt +++ b/homeassistant/package_constraints.txt @@ -23,7 +23,7 @@ fnvhash==0.1.0 hass-nabucasa==0.61.0 hassil==1.0.6 home-assistant-bluetooth==1.9.3 -home-assistant-frontend==20230302.0 +home-assistant-frontend==20230306.0 home-assistant-intents==2023.2.28 httpx==0.23.3 ifaddr==0.1.7 @@ -42,7 +42,7 @@ pyudev==0.23.2 pyyaml==6.0 requests==2.28.2 scapy==2.5.0 -sqlalchemy==2.0.4 +sqlalchemy==2.0.5.post1 typing-extensions>=4.5.0,<5.0 voluptuous-serialize==2.6.0 voluptuous==0.13.1 diff --git a/pyproject.toml b/pyproject.toml index 1a81cc5f502..34e88267645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "homeassistant" -version = "2023.3.1" +version = "2023.3.2" license = {text = "Apache-2.0"} description = "Open-source home automation platform running on Python 3." readme = "README.rst" diff --git a/requirements_all.txt b/requirements_all.txt index d11e4d66026..fb17dc6d4da 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -156,7 +156,7 @@ aioecowitt==2023.01.0 aioemonitor==1.0.5 # homeassistant.components.esphome -aioesphomeapi==13.4.1 +aioesphomeapi==13.4.2 # homeassistant.components.flo aioflo==2021.11.0 @@ -907,7 +907,7 @@ hole==0.8.0 holidays==0.18.0 # homeassistant.components.frontend -home-assistant-frontend==20230302.0 +home-assistant-frontend==20230306.0 # homeassistant.components.conversation home-assistant-intents==2023.2.28 @@ -979,7 +979,7 @@ influxdb==5.3.1 inkbird-ble==0.5.6 # homeassistant.components.insteon -insteon-frontend-home-assistant==0.3.2 +insteon-frontend-home-assistant==0.3.3 # homeassistant.components.intellifire intellifire4py==2.2.2 @@ -1621,7 +1621,7 @@ pyevilgenius==2.0.0 pyezviz==0.2.0.9 # homeassistant.components.fibaro -pyfibaro==0.6.8 +pyfibaro==0.6.9 # homeassistant.components.fido pyfido==2.1.2 @@ -1687,7 +1687,7 @@ pyialarm==2.2.0 pyicloud==1.0.0 # homeassistant.components.insteon -pyinsteon==1.3.3 +pyinsteon==1.3.4 # homeassistant.components.intesishome pyintesishome==1.8.0 @@ -2367,7 +2367,7 @@ smart-meter-texas==0.4.7 smhi-pkg==1.0.16 # homeassistant.components.snapcast -snapcast==2.3.0 +snapcast==2.3.2 # homeassistant.components.sonos soco==0.29.1 @@ -2398,7 +2398,7 @@ spotipy==2.22.1 # homeassistant.components.recorder # homeassistant.components.sql -sqlalchemy==2.0.4 +sqlalchemy==2.0.5.post1 # homeassistant.components.srp_energy srpenergy==1.3.6 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 22d6a64361e..bd5fe240e34 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -143,7 +143,7 @@ aioecowitt==2023.01.0 aioemonitor==1.0.5 # homeassistant.components.esphome -aioesphomeapi==13.4.1 +aioesphomeapi==13.4.2 # homeassistant.components.flo aioflo==2021.11.0 @@ -690,7 +690,7 @@ hole==0.8.0 holidays==0.18.0 # homeassistant.components.frontend -home-assistant-frontend==20230302.0 +home-assistant-frontend==20230306.0 # homeassistant.components.conversation home-assistant-intents==2023.2.28 @@ -738,7 +738,7 @@ influxdb==5.3.1 inkbird-ble==0.5.6 # homeassistant.components.insteon -insteon-frontend-home-assistant==0.3.2 +insteon-frontend-home-assistant==0.3.3 # homeassistant.components.intellifire intellifire4py==2.2.2 @@ -1161,7 +1161,7 @@ pyevilgenius==2.0.0 pyezviz==0.2.0.9 # homeassistant.components.fibaro -pyfibaro==0.6.8 +pyfibaro==0.6.9 # homeassistant.components.fido pyfido==2.1.2 @@ -1212,7 +1212,7 @@ pyialarm==2.2.0 pyicloud==1.0.0 # homeassistant.components.insteon -pyinsteon==1.3.3 +pyinsteon==1.3.4 # homeassistant.components.ipma pyipma==3.0.6 @@ -1698,7 +1698,7 @@ spotipy==2.22.1 # homeassistant.components.recorder # homeassistant.components.sql -sqlalchemy==2.0.4 +sqlalchemy==2.0.5.post1 # homeassistant.components.srp_energy srpenergy==1.3.6 diff --git a/tests/components/api/test_init.py b/tests/components/api/test_init.py index 570bb980aba..61da000fc07 100644 --- a/tests/components/api/test_init.py +++ b/tests/components/api/test_init.py @@ -349,6 +349,52 @@ async def test_api_template(hass: HomeAssistant, mock_api_client: TestClient) -> assert body == "10" + hass.states.async_set("sensor.temperature", 20) + resp = await mock_api_client.post( + const.URL_API_TEMPLATE, + json={"template": "{{ states.sensor.temperature.state }}"}, + ) + + body = await resp.text() + + assert body == "20" + + hass.states.async_remove("sensor.temperature") + resp = await mock_api_client.post( + const.URL_API_TEMPLATE, + json={"template": "{{ states.sensor.temperature.state }}"}, + ) + + body = await resp.text() + + assert body == "" + + +async def test_api_template_cached( + hass: HomeAssistant, mock_api_client: TestClient +) -> None: + """Test the template API uses the cache.""" + hass.states.async_set("sensor.temperature", 30) + + resp = await mock_api_client.post( + const.URL_API_TEMPLATE, + json={"template": "{{ states.sensor.temperature.state }}"}, + ) + + body = await resp.text() + + assert body == "30" + + hass.states.async_set("sensor.temperature", 40) + resp = await mock_api_client.post( + const.URL_API_TEMPLATE, + json={"template": "{{ states.sensor.temperature.state }}"}, + ) + + body = await resp.text() + + assert body == "40" + async def test_api_template_error( hass: HomeAssistant, mock_api_client: TestClient diff --git a/tests/components/hassio/conftest.py b/tests/components/hassio/conftest.py index a6cd956c95e..78ae9643d68 100644 --- a/tests/components/hassio/conftest.py +++ b/tests/components/hassio/conftest.py @@ -1,5 +1,6 @@ """Fixtures for Hass.io.""" import os +import re from unittest.mock import Mock, patch import pytest @@ -12,6 +13,16 @@ from homeassistant.setup import async_setup_component from . import SUPERVISOR_TOKEN +@pytest.fixture(autouse=True) +def disable_security_filter(): + """Disable the security filter to ensure the integration is secure.""" + with patch( + "homeassistant.components.http.security_filter.FILTERS", + re.compile("not-matching-anything"), + ): + yield + + @pytest.fixture def hassio_env(): """Fixture to inject hassio env.""" @@ -37,6 +48,13 @@ def hassio_stubs(hassio_env, hass, hass_client, aioclient_mock): ), patch( "homeassistant.components.hassio.HassIO.get_info", side_effect=HassioAPIError(), + ), patch( + "homeassistant.components.hassio.HassIO.get_ingress_panels", + return_value={"panels": []}, + ), patch( + "homeassistant.components.hassio.repairs.SupervisorRepairs.setup" + ), patch( + "homeassistant.components.hassio.HassIO.refresh_updates" ): hass.state = CoreState.starting hass.loop.run_until_complete(async_setup_component(hass, "hassio", {})) @@ -67,13 +85,7 @@ async def hassio_client_supervisor(hass, aiohttp_client, hassio_stubs): @pytest.fixture -def hassio_handler(hass, aioclient_mock): +async def hassio_handler(hass, aioclient_mock): """Create mock hassio handler.""" - - async def get_client_session(): - return async_get_clientsession(hass) - - websession = hass.loop.run_until_complete(get_client_session()) - with patch.dict(os.environ, {"SUPERVISOR_TOKEN": SUPERVISOR_TOKEN}): - yield HassIO(hass.loop, websession, "127.0.0.1") + yield HassIO(hass.loop, async_get_clientsession(hass), "127.0.0.1") diff --git a/tests/components/hassio/test_handler.py b/tests/components/hassio/test_handler.py index ee23d5d350e..64e9e1c31cc 100644 --- a/tests/components/hassio/test_handler.py +++ b/tests/components/hassio/test_handler.py @@ -1,13 +1,21 @@ """The tests for the hassio component.""" +from __future__ import annotations + +from typing import Any, Literal + import aiohttp +from aiohttp import hdrs, web import pytest -from homeassistant.components.hassio.handler import HassioAPIError +from homeassistant.components.hassio.handler import HassIO, HassioAPIError +from homeassistant.helpers.aiohttp_client import async_get_clientsession from tests.test_util.aiohttp import AiohttpClientMocker -async def test_api_ping(hassio_handler, aioclient_mock: AiohttpClientMocker) -> None: +async def test_api_ping( + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker +) -> None: """Test setup with API ping.""" aioclient_mock.get("http://127.0.0.1/supervisor/ping", json={"result": "ok"}) @@ -16,7 +24,7 @@ async def test_api_ping(hassio_handler, aioclient_mock: AiohttpClientMocker) -> async def test_api_ping_error( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API ping error.""" aioclient_mock.get("http://127.0.0.1/supervisor/ping", json={"result": "error"}) @@ -26,7 +34,7 @@ async def test_api_ping_error( async def test_api_ping_exeption( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API ping exception.""" aioclient_mock.get("http://127.0.0.1/supervisor/ping", exc=aiohttp.ClientError()) @@ -35,7 +43,9 @@ async def test_api_ping_exeption( assert aioclient_mock.call_count == 1 -async def test_api_info(hassio_handler, aioclient_mock: AiohttpClientMocker) -> None: +async def test_api_info( + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker +) -> None: """Test setup with API generic info.""" aioclient_mock.get( "http://127.0.0.1/info", @@ -53,7 +63,7 @@ async def test_api_info(hassio_handler, aioclient_mock: AiohttpClientMocker) -> async def test_api_info_error( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API Home Assistant info error.""" aioclient_mock.get( @@ -67,7 +77,7 @@ async def test_api_info_error( async def test_api_host_info( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API Host info.""" aioclient_mock.get( @@ -90,7 +100,7 @@ async def test_api_host_info( async def test_api_supervisor_info( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API Supervisor info.""" aioclient_mock.get( @@ -108,7 +118,9 @@ async def test_api_supervisor_info( assert data["channel"] == "stable" -async def test_api_os_info(hassio_handler, aioclient_mock: AiohttpClientMocker) -> None: +async def test_api_os_info( + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker +) -> None: """Test setup with API OS info.""" aioclient_mock.get( "http://127.0.0.1/os/info", @@ -125,7 +137,7 @@ async def test_api_os_info(hassio_handler, aioclient_mock: AiohttpClientMocker) async def test_api_host_info_error( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API Home Assistant info error.""" aioclient_mock.get( @@ -139,7 +151,7 @@ async def test_api_host_info_error( async def test_api_core_info( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API Home Assistant Core info.""" aioclient_mock.get( @@ -153,7 +165,7 @@ async def test_api_core_info( async def test_api_core_info_error( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API Home Assistant Core info error.""" aioclient_mock.get( @@ -167,7 +179,7 @@ async def test_api_core_info_error( async def test_api_homeassistant_stop( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API Home Assistant stop.""" aioclient_mock.post("http://127.0.0.1/homeassistant/stop", json={"result": "ok"}) @@ -177,7 +189,7 @@ async def test_api_homeassistant_stop( async def test_api_homeassistant_restart( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API Home Assistant restart.""" aioclient_mock.post("http://127.0.0.1/homeassistant/restart", json={"result": "ok"}) @@ -187,7 +199,7 @@ async def test_api_homeassistant_restart( async def test_api_addon_info( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API Add-on info.""" aioclient_mock.get( @@ -201,7 +213,7 @@ async def test_api_addon_info( async def test_api_addon_stats( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API Add-on stats.""" aioclient_mock.get( @@ -215,7 +227,7 @@ async def test_api_addon_stats( async def test_api_discovery_message( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API discovery message.""" aioclient_mock.get( @@ -229,7 +241,7 @@ async def test_api_discovery_message( async def test_api_retrieve_discovery( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API discovery message.""" aioclient_mock.get( @@ -243,7 +255,7 @@ async def test_api_retrieve_discovery( async def test_api_ingress_panels( - hassio_handler, aioclient_mock: AiohttpClientMocker + hassio_handler: HassIO, aioclient_mock: AiohttpClientMocker ) -> None: """Test setup with API Ingress panels.""" aioclient_mock.get( @@ -267,3 +279,56 @@ async def test_api_ingress_panels( assert aioclient_mock.call_count == 1 assert data["panels"] assert "slug" in data["panels"] + + +@pytest.mark.parametrize( + ("api_call", "method", "payload"), + [ + ["retrieve_discovery_messages", "GET", None], + ["refresh_updates", "POST", None], + ["update_diagnostics", "POST", True], + ], +) +async def test_api_headers( + hass, + aiohttp_raw_server, + socket_enabled, + api_call: str, + method: Literal["GET", "POST"], + payload: Any, +) -> None: + """Test headers are forwarded correctly.""" + received_request = None + + async def mock_handler(request): + """Return OK.""" + nonlocal received_request + received_request = request + return web.json_response({"result": "ok", "data": None}) + + server = await aiohttp_raw_server(mock_handler) + hassio_handler = HassIO( + hass.loop, + async_get_clientsession(hass), + f"{server.host}:{server.port}", + ) + + api_func = getattr(hassio_handler, api_call) + if payload: + await api_func(payload) + else: + await api_func() + assert received_request is not None + + assert received_request.method == method + assert received_request.headers.get("X-Hass-Source") == "core.handler" + + if method == "GET": + assert hdrs.CONTENT_TYPE not in received_request.headers + return + + assert hdrs.CONTENT_TYPE in received_request.headers + if payload: + assert received_request.headers[hdrs.CONTENT_TYPE] == "application/json" + else: + assert received_request.headers[hdrs.CONTENT_TYPE] == "application/octet-stream" diff --git a/tests/components/hassio/test_http.py b/tests/components/hassio/test_http.py index 8ef6fa4001b..cb1dd639ec6 100644 --- a/tests/components/hassio/test_http.py +++ b/tests/components/hassio/test_http.py @@ -1,63 +1,45 @@ """The tests for the hassio component.""" import asyncio from http import HTTPStatus +from unittest.mock import patch from aiohttp import StreamReader import pytest -from homeassistant.components.hassio.http import _need_auth -from homeassistant.core import HomeAssistant - -from tests.common import MockUser from tests.test_util.aiohttp import AiohttpClientMocker -async def test_forward_request( - hassio_client, aioclient_mock: AiohttpClientMocker -) -> None: - """Test fetching normal path.""" - aioclient_mock.post("http://127.0.0.1/beer", text="response") +@pytest.fixture +def mock_not_onboarded(): + """Mock that we're not onboarded.""" + with patch( + "homeassistant.components.hassio.http.async_is_onboarded", return_value=False + ): + yield - resp = await hassio_client.post("/api/hassio/beer") - # Check we got right response - assert resp.status == HTTPStatus.OK - body = await resp.text() - assert body == "response" - - # Check we forwarded command - assert len(aioclient_mock.mock_calls) == 1 +@pytest.fixture +def hassio_user_client(hassio_client, hass_admin_user): + """Return a Hass.io HTTP client tied to a non-admin user.""" + hass_admin_user.groups = [] + return hassio_client @pytest.mark.parametrize( - "build_type", ["supervisor/info", "homeassistant/update", "host/info"] -) -async def test_auth_required_forward_request(hassio_noauth_client, build_type) -> None: - """Test auth required for normal request.""" - resp = await hassio_noauth_client.post(f"/api/hassio/{build_type}") - - # Check we got right response - assert resp.status == HTTPStatus.UNAUTHORIZED - - -@pytest.mark.parametrize( - "build_type", + "path", [ - "app/index.html", - "app/hassio-app.html", - "app/index.html", - "app/hassio-app.html", - "app/some-chunk.js", - "app/app.js", + "app/entrypoint.js", + "addons/bl_b392/logo", + "addons/bl_b392/icon", ], ) -async def test_forward_request_no_auth_for_panel( - hassio_client, build_type, aioclient_mock: AiohttpClientMocker +async def test_forward_request_onboarded_user_get( + hassio_user_client, aioclient_mock: AiohttpClientMocker, path: str ) -> None: - """Test no auth needed for .""" - aioclient_mock.get(f"http://127.0.0.1/{build_type}", text="response") + """Test fetching normal path.""" + aioclient_mock.get(f"http://127.0.0.1/{path}", text="response") - resp = await hassio_client.get(f"/api/hassio/{build_type}") + resp = await hassio_user_client.get(f"/api/hassio/{path}") # Check we got right response assert resp.status == HTTPStatus.OK @@ -66,15 +48,68 @@ async def test_forward_request_no_auth_for_panel( # Check we forwarded command assert len(aioclient_mock.mock_calls) == 1 + # We only expect a single header. + assert aioclient_mock.mock_calls[0][3] == {"X-Hass-Source": "core.http"} -async def test_forward_request_no_auth_for_logo( - hassio_client, aioclient_mock: AiohttpClientMocker +@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE", "RANDOM"]) +async def test_forward_request_onboarded_user_unallowed_methods( + hassio_user_client, aioclient_mock: AiohttpClientMocker, method: str ) -> None: - """Test no auth needed for logo.""" - aioclient_mock.get("http://127.0.0.1/addons/bl_b392/logo", text="response") + """Test fetching normal path.""" + resp = await hassio_user_client.post("/api/hassio/app/entrypoint.js") - resp = await hassio_client.get("/api/hassio/addons/bl_b392/logo") + # Check we got right response + assert resp.status == HTTPStatus.METHOD_NOT_ALLOWED + + # Check we did not forward command + assert len(aioclient_mock.mock_calls) == 0 + + +@pytest.mark.parametrize( + ("bad_path", "expected_status"), + [ + # Caught by bullshit filter + ("app/%252E./entrypoint.js", HTTPStatus.BAD_REQUEST), + # The .. is processed, making it an unauthenticated path + ("app/../entrypoint.js", HTTPStatus.UNAUTHORIZED), + ("app/%2E%2E/entrypoint.js", HTTPStatus.UNAUTHORIZED), + # Unauthenticated path + ("supervisor/info", HTTPStatus.UNAUTHORIZED), + ("supervisor/logs", HTTPStatus.UNAUTHORIZED), + ("addons/bl_b392/logs", HTTPStatus.UNAUTHORIZED), + ], +) +async def test_forward_request_onboarded_user_unallowed_paths( + hassio_user_client, + aioclient_mock: AiohttpClientMocker, + bad_path: str, + expected_status: int, +) -> None: + """Test fetching normal path.""" + resp = await hassio_user_client.get(f"/api/hassio/{bad_path}") + + # Check we got right response + assert resp.status == expected_status + # Check we didn't forward command + assert len(aioclient_mock.mock_calls) == 0 + + +@pytest.mark.parametrize( + "path", + [ + "app/entrypoint.js", + "addons/bl_b392/logo", + "addons/bl_b392/icon", + ], +) +async def test_forward_request_onboarded_noauth_get( + hassio_noauth_client, aioclient_mock: AiohttpClientMocker, path: str +) -> None: + """Test fetching normal path.""" + aioclient_mock.get(f"http://127.0.0.1/{path}", text="response") + + resp = await hassio_noauth_client.get(f"/api/hassio/{path}") # Check we got right response assert resp.status == HTTPStatus.OK @@ -83,15 +118,73 @@ async def test_forward_request_no_auth_for_logo( # Check we forwarded command assert len(aioclient_mock.mock_calls) == 1 + # We only expect a single header. + assert aioclient_mock.mock_calls[0][3] == {"X-Hass-Source": "core.http"} -async def test_forward_request_no_auth_for_icon( - hassio_client, aioclient_mock: AiohttpClientMocker +@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE", "RANDOM"]) +async def test_forward_request_onboarded_noauth_unallowed_methods( + hassio_noauth_client, aioclient_mock: AiohttpClientMocker, method: str ) -> None: - """Test no auth needed for icon.""" - aioclient_mock.get("http://127.0.0.1/addons/bl_b392/icon", text="response") + """Test fetching normal path.""" + resp = await hassio_noauth_client.post("/api/hassio/app/entrypoint.js") - resp = await hassio_client.get("/api/hassio/addons/bl_b392/icon") + # Check we got right response + assert resp.status == HTTPStatus.METHOD_NOT_ALLOWED + + # Check we did not forward command + assert len(aioclient_mock.mock_calls) == 0 + + +@pytest.mark.parametrize( + ("bad_path", "expected_status"), + [ + # Caught by bullshit filter + ("app/%252E./entrypoint.js", HTTPStatus.BAD_REQUEST), + # The .. is processed, making it an unauthenticated path + ("app/../entrypoint.js", HTTPStatus.UNAUTHORIZED), + ("app/%2E%2E/entrypoint.js", HTTPStatus.UNAUTHORIZED), + # Unauthenticated path + ("supervisor/info", HTTPStatus.UNAUTHORIZED), + ("supervisor/logs", HTTPStatus.UNAUTHORIZED), + ("addons/bl_b392/logs", HTTPStatus.UNAUTHORIZED), + ], +) +async def test_forward_request_onboarded_noauth_unallowed_paths( + hassio_noauth_client, + aioclient_mock: AiohttpClientMocker, + bad_path: str, + expected_status: int, +) -> None: + """Test fetching normal path.""" + resp = await hassio_noauth_client.get(f"/api/hassio/{bad_path}") + + # Check we got right response + assert resp.status == expected_status + # Check we didn't forward command + assert len(aioclient_mock.mock_calls) == 0 + + +@pytest.mark.parametrize( + ("path", "authenticated"), + [ + ("app/entrypoint.js", False), + ("addons/bl_b392/logo", False), + ("addons/bl_b392/icon", False), + ("backups/1234abcd/info", True), + ], +) +async def test_forward_request_not_onboarded_get( + hassio_noauth_client, + aioclient_mock: AiohttpClientMocker, + path: str, + authenticated: bool, + mock_not_onboarded, +) -> None: + """Test fetching normal path.""" + aioclient_mock.get(f"http://127.0.0.1/{path}", text="response") + + resp = await hassio_noauth_client.get(f"/api/hassio/{path}") # Check we got right response assert resp.status == HTTPStatus.OK @@ -100,61 +193,224 @@ async def test_forward_request_no_auth_for_icon( # Check we forwarded command assert len(aioclient_mock.mock_calls) == 1 + expected_headers = { + "X-Hass-Source": "core.http", + } + if authenticated: + expected_headers["Authorization"] = "Bearer 123456" + + assert aioclient_mock.mock_calls[0][3] == expected_headers -async def test_forward_log_request( - hassio_client, aioclient_mock: AiohttpClientMocker +@pytest.mark.parametrize( + "path", + [ + "backups/new/upload", + "backups/1234abcd/restore/full", + "backups/1234abcd/restore/partial", + ], +) +async def test_forward_request_not_onboarded_post( + hassio_noauth_client, + aioclient_mock: AiohttpClientMocker, + path: str, + mock_not_onboarded, ) -> None: - """Test fetching normal log path doesn't remove ANSI color escape codes.""" - aioclient_mock.get("http://127.0.0.1/beer/logs", text="\033[32mresponse\033[0m") + """Test fetching normal path.""" + aioclient_mock.get(f"http://127.0.0.1/{path}", text="response") - resp = await hassio_client.get("/api/hassio/beer/logs") + resp = await hassio_noauth_client.get(f"/api/hassio/{path}") # Check we got right response assert resp.status == HTTPStatus.OK body = await resp.text() - assert body == "\033[32mresponse\033[0m" + assert body == "response" # Check we forwarded command assert len(aioclient_mock.mock_calls) == 1 + # We only expect a single header. + assert aioclient_mock.mock_calls[0][3] == { + "X-Hass-Source": "core.http", + "Authorization": "Bearer 123456", + } + + +@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE", "RANDOM"]) +async def test_forward_request_not_onboarded_unallowed_methods( + hassio_noauth_client, aioclient_mock: AiohttpClientMocker, method: str +) -> None: + """Test fetching normal path.""" + resp = await hassio_noauth_client.post("/api/hassio/app/entrypoint.js") + + # Check we got right response + assert resp.status == HTTPStatus.METHOD_NOT_ALLOWED + + # Check we did not forward command + assert len(aioclient_mock.mock_calls) == 0 + + +@pytest.mark.parametrize( + ("bad_path", "expected_status"), + [ + # Caught by bullshit filter + ("app/%252E./entrypoint.js", HTTPStatus.BAD_REQUEST), + # The .. is processed, making it an unauthenticated path + ("app/../entrypoint.js", HTTPStatus.UNAUTHORIZED), + ("app/%2E%2E/entrypoint.js", HTTPStatus.UNAUTHORIZED), + # Unauthenticated path + ("supervisor/info", HTTPStatus.UNAUTHORIZED), + ("supervisor/logs", HTTPStatus.UNAUTHORIZED), + ("addons/bl_b392/logs", HTTPStatus.UNAUTHORIZED), + ], +) +async def test_forward_request_not_onboarded_unallowed_paths( + hassio_noauth_client, + aioclient_mock: AiohttpClientMocker, + bad_path: str, + expected_status: int, + mock_not_onboarded, +) -> None: + """Test fetching normal path.""" + resp = await hassio_noauth_client.get(f"/api/hassio/{bad_path}") + + # Check we got right response + assert resp.status == expected_status + # Check we didn't forward command + assert len(aioclient_mock.mock_calls) == 0 + + +@pytest.mark.parametrize( + ("path", "authenticated"), + [ + ("app/entrypoint.js", False), + ("addons/bl_b392/logo", False), + ("addons/bl_b392/icon", False), + ("backups/1234abcd/info", True), + ("supervisor/logs", True), + ("addons/bl_b392/logs", True), + ], +) +async def test_forward_request_admin_get( + hassio_client, + aioclient_mock: AiohttpClientMocker, + path: str, + authenticated: bool, +) -> None: + """Test fetching normal path.""" + aioclient_mock.get(f"http://127.0.0.1/{path}", text="response") + + resp = await hassio_client.get(f"/api/hassio/{path}") + + # Check we got right response + assert resp.status == HTTPStatus.OK + body = await resp.text() + assert body == "response" + + # Check we forwarded command + assert len(aioclient_mock.mock_calls) == 1 + expected_headers = { + "X-Hass-Source": "core.http", + } + if authenticated: + expected_headers["Authorization"] = "Bearer 123456" + + assert aioclient_mock.mock_calls[0][3] == expected_headers + + +@pytest.mark.parametrize( + "path", + [ + "backups/new/upload", + "backups/1234abcd/restore/full", + "backups/1234abcd/restore/partial", + ], +) +async def test_forward_request_admin_post( + hassio_client, + aioclient_mock: AiohttpClientMocker, + path: str, +) -> None: + """Test fetching normal path.""" + aioclient_mock.get(f"http://127.0.0.1/{path}", text="response") + + resp = await hassio_client.get(f"/api/hassio/{path}") + + # Check we got right response + assert resp.status == HTTPStatus.OK + body = await resp.text() + assert body == "response" + + # Check we forwarded command + assert len(aioclient_mock.mock_calls) == 1 + # We only expect a single header. + assert aioclient_mock.mock_calls[0][3] == { + "X-Hass-Source": "core.http", + "Authorization": "Bearer 123456", + } + + +@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE", "RANDOM"]) +async def test_forward_request_admin_unallowed_methods( + hassio_client, aioclient_mock: AiohttpClientMocker, method: str +) -> None: + """Test fetching normal path.""" + resp = await hassio_client.post("/api/hassio/app/entrypoint.js") + + # Check we got right response + assert resp.status == HTTPStatus.METHOD_NOT_ALLOWED + + # Check we did not forward command + assert len(aioclient_mock.mock_calls) == 0 + + +@pytest.mark.parametrize( + ("bad_path", "expected_status"), + [ + # Caught by bullshit filter + ("app/%252E./entrypoint.js", HTTPStatus.BAD_REQUEST), + # The .. is processed, making it an unauthenticated path + ("app/../entrypoint.js", HTTPStatus.UNAUTHORIZED), + ("app/%2E%2E/entrypoint.js", HTTPStatus.UNAUTHORIZED), + # Unauthenticated path + ("supervisor/info", HTTPStatus.UNAUTHORIZED), + ], +) +async def test_forward_request_admin_unallowed_paths( + hassio_client, + aioclient_mock: AiohttpClientMocker, + bad_path: str, + expected_status: int, +) -> None: + """Test fetching normal path.""" + resp = await hassio_client.get(f"/api/hassio/{bad_path}") + + # Check we got right response + assert resp.status == expected_status + # Check we didn't forward command + assert len(aioclient_mock.mock_calls) == 0 async def test_bad_gateway_when_cannot_find_supervisor( hassio_client, aioclient_mock: AiohttpClientMocker ) -> None: """Test we get a bad gateway error if we can't find supervisor.""" - aioclient_mock.get("http://127.0.0.1/addons/test/info", exc=asyncio.TimeoutError) + aioclient_mock.get("http://127.0.0.1/app/entrypoint.js", exc=asyncio.TimeoutError) - resp = await hassio_client.get("/api/hassio/addons/test/info") + resp = await hassio_client.get("/api/hassio/app/entrypoint.js") assert resp.status == HTTPStatus.BAD_GATEWAY -async def test_forwarding_user_info( - hassio_client, hass_admin_user: MockUser, aioclient_mock: AiohttpClientMocker -) -> None: - """Test that we forward user info correctly.""" - aioclient_mock.get("http://127.0.0.1/hello") - - resp = await hassio_client.get("/api/hassio/hello") - - # Check we got right response - assert resp.status == HTTPStatus.OK - - assert len(aioclient_mock.mock_calls) == 1 - - req_headers = aioclient_mock.mock_calls[0][-1] - assert req_headers["X-Hass-User-ID"] == hass_admin_user.id - assert req_headers["X-Hass-Is-Admin"] == "1" - - async def test_backup_upload_headers( - hassio_client, aioclient_mock: AiohttpClientMocker, caplog: pytest.LogCaptureFixture + hassio_client, + aioclient_mock: AiohttpClientMocker, + caplog: pytest.LogCaptureFixture, + mock_not_onboarded, ) -> None: """Test that we forward the full header for backup upload.""" content_type = "multipart/form-data; boundary='--webkit'" - aioclient_mock.get("http://127.0.0.1/backups/new/upload") + aioclient_mock.post("http://127.0.0.1/backups/new/upload") - resp = await hassio_client.get( + resp = await hassio_client.post( "/api/hassio/backups/new/upload", headers={"Content-Type": content_type} ) @@ -168,19 +424,19 @@ async def test_backup_upload_headers( async def test_backup_download_headers( - hassio_client, aioclient_mock: AiohttpClientMocker + hassio_client, aioclient_mock: AiohttpClientMocker, mock_not_onboarded ) -> None: """Test that we forward the full header for backup download.""" content_disposition = "attachment; filename=test.tar" aioclient_mock.get( - "http://127.0.0.1/backups/slug/download", + "http://127.0.0.1/backups/1234abcd/download", headers={ "Content-Length": "50000000", "Content-Disposition": content_disposition, }, ) - resp = await hassio_client.get("/api/hassio/backups/slug/download") + resp = await hassio_client.get("/api/hassio/backups/1234abcd/download") # Check we got right response assert resp.status == HTTPStatus.OK @@ -190,21 +446,10 @@ async def test_backup_download_headers( assert resp.headers["Content-Disposition"] == content_disposition -def test_need_auth(hass: HomeAssistant) -> None: - """Test if the requested path needs authentication.""" - assert not _need_auth(hass, "addons/test/logo") - assert _need_auth(hass, "backups/new/upload") - assert _need_auth(hass, "supervisor/logs") - - hass.data["onboarding"] = False - assert not _need_auth(hass, "backups/new/upload") - assert not _need_auth(hass, "supervisor/logs") - - async def test_stream(hassio_client, aioclient_mock: AiohttpClientMocker) -> None: """Verify that the request is a stream.""" - aioclient_mock.get("http://127.0.0.1/test") - await hassio_client.get("/api/hassio/test", data="test") + aioclient_mock.get("http://127.0.0.1/app/entrypoint.js") + await hassio_client.get("/api/hassio/app/entrypoint.js", data="test") assert isinstance(aioclient_mock.mock_calls[-1][2], StreamReader) diff --git a/tests/components/hassio/test_ingress.py b/tests/components/hassio/test_ingress.py index 52ca535516a..67548a19c2c 100644 --- a/tests/components/hassio/test_ingress.py +++ b/tests/components/hassio/test_ingress.py @@ -21,7 +21,7 @@ from tests.test_util.aiohttp import AiohttpClientMocker ], ) async def test_ingress_request_get( - hassio_client, build_type, aioclient_mock: AiohttpClientMocker + hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker ) -> None: """Test no auth needed for .""" aioclient_mock.get( @@ -29,7 +29,7 @@ async def test_ingress_request_get( text="test", ) - resp = await hassio_client.get( + resp = await hassio_noauth_client.get( f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", headers={"X-Test-Header": "beer"}, ) @@ -41,7 +41,8 @@ async def test_ingress_request_get( # Check we forwarded command assert len(aioclient_mock.mock_calls) == 1 - assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" + assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3] + assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress" assert ( aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] == f"/api/hassio_ingress/{build_type[0]}" @@ -63,7 +64,7 @@ async def test_ingress_request_get( ], ) async def test_ingress_request_post( - hassio_client, build_type, aioclient_mock: AiohttpClientMocker + hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker ) -> None: """Test no auth needed for .""" aioclient_mock.post( @@ -71,7 +72,7 @@ async def test_ingress_request_post( text="test", ) - resp = await hassio_client.post( + resp = await hassio_noauth_client.post( f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", headers={"X-Test-Header": "beer"}, ) @@ -83,7 +84,8 @@ async def test_ingress_request_post( # Check we forwarded command assert len(aioclient_mock.mock_calls) == 1 - assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" + assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3] + assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress" assert ( aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] == f"/api/hassio_ingress/{build_type[0]}" @@ -105,7 +107,7 @@ async def test_ingress_request_post( ], ) async def test_ingress_request_put( - hassio_client, build_type, aioclient_mock: AiohttpClientMocker + hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker ) -> None: """Test no auth needed for .""" aioclient_mock.put( @@ -113,7 +115,7 @@ async def test_ingress_request_put( text="test", ) - resp = await hassio_client.put( + resp = await hassio_noauth_client.put( f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", headers={"X-Test-Header": "beer"}, ) @@ -125,7 +127,8 @@ async def test_ingress_request_put( # Check we forwarded command assert len(aioclient_mock.mock_calls) == 1 - assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" + assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3] + assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress" assert ( aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] == f"/api/hassio_ingress/{build_type[0]}" @@ -147,7 +150,7 @@ async def test_ingress_request_put( ], ) async def test_ingress_request_delete( - hassio_client, build_type, aioclient_mock: AiohttpClientMocker + hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker ) -> None: """Test no auth needed for .""" aioclient_mock.delete( @@ -155,7 +158,7 @@ async def test_ingress_request_delete( text="test", ) - resp = await hassio_client.delete( + resp = await hassio_noauth_client.delete( f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", headers={"X-Test-Header": "beer"}, ) @@ -167,7 +170,8 @@ async def test_ingress_request_delete( # Check we forwarded command assert len(aioclient_mock.mock_calls) == 1 - assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" + assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3] + assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress" assert ( aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] == f"/api/hassio_ingress/{build_type[0]}" @@ -189,7 +193,7 @@ async def test_ingress_request_delete( ], ) async def test_ingress_request_patch( - hassio_client, build_type, aioclient_mock: AiohttpClientMocker + hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker ) -> None: """Test no auth needed for .""" aioclient_mock.patch( @@ -197,7 +201,7 @@ async def test_ingress_request_patch( text="test", ) - resp = await hassio_client.patch( + resp = await hassio_noauth_client.patch( f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", headers={"X-Test-Header": "beer"}, ) @@ -209,7 +213,8 @@ async def test_ingress_request_patch( # Check we forwarded command assert len(aioclient_mock.mock_calls) == 1 - assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" + assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3] + assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress" assert ( aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] == f"/api/hassio_ingress/{build_type[0]}" @@ -231,7 +236,7 @@ async def test_ingress_request_patch( ], ) async def test_ingress_request_options( - hassio_client, build_type, aioclient_mock: AiohttpClientMocker + hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker ) -> None: """Test no auth needed for .""" aioclient_mock.options( @@ -239,7 +244,7 @@ async def test_ingress_request_options( text="test", ) - resp = await hassio_client.options( + resp = await hassio_noauth_client.options( f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", headers={"X-Test-Header": "beer"}, ) @@ -251,7 +256,8 @@ async def test_ingress_request_options( # Check we forwarded command assert len(aioclient_mock.mock_calls) == 1 - assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" + assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3] + assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress" assert ( aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] == f"/api/hassio_ingress/{build_type[0]}" @@ -273,20 +279,21 @@ async def test_ingress_request_options( ], ) async def test_ingress_websocket( - hassio_client, build_type, aioclient_mock: AiohttpClientMocker + hassio_noauth_client, build_type, aioclient_mock: AiohttpClientMocker ) -> None: """Test no auth needed for .""" aioclient_mock.get(f"http://127.0.0.1/ingress/{build_type[0]}/{build_type[1]}") # Ignore error because we can setup a full IO infrastructure - await hassio_client.ws_connect( + await hassio_noauth_client.ws_connect( f"/api/hassio_ingress/{build_type[0]}/{build_type[1]}", headers={"X-Test-Header": "beer"}, ) # Check we forwarded command assert len(aioclient_mock.mock_calls) == 1 - assert aioclient_mock.mock_calls[-1][3][X_AUTH_TOKEN] == "123456" + assert X_AUTH_TOKEN not in aioclient_mock.mock_calls[-1][3] + assert aioclient_mock.mock_calls[-1][3]["X-Hass-Source"] == "core.ingress" assert ( aioclient_mock.mock_calls[-1][3]["X-Ingress-Path"] == f"/api/hassio_ingress/{build_type[0]}" @@ -298,7 +305,9 @@ async def test_ingress_websocket( async def test_ingress_missing_peername( - hassio_client, aioclient_mock: AiohttpClientMocker, caplog: pytest.LogCaptureFixture + hassio_noauth_client, + aioclient_mock: AiohttpClientMocker, + caplog: pytest.LogCaptureFixture, ) -> None: """Test hadnling of missing peername.""" aioclient_mock.get( @@ -314,7 +323,7 @@ async def test_ingress_missing_peername( return_value=MagicMock(), ) as transport_mock: transport_mock.get_extra_info = get_extra_info - resp = await hassio_client.get( + resp = await hassio_noauth_client.get( "/api/hassio_ingress/lorem/ipsum", headers={"X-Test-Header": "beer"}, ) @@ -323,3 +332,19 @@ async def test_ingress_missing_peername( # Check we got right response assert resp.status == HTTPStatus.BAD_REQUEST + + +async def test_forwarding_paths_as_requested( + hassio_noauth_client, aioclient_mock +) -> None: + """Test incomnig URLs with double encoding go out as dobule encoded.""" + # This double encoded string should be forwarded double-encoded too. + aioclient_mock.get( + "http://127.0.0.1/ingress/mock-token/hello/%252e./world", + text="test", + ) + + resp = await hassio_noauth_client.get( + "/api/hassio_ingress/mock-token/hello/%252e./world", + ) + assert await resp.text() == "test" diff --git a/tests/components/hassio/test_websocket_api.py b/tests/components/hassio/test_websocket_api.py index 611ada61814..b2f9e06cb43 100644 --- a/tests/components/hassio/test_websocket_api.py +++ b/tests/components/hassio/test_websocket_api.py @@ -153,6 +153,11 @@ async def test_websocket_supervisor_api( msg = await websocket_client.receive_json() assert msg["result"]["version_latest"] == "1.0.0" + assert aioclient_mock.mock_calls[-1][3] == { + "X-Hass-Source": "core.websocket_api", + "Authorization": "Bearer 123456", + } + async def test_websocket_supervisor_api_error( hassio_env, diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index 44c3ffac99e..19c7e6c6955 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -69,7 +69,7 @@ async def test_schema_update_calls(recorder_db_url: str, hass: HomeAssistant) -> session_maker = instance.get_session update.assert_has_calls( [ - call(hass, engine, session_maker, version + 1, 0) + call(instance, hass, engine, session_maker, version + 1, 0) for version in range(0, db_schema.SCHEMA_VERSION) ] ) @@ -304,6 +304,8 @@ async def test_schema_migrate( migration_version = None real_migrate_schema = recorder.migration.migrate_schema real_apply_update = recorder.migration._apply_update + real_create_index = recorder.migration._create_index + create_calls = 0 def _create_engine_test(*args, **kwargs): """Test version of create_engine that initializes with old schema. @@ -355,6 +357,17 @@ async def test_schema_migrate( migration_stall.wait() real_apply_update(*args) + def _sometimes_failing_create_index(*args): + """Make the first index create raise a retryable error to ensure we retry.""" + if recorder_db_url.startswith("mysql://"): + nonlocal create_calls + if create_calls < 1: + create_calls += 1 + mysql_exception = OperationalError("statement", {}, []) + mysql_exception.orig = Exception(1205, "retryable") + raise mysql_exception + real_create_index(*args) + with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch( "homeassistant.components.recorder.core.create_engine", new=_create_engine_test, @@ -368,6 +381,11 @@ async def test_schema_migrate( ), patch( "homeassistant.components.recorder.migration._apply_update", wraps=_instrument_apply_update, + ) as apply_update_mock, patch( + "homeassistant.components.recorder.util.time.sleep" + ), patch( + "homeassistant.components.recorder.migration._create_index", + wraps=_sometimes_failing_create_index, ), patch( "homeassistant.components.recorder.Recorder._schedule_compile_missing_statistics", ), patch( @@ -394,12 +412,13 @@ async def test_schema_migrate( assert migration_version == db_schema.SCHEMA_VERSION assert setup_run.called assert recorder.util.async_migration_in_progress(hass) is not True + assert apply_update_mock.called def test_invalid_update(hass: HomeAssistant) -> None: """Test that an invalid new version raises an exception.""" with pytest.raises(ValueError): - migration._apply_update(hass, Mock(), Mock(), -1, 0) + migration._apply_update(Mock(), hass, Mock(), Mock(), -1, 0) @pytest.mark.parametrize( diff --git a/tests/components/recorder/test_purge.py b/tests/components/recorder/test_purge.py index c5ce8d272c7..07c935129e9 100644 --- a/tests/components/recorder/test_purge.py +++ b/tests/components/recorder/test_purge.py @@ -2,7 +2,7 @@ from datetime import datetime, timedelta import json import sqlite3 -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from sqlalchemy.exc import DatabaseError, OperationalError @@ -192,7 +192,7 @@ async def test_purge_old_states_encounters_temporary_mysql_error( await async_wait_recording_done(hass) mysql_exception = OperationalError("statement", {}, []) - mysql_exception.orig = MagicMock(args=(1205, "retryable")) + mysql_exception.orig = Exception(1205, "retryable") with patch( "homeassistant.components.recorder.util.time.sleep" diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index 8685985def8..e6ae291264f 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -8,7 +8,7 @@ import sys from unittest.mock import ANY, DEFAULT, MagicMock, patch, sentinel import pytest -from sqlalchemy import create_engine +from sqlalchemy import create_engine, select from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Session @@ -22,6 +22,10 @@ from homeassistant.components.recorder.models import ( ) from homeassistant.components.recorder.statistics import ( STATISTIC_UNIT_TO_UNIT_CONVERTER, + _generate_get_metadata_stmt, + _generate_max_mean_min_statistic_in_sub_period_stmt, + _generate_statistics_at_time_stmt, + _generate_statistics_during_period_stmt, _statistics_during_period_with_session, _update_or_add_metadata, async_add_external_statistics, @@ -1231,8 +1235,9 @@ def test_delete_duplicates_no_duplicates( """Test removal of duplicated statistics.""" hass = hass_recorder() wait_recording_done(hass) + instance = recorder.get_instance(hass) with session_scope(hass=hass) as session: - delete_statistics_duplicates(hass, session) + delete_statistics_duplicates(instance, hass, session) assert "duplicated statistics rows" not in caplog.text assert "Found non identical" not in caplog.text assert "Found duplicated" not in caplog.text @@ -1798,3 +1803,100 @@ def record_states(hass): states[sns4].append(set_state(sns4, "20", attributes=sns4_attr)) return zero, four, states + + +def test_cache_key_for_generate_statistics_during_period_stmt(): + """Test cache key for _generate_statistics_during_period_stmt.""" + columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts) + stmt = _generate_statistics_during_period_stmt( + columns, dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm, {} + ) + cache_key_1 = stmt._generate_cache_key() + stmt2 = _generate_statistics_during_period_stmt( + columns, dt_util.utcnow(), dt_util.utcnow(), [0], StatisticsShortTerm, {} + ) + cache_key_2 = stmt2._generate_cache_key() + assert cache_key_1 == cache_key_2 + columns2 = select( + StatisticsShortTerm.metadata_id, + StatisticsShortTerm.start_ts, + StatisticsShortTerm.sum, + StatisticsShortTerm.mean, + ) + stmt3 = _generate_statistics_during_period_stmt( + columns2, + dt_util.utcnow(), + dt_util.utcnow(), + [0], + StatisticsShortTerm, + {"max", "mean"}, + ) + cache_key_3 = stmt3._generate_cache_key() + assert cache_key_1 != cache_key_3 + + +def test_cache_key_for_generate_get_metadata_stmt(): + """Test cache key for _generate_get_metadata_stmt.""" + stmt_mean = _generate_get_metadata_stmt([0], "mean") + stmt_mean2 = _generate_get_metadata_stmt([1], "mean") + stmt_sum = _generate_get_metadata_stmt([0], "sum") + stmt_none = _generate_get_metadata_stmt() + assert stmt_mean._generate_cache_key() == stmt_mean2._generate_cache_key() + assert stmt_mean._generate_cache_key() != stmt_sum._generate_cache_key() + assert stmt_mean._generate_cache_key() != stmt_none._generate_cache_key() + + +def test_cache_key_for_generate_max_mean_min_statistic_in_sub_period_stmt(): + """Test cache key for _generate_max_mean_min_statistic_in_sub_period_stmt.""" + columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts) + stmt = _generate_max_mean_min_statistic_in_sub_period_stmt( + columns, + dt_util.utcnow(), + dt_util.utcnow(), + StatisticsShortTerm, + [0], + ) + cache_key_1 = stmt._generate_cache_key() + stmt2 = _generate_max_mean_min_statistic_in_sub_period_stmt( + columns, + dt_util.utcnow(), + dt_util.utcnow(), + StatisticsShortTerm, + [0], + ) + cache_key_2 = stmt2._generate_cache_key() + assert cache_key_1 == cache_key_2 + columns2 = select( + StatisticsShortTerm.metadata_id, + StatisticsShortTerm.start_ts, + StatisticsShortTerm.sum, + StatisticsShortTerm.mean, + ) + stmt3 = _generate_max_mean_min_statistic_in_sub_period_stmt( + columns2, + dt_util.utcnow(), + dt_util.utcnow(), + StatisticsShortTerm, + [0], + ) + cache_key_3 = stmt3._generate_cache_key() + assert cache_key_1 != cache_key_3 + + +def test_cache_key_for_generate_statistics_at_time_stmt(): + """Test cache key for _generate_statistics_at_time_stmt.""" + columns = select(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start_ts) + stmt = _generate_statistics_at_time_stmt(columns, StatisticsShortTerm, {0}, 0.0) + cache_key_1 = stmt._generate_cache_key() + stmt2 = _generate_statistics_at_time_stmt(columns, StatisticsShortTerm, {0}, 0.0) + cache_key_2 = stmt2._generate_cache_key() + assert cache_key_1 == cache_key_2 + columns2 = select( + StatisticsShortTerm.metadata_id, + StatisticsShortTerm.start_ts, + StatisticsShortTerm.sum, + StatisticsShortTerm.mean, + ) + stmt3 = _generate_statistics_at_time_stmt(columns2, StatisticsShortTerm, {0}, 0.0) + cache_key_3 = stmt3._generate_cache_key() + assert cache_key_1 != cache_key_3 diff --git a/tests/components/thread/test_diagnostics.py b/tests/components/thread/test_diagnostics.py index 1006fa374c3..a551315205b 100644 --- a/tests/components/thread/test_diagnostics.py +++ b/tests/components/thread/test_diagnostics.py @@ -133,9 +133,7 @@ class MockNeighbour: @pytest.fixture def ndb() -> Mock: """Prevent NDB poking the OS route tables.""" - with patch( - "homeassistant.components.thread.diagnostics.NDB" - ) as ndb, ndb() as instance: + with patch("pyroute2.NDB") as ndb, ndb() as instance: instance.neighbours = [] instance.routes = [] yield instance diff --git a/tests/components/tplink_omada/test_config_flow.py b/tests/components/tplink_omada/test_config_flow.py index fd32b357b7c..cf3fddf5943 100644 --- a/tests/components/tplink_omada/test_config_flow.py +++ b/tests/components/tplink_omada/test_config_flow.py @@ -22,14 +22,14 @@ from homeassistant.data_entry_flow import FlowResultType from tests.common import MockConfigEntry MOCK_USER_DATA = { - "host": "1.1.1.1", + "host": "https://fake.omada.host", "verify_ssl": True, "username": "test-username", "password": "test-password", } MOCK_ENTRY_DATA = { - "host": "1.1.1.1", + "host": "https://fake.omada.host", "verify_ssl": True, "site": "SiteId", "username": "test-username", @@ -111,7 +111,7 @@ async def test_form_multiple_sites(hass: HomeAssistant) -> None: assert result3["type"] == FlowResultType.CREATE_ENTRY assert result3["title"] == "OC200 (Site 2)" assert result3["data"] == { - "host": "1.1.1.1", + "host": "https://fake.omada.host", "verify_ssl": True, "site": "second", "username": "test-username", @@ -272,7 +272,7 @@ async def test_async_step_reauth_success(hass: HomeAssistant) -> None: mocked_validate.assert_called_once_with( hass, { - "host": "1.1.1.1", + "host": "https://fake.omada.host", "verify_ssl": True, "site": "SiteId", "username": "new_uname", @@ -353,6 +353,64 @@ async def test_create_omada_client_parses_args(hass: HomeAssistant) -> None: assert result is not None mock_client.assert_called_once_with( - "1.1.1.1", "test-username", "test-password", "ws" + "https://fake.omada.host", "test-username", "test-password", "ws" ) mock_clientsession.assert_called_once_with(hass, verify_ssl=True) + + +async def test_create_omada_client_adds_missing_scheme(hass: HomeAssistant) -> None: + """Test config arguments are passed to Omada client.""" + + with patch( + "homeassistant.components.tplink_omada.config_flow.OmadaClient", autospec=True + ) as mock_client, patch( + "homeassistant.components.tplink_omada.config_flow.async_get_clientsession", + return_value="ws", + ) as mock_clientsession: + result = await create_omada_client( + hass, + { + "host": "fake.omada.host", + "verify_ssl": True, + "username": "test-username", + "password": "test-password", + }, + ) + + assert result is not None + mock_client.assert_called_once_with( + "https://fake.omada.host", "test-username", "test-password", "ws" + ) + mock_clientsession.assert_called_once_with(hass, verify_ssl=True) + + +async def test_create_omada_client_with_ip_creates_clientsession( + hass: HomeAssistant, +) -> None: + """Test config arguments are passed to Omada client.""" + + with patch( + "homeassistant.components.tplink_omada.config_flow.OmadaClient", autospec=True + ) as mock_client, patch( + "homeassistant.components.tplink_omada.config_flow.CookieJar", autospec=True + ) as mock_jar, patch( + "homeassistant.components.tplink_omada.config_flow.async_create_clientsession", + return_value="ws", + ) as mock_create_clientsession: + result = await create_omada_client( + hass, + { + "host": "10.10.10.10", + "verify_ssl": True, # Verify is meaningless for IP + "username": "test-username", + "password": "test-password", + }, + ) + + assert result is not None + mock_client.assert_called_once_with( + "https://10.10.10.10", "test-username", "test-password", "ws" + ) + mock_create_clientsession.assert_called_once_with( + hass, cookie_jar=mock_jar.return_value + ) diff --git a/tests/components/zha/test_init.py b/tests/components/zha/test_init.py index e580242a677..a92631f6da3 100644 --- a/tests/components/zha/test_init.py +++ b/tests/components/zha/test_init.py @@ -1,9 +1,10 @@ """Tests for ZHA integration init.""" -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from zigpy.config import CONF_DEVICE, CONF_DEVICE_PATH +from homeassistant.components.zha import async_setup_entry from homeassistant.components.zha.core.const import ( CONF_BAUDRATE, CONF_RADIO_TYPE, @@ -108,3 +109,41 @@ async def test_config_depreciation(hass: HomeAssistant, zha_config) -> None: ) as setup_mock: assert await async_setup_component(hass, DOMAIN, {DOMAIN: zha_config}) assert setup_mock.call_count == 1 + + +@pytest.mark.parametrize( + ("path", "cleaned_path"), + [ + ("/dev/path1", "/dev/path1"), + ("/dev/path1 ", "/dev/path1 "), + ("socket://dev/path1 ", "socket://dev/path1"), + ], +) +@patch("homeassistant.components.zha.setup_quirks", Mock(return_value=True)) +@patch("homeassistant.components.zha.api.async_load_api", Mock(return_value=True)) +async def test_setup_with_v3_spaces_in_uri( + hass: HomeAssistant, path: str, cleaned_path: str +) -> None: + """Test migration of config entry from v3 with spaces after `socket://` URI.""" + config_entry_v3 = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_RADIO_TYPE: DATA_RADIO_TYPE, + CONF_DEVICE: {CONF_DEVICE_PATH: path, CONF_BAUDRATE: 115200}, + }, + version=3, + ) + config_entry_v3.add_to_hass(hass) + + with patch( + "homeassistant.components.zha.ZHAGateway", return_value=AsyncMock() + ) as mock_gateway: + mock_gateway.return_value.coordinator_ieee = "mock_ieee" + mock_gateway.return_value.radio_description = "mock_radio" + + assert await async_setup_entry(hass, config_entry_v3) + hass.data[DOMAIN]["zha_gateway"] = mock_gateway.return_value + + assert config_entry_v3.data[CONF_RADIO_TYPE] == DATA_RADIO_TYPE + assert config_entry_v3.data[CONF_DEVICE][CONF_DEVICE_PATH] == cleaned_path + assert config_entry_v3.version == 3