From 1cfbdd6a5d4efe97ad163181bfb499c593f4da30 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 25 Nov 2023 05:49:50 -0500 Subject: [PATCH] Allow overriding blueprints on import (#103340) Co-authored-by: Franck Nijhof --- .../components/automation/helpers.py | 12 ++- homeassistant/components/blueprint/models.py | 36 +++++--- .../components/blueprint/websocket_api.py | 31 ++++++- homeassistant/components/script/helpers.py | 10 ++- tests/components/blueprint/test_models.py | 10 +-- .../blueprint/test_websocket_api.py | 83 +++++++++++++++++++ 6 files changed, 159 insertions(+), 23 deletions(-) diff --git a/homeassistant/components/automation/helpers.py b/homeassistant/components/automation/helpers.py index 7c2efc17bf4..a7c329a544a 100644 --- a/homeassistant/components/automation/helpers.py +++ b/homeassistant/components/automation/helpers.py @@ -1,5 +1,6 @@ """Helpers for automation integration.""" from homeassistant.components import blueprint +from homeassistant.const import SERVICE_RELOAD from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.singleton import singleton @@ -15,8 +16,17 @@ def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool: return len(automations_with_blueprint(hass, blueprint_path)) > 0 +async def _reload_blueprint_automations( + hass: HomeAssistant, blueprint_path: str +) -> None: + """Reload all automations that rely on a specific blueprint.""" + await hass.services.async_call(DOMAIN, SERVICE_RELOAD) + + @singleton(DATA_BLUEPRINTS) @callback def async_get_blueprints(hass: HomeAssistant) -> blueprint.DomainBlueprints: """Get automation blueprints.""" - return blueprint.DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use) + return blueprint.DomainBlueprints( + hass, DOMAIN, LOGGER, _blueprint_in_use, _reload_blueprint_automations + ) diff --git a/homeassistant/components/blueprint/models.py b/homeassistant/components/blueprint/models.py index 6f48080a451..ddf57aa6eee 100644 --- a/homeassistant/components/blueprint/models.py +++ b/homeassistant/components/blueprint/models.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from collections.abc import Callable +from collections.abc import Awaitable, Callable import logging import pathlib import shutil @@ -189,12 +189,14 @@ class DomainBlueprints: domain: str, logger: logging.Logger, blueprint_in_use: Callable[[HomeAssistant, str], bool], + reload_blueprint_consumers: Callable[[HomeAssistant, str], Awaitable[None]], ) -> None: """Initialize a domain blueprints instance.""" self.hass = hass self.domain = domain self.logger = logger self._blueprint_in_use = blueprint_in_use + self._reload_blueprint_consumers = reload_blueprint_consumers self._blueprints: dict[str, Blueprint | None] = {} self._load_lock = asyncio.Lock() @@ -283,7 +285,7 @@ class DomainBlueprints: blueprint = await self.hass.async_add_executor_job( self._load_blueprint, blueprint_path ) - except Exception: + except FailedToLoad: self._blueprints[blueprint_path] = None raise @@ -315,31 +317,41 @@ class DomainBlueprints: await self.hass.async_add_executor_job(path.unlink) self._blueprints[blueprint_path] = None - def _create_file(self, blueprint: Blueprint, blueprint_path: str) -> None: - """Create blueprint file.""" + def _create_file( + self, blueprint: Blueprint, blueprint_path: str, allow_override: bool + ) -> bool: + """Create blueprint file. + + Returns true if the action overrides an existing blueprint. + """ path = pathlib.Path( self.hass.config.path(BLUEPRINT_FOLDER, self.domain, blueprint_path) ) - if path.exists(): + exists = path.exists() + + if not allow_override and exists: raise FileAlreadyExists(self.domain, blueprint_path) path.parent.mkdir(parents=True, exist_ok=True) path.write_text(blueprint.yaml(), encoding="utf-8") + return exists async def async_add_blueprint( - self, blueprint: Blueprint, blueprint_path: str - ) -> None: + self, blueprint: Blueprint, blueprint_path: str, allow_override=False + ) -> bool: """Add a blueprint.""" - if not blueprint_path.endswith(".yaml"): - blueprint_path = f"{blueprint_path}.yaml" - - await self.hass.async_add_executor_job( - self._create_file, blueprint, blueprint_path + overrides_existing = await self.hass.async_add_executor_job( + self._create_file, blueprint, blueprint_path, allow_override ) self._blueprints[blueprint_path] = blueprint + if overrides_existing: + await self._reload_blueprint_consumers(self.hass, blueprint_path) + + return overrides_existing + async def async_populate(self) -> None: """Create folder if it doesn't exist and populate with examples.""" if self._blueprints: diff --git a/homeassistant/components/blueprint/websocket_api.py b/homeassistant/components/blueprint/websocket_api.py index 1732320c1e9..3c7cc3769c8 100644 --- a/homeassistant/components/blueprint/websocket_api.py +++ b/homeassistant/components/blueprint/websocket_api.py @@ -14,7 +14,7 @@ from homeassistant.util import yaml from . import importer, models from .const import DOMAIN -from .errors import FileAlreadyExists +from .errors import FailedToLoad, FileAlreadyExists @callback @@ -81,6 +81,23 @@ async def ws_import_blueprint( ) return + # Check it exists and if so, which automations are using it + domain = imported_blueprint.blueprint.metadata["domain"] + domain_blueprints: models.DomainBlueprints | None = hass.data.get(DOMAIN, {}).get( + domain + ) + if domain_blueprints is None: + connection.send_error( + msg["id"], websocket_api.ERR_INVALID_FORMAT, "Unsupported domain" + ) + return + + suggested_path = f"{imported_blueprint.suggested_filename}.yaml" + try: + exists = bool(await domain_blueprints.async_get_blueprint(suggested_path)) + except FailedToLoad: + exists = False + connection.send_result( msg["id"], { @@ -90,6 +107,7 @@ async def ws_import_blueprint( "metadata": imported_blueprint.blueprint.metadata, }, "validation_errors": imported_blueprint.blueprint.validate(), + "exists": exists, }, ) @@ -101,6 +119,7 @@ async def ws_import_blueprint( vol.Required("path"): cv.path, vol.Required("yaml"): cv.string, vol.Optional("source_url"): cv.url, + vol.Optional("allow_override"): bool, } ) @websocket_api.async_response @@ -130,8 +149,13 @@ async def ws_save_blueprint( connection.send_error(msg["id"], websocket_api.ERR_INVALID_FORMAT, str(err)) return + if not path.endswith(".yaml"): + path = f"{path}.yaml" + try: - await domain_blueprints[domain].async_add_blueprint(blueprint, path) + overrides_existing = await domain_blueprints[domain].async_add_blueprint( + blueprint, path, allow_override=msg.get("allow_override", False) + ) except FileAlreadyExists: connection.send_error(msg["id"], "already_exists", "File already exists") return @@ -141,6 +165,9 @@ async def ws_save_blueprint( connection.send_result( msg["id"], + { + "overrides_existing": overrides_existing, + }, ) diff --git a/homeassistant/components/script/helpers.py b/homeassistant/components/script/helpers.py index 9f0d4399d3d..4504869e270 100644 --- a/homeassistant/components/script/helpers.py +++ b/homeassistant/components/script/helpers.py @@ -1,5 +1,6 @@ """Helpers for automation integration.""" from homeassistant.components.blueprint import DomainBlueprints +from homeassistant.const import SERVICE_RELOAD from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.singleton import singleton @@ -15,8 +16,15 @@ def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool: return len(scripts_with_blueprint(hass, blueprint_path)) > 0 +async def _reload_blueprint_scripts(hass: HomeAssistant, blueprint_path: str) -> None: + """Reload all script that rely on a specific blueprint.""" + await hass.services.async_call(DOMAIN, SERVICE_RELOAD) + + @singleton(DATA_BLUEPRINTS) @callback def async_get_blueprints(hass: HomeAssistant) -> DomainBlueprints: """Get script blueprints.""" - return DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use) + return DomainBlueprints( + hass, DOMAIN, LOGGER, _blueprint_in_use, _reload_blueprint_scripts + ) diff --git a/tests/components/blueprint/test_models.py b/tests/components/blueprint/test_models.py index b2d3ce517d8..c11a467de9b 100644 --- a/tests/components/blueprint/test_models.py +++ b/tests/components/blueprint/test_models.py @@ -1,6 +1,6 @@ """Test blueprint models.""" import logging -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest @@ -49,7 +49,7 @@ def blueprint_2(): def domain_bps(hass): """Domain blueprints fixture.""" return models.DomainBlueprints( - hass, "automation", logging.getLogger(__name__), None + hass, "automation", logging.getLogger(__name__), None, AsyncMock() ) @@ -257,13 +257,9 @@ async def test_domain_blueprints_inputs_from_config(domain_bps, blueprint_1) -> async def test_domain_blueprints_add_blueprint(domain_bps, blueprint_1) -> None: """Test DomainBlueprints.async_add_blueprint.""" with patch.object(domain_bps, "_create_file") as create_file_mock: - # Should add extension when not present. - await domain_bps.async_add_blueprint(blueprint_1, "something") + await domain_bps.async_add_blueprint(blueprint_1, "something.yaml") assert create_file_mock.call_args[0][1] == "something.yaml" - await domain_bps.async_add_blueprint(blueprint_1, "something2.yaml") - assert create_file_mock.call_args[0][1] == "something2.yaml" - # Should be in cache. with patch.object(domain_bps, "_load_blueprint") as mock_load: assert await domain_bps.async_get_blueprint("something.yaml") == blueprint_1 diff --git a/tests/components/blueprint/test_websocket_api.py b/tests/components/blueprint/test_websocket_api.py index f831445b60c..213dff89597 100644 --- a/tests/components/blueprint/test_websocket_api.py +++ b/tests/components/blueprint/test_websocket_api.py @@ -3,6 +3,7 @@ from pathlib import Path from unittest.mock import Mock, patch import pytest +import yaml from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -129,6 +130,52 @@ async def test_import_blueprint( }, }, "validation_errors": None, + "exists": False, + } + + +async def test_import_blueprint_update( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + hass_ws_client: WebSocketGenerator, + setup_bp, +) -> None: + """Test importing blueprints.""" + raw_data = Path( + hass.config.path("blueprints/automation/in_folder/in_folder_blueprint.yaml") + ).read_text() + + aioclient_mock.get( + "https://raw.githubusercontent.com/in_folder/home-assistant-config/main/blueprints/automation/in_folder_blueprint.yaml", + text=raw_data, + ) + + client = await hass_ws_client(hass) + await client.send_json( + { + "id": 5, + "type": "blueprint/import", + "url": "https://github.com/in_folder/home-assistant-config/blob/main/blueprints/automation/in_folder_blueprint.yaml", + } + ) + + msg = await client.receive_json() + + assert msg["id"] == 5 + assert msg["success"] + assert msg["result"] == { + "suggested_filename": "in_folder/in_folder_blueprint", + "raw_data": raw_data, + "blueprint": { + "metadata": { + "domain": "automation", + "input": {"action": None, "trigger": None}, + "name": "In Folder Blueprint", + "source_url": "https://github.com/in_folder/home-assistant-config/blob/main/blueprints/automation/in_folder_blueprint.yaml", + } + }, + "validation_errors": None, + "exists": True, } @@ -212,6 +259,42 @@ async def test_save_existing_file( assert msg["error"] == {"code": "already_exists", "message": "File already exists"} +async def test_save_existing_file_override( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test saving blueprints.""" + + client = await hass_ws_client(hass) + with patch("pathlib.Path.write_text") as write_mock: + await client.send_json( + { + "id": 7, + "type": "blueprint/save", + "path": "test_event_service", + "yaml": 'blueprint: {name: "name", domain: "automation"}', + "domain": "automation", + "source_url": "https://github.com/balloob/home-assistant-config/blob/main/blueprints/automation/test_event_service.yaml", + "allow_override": True, + } + ) + + msg = await client.receive_json() + + assert msg["id"] == 7 + assert msg["success"] + assert msg["result"] == {"overrides_existing": True} + assert yaml.safe_load(write_mock.mock_calls[0][1][0]) == { + "blueprint": { + "name": "name", + "domain": "automation", + "source_url": "https://github.com/balloob/home-assistant-config/blob/main/blueprints/automation/test_event_service.yaml", + "input": {}, + } + } + + async def test_save_file_error( hass: HomeAssistant, aioclient_mock: AiohttpClientMocker,