Allow overriding blueprints on import (#103340)

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Paulus Schoutsen 2023-11-25 05:49:50 -05:00 committed by GitHub
parent af7155df7a
commit 1cfbdd6a5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 159 additions and 23 deletions

View File

@ -1,5 +1,6 @@
"""Helpers for automation integration.""" """Helpers for automation integration."""
from homeassistant.components import blueprint from homeassistant.components import blueprint
from homeassistant.const import SERVICE_RELOAD
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.singleton import singleton from homeassistant.helpers.singleton import singleton
@ -15,8 +16,17 @@ def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool:
return len(automations_with_blueprint(hass, blueprint_path)) > 0 return len(automations_with_blueprint(hass, blueprint_path)) > 0
async def _reload_blueprint_automations(
hass: HomeAssistant, blueprint_path: str
) -> None:
"""Reload all automations that rely on a specific blueprint."""
await hass.services.async_call(DOMAIN, SERVICE_RELOAD)
@singleton(DATA_BLUEPRINTS) @singleton(DATA_BLUEPRINTS)
@callback @callback
def async_get_blueprints(hass: HomeAssistant) -> blueprint.DomainBlueprints: def async_get_blueprints(hass: HomeAssistant) -> blueprint.DomainBlueprints:
"""Get automation blueprints.""" """Get automation blueprints."""
return blueprint.DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use) return blueprint.DomainBlueprints(
hass, DOMAIN, LOGGER, _blueprint_in_use, _reload_blueprint_automations
)

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable from collections.abc import Awaitable, Callable
import logging import logging
import pathlib import pathlib
import shutil import shutil
@ -189,12 +189,14 @@ class DomainBlueprints:
domain: str, domain: str,
logger: logging.Logger, logger: logging.Logger,
blueprint_in_use: Callable[[HomeAssistant, str], bool], blueprint_in_use: Callable[[HomeAssistant, str], bool],
reload_blueprint_consumers: Callable[[HomeAssistant, str], Awaitable[None]],
) -> None: ) -> None:
"""Initialize a domain blueprints instance.""" """Initialize a domain blueprints instance."""
self.hass = hass self.hass = hass
self.domain = domain self.domain = domain
self.logger = logger self.logger = logger
self._blueprint_in_use = blueprint_in_use self._blueprint_in_use = blueprint_in_use
self._reload_blueprint_consumers = reload_blueprint_consumers
self._blueprints: dict[str, Blueprint | None] = {} self._blueprints: dict[str, Blueprint | None] = {}
self._load_lock = asyncio.Lock() self._load_lock = asyncio.Lock()
@ -283,7 +285,7 @@ class DomainBlueprints:
blueprint = await self.hass.async_add_executor_job( blueprint = await self.hass.async_add_executor_job(
self._load_blueprint, blueprint_path self._load_blueprint, blueprint_path
) )
except Exception: except FailedToLoad:
self._blueprints[blueprint_path] = None self._blueprints[blueprint_path] = None
raise raise
@ -315,31 +317,41 @@ class DomainBlueprints:
await self.hass.async_add_executor_job(path.unlink) await self.hass.async_add_executor_job(path.unlink)
self._blueprints[blueprint_path] = None self._blueprints[blueprint_path] = None
def _create_file(self, blueprint: Blueprint, blueprint_path: str) -> None: def _create_file(
"""Create blueprint file.""" self, blueprint: Blueprint, blueprint_path: str, allow_override: bool
) -> bool:
"""Create blueprint file.
Returns true if the action overrides an existing blueprint.
"""
path = pathlib.Path( path = pathlib.Path(
self.hass.config.path(BLUEPRINT_FOLDER, self.domain, blueprint_path) self.hass.config.path(BLUEPRINT_FOLDER, self.domain, blueprint_path)
) )
if path.exists(): exists = path.exists()
if not allow_override and exists:
raise FileAlreadyExists(self.domain, blueprint_path) raise FileAlreadyExists(self.domain, blueprint_path)
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(blueprint.yaml(), encoding="utf-8") path.write_text(blueprint.yaml(), encoding="utf-8")
return exists
async def async_add_blueprint( async def async_add_blueprint(
self, blueprint: Blueprint, blueprint_path: str self, blueprint: Blueprint, blueprint_path: str, allow_override=False
) -> None: ) -> bool:
"""Add a blueprint.""" """Add a blueprint."""
if not blueprint_path.endswith(".yaml"): overrides_existing = await self.hass.async_add_executor_job(
blueprint_path = f"{blueprint_path}.yaml" self._create_file, blueprint, blueprint_path, allow_override
await self.hass.async_add_executor_job(
self._create_file, blueprint, blueprint_path
) )
self._blueprints[blueprint_path] = blueprint self._blueprints[blueprint_path] = blueprint
if overrides_existing:
await self._reload_blueprint_consumers(self.hass, blueprint_path)
return overrides_existing
async def async_populate(self) -> None: async def async_populate(self) -> None:
"""Create folder if it doesn't exist and populate with examples.""" """Create folder if it doesn't exist and populate with examples."""
if self._blueprints: if self._blueprints:

View File

@ -14,7 +14,7 @@ from homeassistant.util import yaml
from . import importer, models from . import importer, models
from .const import DOMAIN from .const import DOMAIN
from .errors import FileAlreadyExists from .errors import FailedToLoad, FileAlreadyExists
@callback @callback
@ -81,6 +81,23 @@ async def ws_import_blueprint(
) )
return return
# Check it exists and if so, which automations are using it
domain = imported_blueprint.blueprint.metadata["domain"]
domain_blueprints: models.DomainBlueprints | None = hass.data.get(DOMAIN, {}).get(
domain
)
if domain_blueprints is None:
connection.send_error(
msg["id"], websocket_api.ERR_INVALID_FORMAT, "Unsupported domain"
)
return
suggested_path = f"{imported_blueprint.suggested_filename}.yaml"
try:
exists = bool(await domain_blueprints.async_get_blueprint(suggested_path))
except FailedToLoad:
exists = False
connection.send_result( connection.send_result(
msg["id"], msg["id"],
{ {
@ -90,6 +107,7 @@ async def ws_import_blueprint(
"metadata": imported_blueprint.blueprint.metadata, "metadata": imported_blueprint.blueprint.metadata,
}, },
"validation_errors": imported_blueprint.blueprint.validate(), "validation_errors": imported_blueprint.blueprint.validate(),
"exists": exists,
}, },
) )
@ -101,6 +119,7 @@ async def ws_import_blueprint(
vol.Required("path"): cv.path, vol.Required("path"): cv.path,
vol.Required("yaml"): cv.string, vol.Required("yaml"): cv.string,
vol.Optional("source_url"): cv.url, vol.Optional("source_url"): cv.url,
vol.Optional("allow_override"): bool,
} }
) )
@websocket_api.async_response @websocket_api.async_response
@ -130,8 +149,13 @@ async def ws_save_blueprint(
connection.send_error(msg["id"], websocket_api.ERR_INVALID_FORMAT, str(err)) connection.send_error(msg["id"], websocket_api.ERR_INVALID_FORMAT, str(err))
return return
if not path.endswith(".yaml"):
path = f"{path}.yaml"
try: try:
await domain_blueprints[domain].async_add_blueprint(blueprint, path) overrides_existing = await domain_blueprints[domain].async_add_blueprint(
blueprint, path, allow_override=msg.get("allow_override", False)
)
except FileAlreadyExists: except FileAlreadyExists:
connection.send_error(msg["id"], "already_exists", "File already exists") connection.send_error(msg["id"], "already_exists", "File already exists")
return return
@ -141,6 +165,9 @@ async def ws_save_blueprint(
connection.send_result( connection.send_result(
msg["id"], msg["id"],
{
"overrides_existing": overrides_existing,
},
) )

View File

@ -1,5 +1,6 @@
"""Helpers for automation integration.""" """Helpers for automation integration."""
from homeassistant.components.blueprint import DomainBlueprints from homeassistant.components.blueprint import DomainBlueprints
from homeassistant.const import SERVICE_RELOAD
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.singleton import singleton from homeassistant.helpers.singleton import singleton
@ -15,8 +16,15 @@ def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool:
return len(scripts_with_blueprint(hass, blueprint_path)) > 0 return len(scripts_with_blueprint(hass, blueprint_path)) > 0
async def _reload_blueprint_scripts(hass: HomeAssistant, blueprint_path: str) -> None:
"""Reload all script that rely on a specific blueprint."""
await hass.services.async_call(DOMAIN, SERVICE_RELOAD)
@singleton(DATA_BLUEPRINTS) @singleton(DATA_BLUEPRINTS)
@callback @callback
def async_get_blueprints(hass: HomeAssistant) -> DomainBlueprints: def async_get_blueprints(hass: HomeAssistant) -> DomainBlueprints:
"""Get script blueprints.""" """Get script blueprints."""
return DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use) return DomainBlueprints(
hass, DOMAIN, LOGGER, _blueprint_in_use, _reload_blueprint_scripts
)

View File

@ -1,6 +1,6 @@
"""Test blueprint models.""" """Test blueprint models."""
import logging import logging
from unittest.mock import patch from unittest.mock import AsyncMock, patch
import pytest import pytest
@ -49,7 +49,7 @@ def blueprint_2():
def domain_bps(hass): def domain_bps(hass):
"""Domain blueprints fixture.""" """Domain blueprints fixture."""
return models.DomainBlueprints( return models.DomainBlueprints(
hass, "automation", logging.getLogger(__name__), None hass, "automation", logging.getLogger(__name__), None, AsyncMock()
) )
@ -257,13 +257,9 @@ async def test_domain_blueprints_inputs_from_config(domain_bps, blueprint_1) ->
async def test_domain_blueprints_add_blueprint(domain_bps, blueprint_1) -> None: async def test_domain_blueprints_add_blueprint(domain_bps, blueprint_1) -> None:
"""Test DomainBlueprints.async_add_blueprint.""" """Test DomainBlueprints.async_add_blueprint."""
with patch.object(domain_bps, "_create_file") as create_file_mock: with patch.object(domain_bps, "_create_file") as create_file_mock:
# Should add extension when not present. await domain_bps.async_add_blueprint(blueprint_1, "something.yaml")
await domain_bps.async_add_blueprint(blueprint_1, "something")
assert create_file_mock.call_args[0][1] == "something.yaml" assert create_file_mock.call_args[0][1] == "something.yaml"
await domain_bps.async_add_blueprint(blueprint_1, "something2.yaml")
assert create_file_mock.call_args[0][1] == "something2.yaml"
# Should be in cache. # Should be in cache.
with patch.object(domain_bps, "_load_blueprint") as mock_load: with patch.object(domain_bps, "_load_blueprint") as mock_load:
assert await domain_bps.async_get_blueprint("something.yaml") == blueprint_1 assert await domain_bps.async_get_blueprint("something.yaml") == blueprint_1

View File

@ -3,6 +3,7 @@ from pathlib import Path
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
import yaml
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -129,6 +130,52 @@ async def test_import_blueprint(
}, },
}, },
"validation_errors": None, "validation_errors": None,
"exists": False,
}
async def test_import_blueprint_update(
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_ws_client: WebSocketGenerator,
setup_bp,
) -> None:
"""Test importing blueprints."""
raw_data = Path(
hass.config.path("blueprints/automation/in_folder/in_folder_blueprint.yaml")
).read_text()
aioclient_mock.get(
"https://raw.githubusercontent.com/in_folder/home-assistant-config/main/blueprints/automation/in_folder_blueprint.yaml",
text=raw_data,
)
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "blueprint/import",
"url": "https://github.com/in_folder/home-assistant-config/blob/main/blueprints/automation/in_folder_blueprint.yaml",
}
)
msg = await client.receive_json()
assert msg["id"] == 5
assert msg["success"]
assert msg["result"] == {
"suggested_filename": "in_folder/in_folder_blueprint",
"raw_data": raw_data,
"blueprint": {
"metadata": {
"domain": "automation",
"input": {"action": None, "trigger": None},
"name": "In Folder Blueprint",
"source_url": "https://github.com/in_folder/home-assistant-config/blob/main/blueprints/automation/in_folder_blueprint.yaml",
}
},
"validation_errors": None,
"exists": True,
} }
@ -212,6 +259,42 @@ async def test_save_existing_file(
assert msg["error"] == {"code": "already_exists", "message": "File already exists"} assert msg["error"] == {"code": "already_exists", "message": "File already exists"}
async def test_save_existing_file_override(
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test saving blueprints."""
client = await hass_ws_client(hass)
with patch("pathlib.Path.write_text") as write_mock:
await client.send_json(
{
"id": 7,
"type": "blueprint/save",
"path": "test_event_service",
"yaml": 'blueprint: {name: "name", domain: "automation"}',
"domain": "automation",
"source_url": "https://github.com/balloob/home-assistant-config/blob/main/blueprints/automation/test_event_service.yaml",
"allow_override": True,
}
)
msg = await client.receive_json()
assert msg["id"] == 7
assert msg["success"]
assert msg["result"] == {"overrides_existing": True}
assert yaml.safe_load(write_mock.mock_calls[0][1][0]) == {
"blueprint": {
"name": "name",
"domain": "automation",
"source_url": "https://github.com/balloob/home-assistant-config/blob/main/blueprints/automation/test_event_service.yaml",
"input": {},
}
}
async def test_save_file_error( async def test_save_file_error(
hass: HomeAssistant, hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker, aioclient_mock: AiohttpClientMocker,