Allow overriding blueprints on import (#103340)

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Paulus Schoutsen 2023-11-25 05:49:50 -05:00 committed by GitHub
parent af7155df7a
commit 1cfbdd6a5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 159 additions and 23 deletions

View File

@ -1,5 +1,6 @@
"""Helpers for automation integration."""
from homeassistant.components import blueprint
from homeassistant.const import SERVICE_RELOAD
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.singleton import singleton
@ -15,8 +16,17 @@ def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool:
return len(automations_with_blueprint(hass, blueprint_path)) > 0
async def _reload_blueprint_automations(
hass: HomeAssistant, blueprint_path: str
) -> None:
"""Reload all automations that rely on a specific blueprint."""
await hass.services.async_call(DOMAIN, SERVICE_RELOAD)
@singleton(DATA_BLUEPRINTS)
@callback
def async_get_blueprints(hass: HomeAssistant) -> blueprint.DomainBlueprints:
"""Get automation blueprints."""
return blueprint.DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use)
return blueprint.DomainBlueprints(
hass, DOMAIN, LOGGER, _blueprint_in_use, _reload_blueprint_automations
)

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from collections.abc import Awaitable, Callable
import logging
import pathlib
import shutil
@ -189,12 +189,14 @@ class DomainBlueprints:
domain: str,
logger: logging.Logger,
blueprint_in_use: Callable[[HomeAssistant, str], bool],
reload_blueprint_consumers: Callable[[HomeAssistant, str], Awaitable[None]],
) -> None:
"""Initialize a domain blueprints instance."""
self.hass = hass
self.domain = domain
self.logger = logger
self._blueprint_in_use = blueprint_in_use
self._reload_blueprint_consumers = reload_blueprint_consumers
self._blueprints: dict[str, Blueprint | None] = {}
self._load_lock = asyncio.Lock()
@ -283,7 +285,7 @@ class DomainBlueprints:
blueprint = await self.hass.async_add_executor_job(
self._load_blueprint, blueprint_path
)
except Exception:
except FailedToLoad:
self._blueprints[blueprint_path] = None
raise
@ -315,31 +317,41 @@ class DomainBlueprints:
await self.hass.async_add_executor_job(path.unlink)
self._blueprints[blueprint_path] = None
def _create_file(self, blueprint: Blueprint, blueprint_path: str) -> None:
"""Create blueprint file."""
def _create_file(
self, blueprint: Blueprint, blueprint_path: str, allow_override: bool
) -> bool:
"""Create blueprint file.
Returns true if the action overrides an existing blueprint.
"""
path = pathlib.Path(
self.hass.config.path(BLUEPRINT_FOLDER, self.domain, blueprint_path)
)
if path.exists():
exists = path.exists()
if not allow_override and exists:
raise FileAlreadyExists(self.domain, blueprint_path)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(blueprint.yaml(), encoding="utf-8")
return exists
async def async_add_blueprint(
self, blueprint: Blueprint, blueprint_path: str
) -> None:
self, blueprint: Blueprint, blueprint_path: str, allow_override=False
) -> bool:
"""Add a blueprint."""
if not blueprint_path.endswith(".yaml"):
blueprint_path = f"{blueprint_path}.yaml"
await self.hass.async_add_executor_job(
self._create_file, blueprint, blueprint_path
overrides_existing = await self.hass.async_add_executor_job(
self._create_file, blueprint, blueprint_path, allow_override
)
self._blueprints[blueprint_path] = blueprint
if overrides_existing:
await self._reload_blueprint_consumers(self.hass, blueprint_path)
return overrides_existing
async def async_populate(self) -> None:
"""Create folder if it doesn't exist and populate with examples."""
if self._blueprints:

View File

@ -14,7 +14,7 @@ from homeassistant.util import yaml
from . import importer, models
from .const import DOMAIN
from .errors import FileAlreadyExists
from .errors import FailedToLoad, FileAlreadyExists
@callback
@ -81,6 +81,23 @@ async def ws_import_blueprint(
)
return
# Check it exists and if so, which automations are using it
domain = imported_blueprint.blueprint.metadata["domain"]
domain_blueprints: models.DomainBlueprints | None = hass.data.get(DOMAIN, {}).get(
domain
)
if domain_blueprints is None:
connection.send_error(
msg["id"], websocket_api.ERR_INVALID_FORMAT, "Unsupported domain"
)
return
suggested_path = f"{imported_blueprint.suggested_filename}.yaml"
try:
exists = bool(await domain_blueprints.async_get_blueprint(suggested_path))
except FailedToLoad:
exists = False
connection.send_result(
msg["id"],
{
@ -90,6 +107,7 @@ async def ws_import_blueprint(
"metadata": imported_blueprint.blueprint.metadata,
},
"validation_errors": imported_blueprint.blueprint.validate(),
"exists": exists,
},
)
@ -101,6 +119,7 @@ async def ws_import_blueprint(
vol.Required("path"): cv.path,
vol.Required("yaml"): cv.string,
vol.Optional("source_url"): cv.url,
vol.Optional("allow_override"): bool,
}
)
@websocket_api.async_response
@ -130,8 +149,13 @@ async def ws_save_blueprint(
connection.send_error(msg["id"], websocket_api.ERR_INVALID_FORMAT, str(err))
return
if not path.endswith(".yaml"):
path = f"{path}.yaml"
try:
await domain_blueprints[domain].async_add_blueprint(blueprint, path)
overrides_existing = await domain_blueprints[domain].async_add_blueprint(
blueprint, path, allow_override=msg.get("allow_override", False)
)
except FileAlreadyExists:
connection.send_error(msg["id"], "already_exists", "File already exists")
return
@ -141,6 +165,9 @@ async def ws_save_blueprint(
connection.send_result(
msg["id"],
{
"overrides_existing": overrides_existing,
},
)

View File

@ -1,5 +1,6 @@
"""Helpers for automation integration."""
from homeassistant.components.blueprint import DomainBlueprints
from homeassistant.const import SERVICE_RELOAD
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.singleton import singleton
@ -15,8 +16,15 @@ def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool:
return len(scripts_with_blueprint(hass, blueprint_path)) > 0
async def _reload_blueprint_scripts(hass: HomeAssistant, blueprint_path: str) -> None:
"""Reload all script that rely on a specific blueprint."""
await hass.services.async_call(DOMAIN, SERVICE_RELOAD)
@singleton(DATA_BLUEPRINTS)
@callback
def async_get_blueprints(hass: HomeAssistant) -> DomainBlueprints:
"""Get script blueprints."""
return DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use)
return DomainBlueprints(
hass, DOMAIN, LOGGER, _blueprint_in_use, _reload_blueprint_scripts
)

View File

@ -1,6 +1,6 @@
"""Test blueprint models."""
import logging
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
import pytest
@ -49,7 +49,7 @@ def blueprint_2():
def domain_bps(hass):
"""Domain blueprints fixture."""
return models.DomainBlueprints(
hass, "automation", logging.getLogger(__name__), None
hass, "automation", logging.getLogger(__name__), None, AsyncMock()
)
@ -257,13 +257,9 @@ async def test_domain_blueprints_inputs_from_config(domain_bps, blueprint_1) ->
async def test_domain_blueprints_add_blueprint(domain_bps, blueprint_1) -> None:
"""Test DomainBlueprints.async_add_blueprint."""
with patch.object(domain_bps, "_create_file") as create_file_mock:
# Should add extension when not present.
await domain_bps.async_add_blueprint(blueprint_1, "something")
await domain_bps.async_add_blueprint(blueprint_1, "something.yaml")
assert create_file_mock.call_args[0][1] == "something.yaml"
await domain_bps.async_add_blueprint(blueprint_1, "something2.yaml")
assert create_file_mock.call_args[0][1] == "something2.yaml"
# Should be in cache.
with patch.object(domain_bps, "_load_blueprint") as mock_load:
assert await domain_bps.async_get_blueprint("something.yaml") == blueprint_1

View File

@ -3,6 +3,7 @@ from pathlib import Path
from unittest.mock import Mock, patch
import pytest
import yaml
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
@ -129,6 +130,52 @@ async def test_import_blueprint(
},
},
"validation_errors": None,
"exists": False,
}
async def test_import_blueprint_update(
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_ws_client: WebSocketGenerator,
setup_bp,
) -> None:
"""Test importing blueprints."""
raw_data = Path(
hass.config.path("blueprints/automation/in_folder/in_folder_blueprint.yaml")
).read_text()
aioclient_mock.get(
"https://raw.githubusercontent.com/in_folder/home-assistant-config/main/blueprints/automation/in_folder_blueprint.yaml",
text=raw_data,
)
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "blueprint/import",
"url": "https://github.com/in_folder/home-assistant-config/blob/main/blueprints/automation/in_folder_blueprint.yaml",
}
)
msg = await client.receive_json()
assert msg["id"] == 5
assert msg["success"]
assert msg["result"] == {
"suggested_filename": "in_folder/in_folder_blueprint",
"raw_data": raw_data,
"blueprint": {
"metadata": {
"domain": "automation",
"input": {"action": None, "trigger": None},
"name": "In Folder Blueprint",
"source_url": "https://github.com/in_folder/home-assistant-config/blob/main/blueprints/automation/in_folder_blueprint.yaml",
}
},
"validation_errors": None,
"exists": True,
}
@ -212,6 +259,42 @@ async def test_save_existing_file(
assert msg["error"] == {"code": "already_exists", "message": "File already exists"}
async def test_save_existing_file_override(
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test saving blueprints."""
client = await hass_ws_client(hass)
with patch("pathlib.Path.write_text") as write_mock:
await client.send_json(
{
"id": 7,
"type": "blueprint/save",
"path": "test_event_service",
"yaml": 'blueprint: {name: "name", domain: "automation"}',
"domain": "automation",
"source_url": "https://github.com/balloob/home-assistant-config/blob/main/blueprints/automation/test_event_service.yaml",
"allow_override": True,
}
)
msg = await client.receive_json()
assert msg["id"] == 7
assert msg["success"]
assert msg["result"] == {"overrides_existing": True}
assert yaml.safe_load(write_mock.mock_calls[0][1][0]) == {
"blueprint": {
"name": "name",
"domain": "automation",
"source_url": "https://github.com/balloob/home-assistant-config/blob/main/blueprints/automation/test_event_service.yaml",
"input": {},
}
}
async def test_save_file_error(
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,