diff --git a/homeassistant/components/diagnostics/__init__.py b/homeassistant/components/diagnostics/__init__.py index 966845442e4..5c719d72b09 100644 --- a/homeassistant/components/diagnostics/__init__.py +++ b/homeassistant/components/diagnostics/__init__.py @@ -1,6 +1,7 @@ """The Diagnostics integration.""" from __future__ import annotations +from http import HTTPStatus import json import logging from typing import Protocol @@ -12,6 +13,7 @@ from homeassistant.components import http, websocket_api from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import integration_platform +from homeassistant.helpers.device_registry import DeviceEntry, async_get from homeassistant.helpers.json import ExtendedJSONEncoder from homeassistant.helpers.typing import ConfigType from homeassistant.util.json import ( @@ -19,7 +21,7 @@ from homeassistant.util.json import ( format_unserializable_data, ) -from .const import DOMAIN, REDACTED +from .const import DOMAIN, REDACTED, DiagnosticsSubType, DiagnosticsType __all__ = ["REDACTED"] @@ -35,6 +37,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ) websocket_api.async_register_command(hass, handle_info) + websocket_api.async_register_command(hass, handle_get) hass.http.register_view(DownloadDiagnosticsView) return True @@ -48,13 +51,23 @@ class DiagnosticsProtocol(Protocol): ) -> dict: """Return diagnostics for a config entry.""" + async def async_get_device_diagnostics( + self, hass: HomeAssistant, config_entry: ConfigEntry, device: DeviceEntry + ) -> dict: + """Return diagnostics for a device.""" + async def _register_diagnostics_platform( hass: HomeAssistant, integration_domain: str, platform: DiagnosticsProtocol ): """Register a diagnostics platform.""" hass.data[DOMAIN][integration_domain] = { - "config_entry": getattr(platform, "async_get_config_entry_diagnostics", None) + DiagnosticsType.CONFIG_ENTRY.value: getattr( + platform, "async_get_config_entry_diagnostics", None + ), + DiagnosticsSubType.DEVICE.value: getattr( + platform, "async_get_device_diagnostics", None + ), } @@ -77,50 +90,123 @@ def handle_info( ) +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "diagnostics/get", + vol.Required("domain"): str, + } +) +@callback +def handle_get( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict +): + """List all possible diagnostic handlers.""" + domain = msg["domain"] + info = hass.data[DOMAIN].get(domain) + + if info is None: + connection.send_error( + msg["id"], websocket_api.ERR_NOT_FOUND, "Domain not supported" + ) + return + + connection.send_result( + msg["id"], + { + "domain": domain, + "handlers": {key: val is not None for key, val in info.items()}, + }, + ) + + +def _get_json_file_response( + data: dict | list, + filename: str, + d_type: DiagnosticsType, + d_id: str, + sub_type: DiagnosticsSubType | None = None, + sub_id: str | None = None, +) -> web.Response: + """Return JSON file from dictionary.""" + try: + json_data = json.dumps(data, indent=2, cls=ExtendedJSONEncoder) + except TypeError: + _LOGGER.error( + "Failed to serialize to JSON: %s/%s%s. Bad data at %s", + d_type.value, + d_id, + f"/{sub_type.value}/{sub_id}" if sub_type is not None else "", + format_unserializable_data(find_paths_unserializable_data(data)), + ) + return web.Response(status=HTTPStatus.INTERNAL_SERVER_ERROR) + + return web.Response( + body=json_data, + content_type="application/json", + headers={"Content-Disposition": f'attachment; filename="{filename}.json"'}, + ) + + class DownloadDiagnosticsView(http.HomeAssistantView): """Download diagnostics view.""" url = "/api/diagnostics/{d_type}/{d_id}" + extra_urls = ["/api/diagnostics/{d_type}/{d_id}/{sub_type}/{sub_id}"] name = "api:diagnostics" async def get( # pylint: disable=no-self-use - self, request: web.Request, d_type: str, d_id: str + self, + request: web.Request, + d_type: str, + d_id: str, + sub_type: str | None = None, + sub_id: str | None = None, ) -> web.Response: """Download diagnostics.""" - if d_type != "config_entry": - return web.Response(status=404) + # t_type handling + try: + d_type = DiagnosticsType(d_type) + except ValueError: + return web.Response(status=HTTPStatus.BAD_REQUEST) hass = request.app["hass"] config_entry = hass.config_entries.async_get_entry(d_id) if config_entry is None: - return web.Response(status=404) + return web.Response(status=HTTPStatus.NOT_FOUND) info = hass.data[DOMAIN].get(config_entry.domain) if info is None: - return web.Response(status=404) + return web.Response(status=HTTPStatus.NOT_FOUND) - if info["config_entry"] is None: - return web.Response(status=404) + filename = f"{config_entry.domain}-{config_entry.entry_id}" - data = await info["config_entry"](hass, config_entry) + if sub_type is None: + if info[d_type.value] is None: + return web.Response(status=HTTPStatus.NOT_FOUND) + data = await info[d_type.value](hass, config_entry) + filename = f"{d_type}-{filename}" + return _get_json_file_response(data, filename, d_type.value, d_id) + # sub_type handling try: - json_data = json.dumps(data, indent=4, cls=ExtendedJSONEncoder) - except TypeError: - _LOGGER.error( - "Failed to serialize to JSON: %s/%s. Bad data at %s", - d_type, - d_id, - format_unserializable_data(find_paths_unserializable_data(data)), - ) - return web.Response(status=500) + sub_type = DiagnosticsSubType(sub_type) + except ValueError: + return web.Response(status=HTTPStatus.BAD_REQUEST) - return web.Response( - body=json_data, - content_type="application/json", - headers={ - "Content-Disposition": f'attachment; filename="{config_entry.domain}-{config_entry.entry_id}.json"' - }, - ) + dev_reg = async_get(hass) + assert sub_id + device = dev_reg.async_get(sub_id) + + if device is None: + return web.Response(status=HTTPStatus.NOT_FOUND) + + filename += f"-{device.name}-{device.id}" + + if info[sub_type.value] is None: + return web.Response(status=HTTPStatus.NOT_FOUND) + + data = await info[sub_type.value](hass, config_entry, sub_id) + return _get_json_file_response(data, filename, d_type, d_id, sub_type, sub_id) diff --git a/homeassistant/components/diagnostics/const.py b/homeassistant/components/diagnostics/const.py index c5e0f315497..0d07abde2bd 100644 --- a/homeassistant/components/diagnostics/const.py +++ b/homeassistant/components/diagnostics/const.py @@ -1,5 +1,18 @@ """Constants for the Diagnostics integration.""" +from homeassistant.backports.enum import StrEnum DOMAIN = "diagnostics" REDACTED = "**REDACTED**" + + +class DiagnosticsType(StrEnum): + """Diagnostics types.""" + + CONFIG_ENTRY = "config_entry" + + +class DiagnosticsSubType(StrEnum): + """Diagnostics sub types.""" + + DEVICE = "device" diff --git a/tests/components/diagnostics/__init__.py b/tests/components/diagnostics/__init__.py index 5cf56913f60..63961c8013a 100644 --- a/tests/components/diagnostics/__init__.py +++ b/tests/components/diagnostics/__init__.py @@ -3,17 +3,9 @@ from http import HTTPStatus from homeassistant.setup import async_setup_component -from tests.common import MockConfigEntry - -async def get_diagnostics_for_config_entry(hass, hass_client, domain_or_config_entry): +async def get_diagnostics_for_config_entry(hass, hass_client, config_entry): """Return the diagnostics config entry for the specified domain.""" - if isinstance(domain_or_config_entry, str): - config_entry = MockConfigEntry(domain=domain_or_config_entry) - config_entry.add_to_hass(hass) - else: - config_entry = domain_or_config_entry - assert await async_setup_component(hass, "diagnostics", {}) client = await hass_client() @@ -22,3 +14,13 @@ async def get_diagnostics_for_config_entry(hass, hass_client, domain_or_config_e ) assert response.status == HTTPStatus.OK return await response.json() + + +async def get_diagnostics_for_device(hass, hass_client, config_entry, device): + """Return the diagnostics for the specified device.""" + client = await hass_client() + response = await client.get( + f"/api/diagnostics/config_entry/{config_entry.entry_id}/device/{device.id}" + ) + assert response.status == HTTPStatus.OK + return await response.json() diff --git a/tests/components/diagnostics/test_init.py b/tests/components/diagnostics/test_init.py index 60bd28a862a..77b8d5fdebe 100644 --- a/tests/components/diagnostics/test_init.py +++ b/tests/components/diagnostics/test_init.py @@ -1,14 +1,16 @@ """Test the Diagnostics integration.""" +from http import HTTPStatus from unittest.mock import AsyncMock, Mock import pytest from homeassistant.components.websocket_api.const import TYPE_RESULT +from homeassistant.helpers.device_registry import async_get from homeassistant.setup import async_setup_component -from . import get_diagnostics_for_config_entry +from . import get_diagnostics_for_config_entry, get_diagnostics_for_device -from tests.common import mock_platform +from tests.common import MockConfigEntry, mock_platform @pytest.fixture(autouse=True) @@ -21,16 +23,26 @@ async def mock_diagnostics_integration(hass): Mock( async_get_config_entry_diagnostics=AsyncMock( return_value={ - "hello": "info", + "config_entry": "info", + } + ), + async_get_device_diagnostics=AsyncMock( + return_value={ + "device": "info", } ), ), ) + mock_platform( + hass, + "integration_without_diagnostics.diagnostics", + Mock(), + ) assert await async_setup_component(hass, "diagnostics", {}) -async def test_websocket_info(hass, hass_ws_client): - """Test camera_thumbnail websocket command.""" +async def test_websocket(hass, hass_ws_client): + """Test websocket command.""" client = await hass_ws_client(hass) await client.send_json({"id": 5, "type": "diagnostics/list"}) @@ -40,12 +52,78 @@ async def test_websocket_info(hass, hass_ws_client): assert msg["type"] == TYPE_RESULT assert msg["success"] assert msg["result"] == [ - {"domain": "fake_integration", "handlers": {"config_entry": True}} + { + "domain": "fake_integration", + "handlers": {"config_entry": True, "device": True}, + } ] + await client.send_json( + {"id": 6, "type": "diagnostics/get", "domain": "fake_integration"} + ) + + msg = await client.receive_json() + + assert msg["id"] == 6 + assert msg["type"] == TYPE_RESULT + assert msg["success"] + assert msg["result"] == { + "domain": "fake_integration", + "handlers": {"config_entry": True, "device": True}, + } + async def test_download_diagnostics(hass, hass_client): - """Test record service.""" - assert await get_diagnostics_for_config_entry( - hass, hass_client, "fake_integration" - ) == {"hello": "info"} + """Test download diagnostics.""" + config_entry = MockConfigEntry(domain="fake_integration") + config_entry.add_to_hass(hass) + + assert await get_diagnostics_for_config_entry(hass, hass_client, config_entry) == { + "config_entry": "info" + } + + dev_reg = async_get(hass) + device = dev_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, identifiers={("test", "test")} + ) + + assert await get_diagnostics_for_device( + hass, hass_client, config_entry, device + ) == {"device": "info"} + + +async def test_failure_scenarios(hass, hass_client): + """Test failure scenarios.""" + client = await hass_client() + + # test wrong d_type + response = await client.get("/api/diagnostics/wrong_type/fake_id") + assert response.status == HTTPStatus.BAD_REQUEST + + # test wrong d_id + response = await client.get("/api/diagnostics/config_entry/fake_id") + assert response.status == HTTPStatus.NOT_FOUND + + config_entry = MockConfigEntry(domain="integration_without_diagnostics") + config_entry.add_to_hass(hass) + + # test valid d_type and d_id but no config entry diagnostics + response = await client.get( + f"/api/diagnostics/config_entry/{config_entry.entry_id}" + ) + assert response.status == HTTPStatus.NOT_FOUND + + config_entry = MockConfigEntry(domain="fake_integration") + config_entry.add_to_hass(hass) + + # test invalid sub_type + response = await client.get( + f"/api/diagnostics/config_entry/{config_entry.entry_id}/wrong_type/id" + ) + assert response.status == HTTPStatus.BAD_REQUEST + + # test invalid sub_id + response = await client.get( + f"/api/diagnostics/config_entry/{config_entry.entry_id}/device/fake_id" + ) + assert response.status == HTTPStatus.NOT_FOUND