mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Allow overriding blueprints on import (#103340)
Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
parent
af7155df7a
commit
1cfbdd6a5d
@ -1,5 +1,6 @@
|
|||||||
"""Helpers for automation integration."""
|
"""Helpers for automation integration."""
|
||||||
from homeassistant.components import blueprint
|
from homeassistant.components import blueprint
|
||||||
|
from homeassistant.const import SERVICE_RELOAD
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers.singleton import singleton
|
from homeassistant.helpers.singleton import singleton
|
||||||
|
|
||||||
@ -15,8 +16,17 @@ def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool:
|
|||||||
return len(automations_with_blueprint(hass, blueprint_path)) > 0
|
return len(automations_with_blueprint(hass, blueprint_path)) > 0
|
||||||
|
|
||||||
|
|
||||||
|
async def _reload_blueprint_automations(
|
||||||
|
hass: HomeAssistant, blueprint_path: str
|
||||||
|
) -> None:
|
||||||
|
"""Reload all automations that rely on a specific blueprint."""
|
||||||
|
await hass.services.async_call(DOMAIN, SERVICE_RELOAD)
|
||||||
|
|
||||||
|
|
||||||
@singleton(DATA_BLUEPRINTS)
|
@singleton(DATA_BLUEPRINTS)
|
||||||
@callback
|
@callback
|
||||||
def async_get_blueprints(hass: HomeAssistant) -> blueprint.DomainBlueprints:
|
def async_get_blueprints(hass: HomeAssistant) -> blueprint.DomainBlueprints:
|
||||||
"""Get automation blueprints."""
|
"""Get automation blueprints."""
|
||||||
return blueprint.DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use)
|
return blueprint.DomainBlueprints(
|
||||||
|
hass, DOMAIN, LOGGER, _blueprint_in_use, _reload_blueprint_automations
|
||||||
|
)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import Awaitable, Callable
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import shutil
|
import shutil
|
||||||
@ -189,12 +189,14 @@ class DomainBlueprints:
|
|||||||
domain: str,
|
domain: str,
|
||||||
logger: logging.Logger,
|
logger: logging.Logger,
|
||||||
blueprint_in_use: Callable[[HomeAssistant, str], bool],
|
blueprint_in_use: Callable[[HomeAssistant, str], bool],
|
||||||
|
reload_blueprint_consumers: Callable[[HomeAssistant, str], Awaitable[None]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a domain blueprints instance."""
|
"""Initialize a domain blueprints instance."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.domain = domain
|
self.domain = domain
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self._blueprint_in_use = blueprint_in_use
|
self._blueprint_in_use = blueprint_in_use
|
||||||
|
self._reload_blueprint_consumers = reload_blueprint_consumers
|
||||||
self._blueprints: dict[str, Blueprint | None] = {}
|
self._blueprints: dict[str, Blueprint | None] = {}
|
||||||
self._load_lock = asyncio.Lock()
|
self._load_lock = asyncio.Lock()
|
||||||
|
|
||||||
@ -283,7 +285,7 @@ class DomainBlueprints:
|
|||||||
blueprint = await self.hass.async_add_executor_job(
|
blueprint = await self.hass.async_add_executor_job(
|
||||||
self._load_blueprint, blueprint_path
|
self._load_blueprint, blueprint_path
|
||||||
)
|
)
|
||||||
except Exception:
|
except FailedToLoad:
|
||||||
self._blueprints[blueprint_path] = None
|
self._blueprints[blueprint_path] = None
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -315,31 +317,41 @@ class DomainBlueprints:
|
|||||||
await self.hass.async_add_executor_job(path.unlink)
|
await self.hass.async_add_executor_job(path.unlink)
|
||||||
self._blueprints[blueprint_path] = None
|
self._blueprints[blueprint_path] = None
|
||||||
|
|
||||||
def _create_file(self, blueprint: Blueprint, blueprint_path: str) -> None:
|
def _create_file(
|
||||||
"""Create blueprint file."""
|
self, blueprint: Blueprint, blueprint_path: str, allow_override: bool
|
||||||
|
) -> bool:
|
||||||
|
"""Create blueprint file.
|
||||||
|
|
||||||
|
Returns true if the action overrides an existing blueprint.
|
||||||
|
"""
|
||||||
|
|
||||||
path = pathlib.Path(
|
path = pathlib.Path(
|
||||||
self.hass.config.path(BLUEPRINT_FOLDER, self.domain, blueprint_path)
|
self.hass.config.path(BLUEPRINT_FOLDER, self.domain, blueprint_path)
|
||||||
)
|
)
|
||||||
if path.exists():
|
exists = path.exists()
|
||||||
|
|
||||||
|
if not allow_override and exists:
|
||||||
raise FileAlreadyExists(self.domain, blueprint_path)
|
raise FileAlreadyExists(self.domain, blueprint_path)
|
||||||
|
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
path.write_text(blueprint.yaml(), encoding="utf-8")
|
path.write_text(blueprint.yaml(), encoding="utf-8")
|
||||||
|
return exists
|
||||||
|
|
||||||
async def async_add_blueprint(
|
async def async_add_blueprint(
|
||||||
self, blueprint: Blueprint, blueprint_path: str
|
self, blueprint: Blueprint, blueprint_path: str, allow_override=False
|
||||||
) -> None:
|
) -> bool:
|
||||||
"""Add a blueprint."""
|
"""Add a blueprint."""
|
||||||
if not blueprint_path.endswith(".yaml"):
|
overrides_existing = await self.hass.async_add_executor_job(
|
||||||
blueprint_path = f"{blueprint_path}.yaml"
|
self._create_file, blueprint, blueprint_path, allow_override
|
||||||
|
|
||||||
await self.hass.async_add_executor_job(
|
|
||||||
self._create_file, blueprint, blueprint_path
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._blueprints[blueprint_path] = blueprint
|
self._blueprints[blueprint_path] = blueprint
|
||||||
|
|
||||||
|
if overrides_existing:
|
||||||
|
await self._reload_blueprint_consumers(self.hass, blueprint_path)
|
||||||
|
|
||||||
|
return overrides_existing
|
||||||
|
|
||||||
async def async_populate(self) -> None:
|
async def async_populate(self) -> None:
|
||||||
"""Create folder if it doesn't exist and populate with examples."""
|
"""Create folder if it doesn't exist and populate with examples."""
|
||||||
if self._blueprints:
|
if self._blueprints:
|
||||||
|
@ -14,7 +14,7 @@ from homeassistant.util import yaml
|
|||||||
|
|
||||||
from . import importer, models
|
from . import importer, models
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .errors import FileAlreadyExists
|
from .errors import FailedToLoad, FileAlreadyExists
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -81,6 +81,23 @@ async def ws_import_blueprint(
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Check it exists and if so, which automations are using it
|
||||||
|
domain = imported_blueprint.blueprint.metadata["domain"]
|
||||||
|
domain_blueprints: models.DomainBlueprints | None = hass.data.get(DOMAIN, {}).get(
|
||||||
|
domain
|
||||||
|
)
|
||||||
|
if domain_blueprints is None:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"], websocket_api.ERR_INVALID_FORMAT, "Unsupported domain"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
suggested_path = f"{imported_blueprint.suggested_filename}.yaml"
|
||||||
|
try:
|
||||||
|
exists = bool(await domain_blueprints.async_get_blueprint(suggested_path))
|
||||||
|
except FailedToLoad:
|
||||||
|
exists = False
|
||||||
|
|
||||||
connection.send_result(
|
connection.send_result(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
{
|
{
|
||||||
@ -90,6 +107,7 @@ async def ws_import_blueprint(
|
|||||||
"metadata": imported_blueprint.blueprint.metadata,
|
"metadata": imported_blueprint.blueprint.metadata,
|
||||||
},
|
},
|
||||||
"validation_errors": imported_blueprint.blueprint.validate(),
|
"validation_errors": imported_blueprint.blueprint.validate(),
|
||||||
|
"exists": exists,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -101,6 +119,7 @@ async def ws_import_blueprint(
|
|||||||
vol.Required("path"): cv.path,
|
vol.Required("path"): cv.path,
|
||||||
vol.Required("yaml"): cv.string,
|
vol.Required("yaml"): cv.string,
|
||||||
vol.Optional("source_url"): cv.url,
|
vol.Optional("source_url"): cv.url,
|
||||||
|
vol.Optional("allow_override"): bool,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@websocket_api.async_response
|
@websocket_api.async_response
|
||||||
@ -130,8 +149,13 @@ async def ws_save_blueprint(
|
|||||||
connection.send_error(msg["id"], websocket_api.ERR_INVALID_FORMAT, str(err))
|
connection.send_error(msg["id"], websocket_api.ERR_INVALID_FORMAT, str(err))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not path.endswith(".yaml"):
|
||||||
|
path = f"{path}.yaml"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await domain_blueprints[domain].async_add_blueprint(blueprint, path)
|
overrides_existing = await domain_blueprints[domain].async_add_blueprint(
|
||||||
|
blueprint, path, allow_override=msg.get("allow_override", False)
|
||||||
|
)
|
||||||
except FileAlreadyExists:
|
except FileAlreadyExists:
|
||||||
connection.send_error(msg["id"], "already_exists", "File already exists")
|
connection.send_error(msg["id"], "already_exists", "File already exists")
|
||||||
return
|
return
|
||||||
@ -141,6 +165,9 @@ async def ws_save_blueprint(
|
|||||||
|
|
||||||
connection.send_result(
|
connection.send_result(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
|
{
|
||||||
|
"overrides_existing": overrides_existing,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Helpers for automation integration."""
|
"""Helpers for automation integration."""
|
||||||
from homeassistant.components.blueprint import DomainBlueprints
|
from homeassistant.components.blueprint import DomainBlueprints
|
||||||
|
from homeassistant.const import SERVICE_RELOAD
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers.singleton import singleton
|
from homeassistant.helpers.singleton import singleton
|
||||||
|
|
||||||
@ -15,8 +16,15 @@ def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool:
|
|||||||
return len(scripts_with_blueprint(hass, blueprint_path)) > 0
|
return len(scripts_with_blueprint(hass, blueprint_path)) > 0
|
||||||
|
|
||||||
|
|
||||||
|
async def _reload_blueprint_scripts(hass: HomeAssistant, blueprint_path: str) -> None:
|
||||||
|
"""Reload all script that rely on a specific blueprint."""
|
||||||
|
await hass.services.async_call(DOMAIN, SERVICE_RELOAD)
|
||||||
|
|
||||||
|
|
||||||
@singleton(DATA_BLUEPRINTS)
|
@singleton(DATA_BLUEPRINTS)
|
||||||
@callback
|
@callback
|
||||||
def async_get_blueprints(hass: HomeAssistant) -> DomainBlueprints:
|
def async_get_blueprints(hass: HomeAssistant) -> DomainBlueprints:
|
||||||
"""Get script blueprints."""
|
"""Get script blueprints."""
|
||||||
return DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use)
|
return DomainBlueprints(
|
||||||
|
hass, DOMAIN, LOGGER, _blueprint_in_use, _reload_blueprint_scripts
|
||||||
|
)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Test blueprint models."""
|
"""Test blueprint models."""
|
||||||
import logging
|
import logging
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -49,7 +49,7 @@ def blueprint_2():
|
|||||||
def domain_bps(hass):
|
def domain_bps(hass):
|
||||||
"""Domain blueprints fixture."""
|
"""Domain blueprints fixture."""
|
||||||
return models.DomainBlueprints(
|
return models.DomainBlueprints(
|
||||||
hass, "automation", logging.getLogger(__name__), None
|
hass, "automation", logging.getLogger(__name__), None, AsyncMock()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -257,13 +257,9 @@ async def test_domain_blueprints_inputs_from_config(domain_bps, blueprint_1) ->
|
|||||||
async def test_domain_blueprints_add_blueprint(domain_bps, blueprint_1) -> None:
|
async def test_domain_blueprints_add_blueprint(domain_bps, blueprint_1) -> None:
|
||||||
"""Test DomainBlueprints.async_add_blueprint."""
|
"""Test DomainBlueprints.async_add_blueprint."""
|
||||||
with patch.object(domain_bps, "_create_file") as create_file_mock:
|
with patch.object(domain_bps, "_create_file") as create_file_mock:
|
||||||
# Should add extension when not present.
|
await domain_bps.async_add_blueprint(blueprint_1, "something.yaml")
|
||||||
await domain_bps.async_add_blueprint(blueprint_1, "something")
|
|
||||||
assert create_file_mock.call_args[0][1] == "something.yaml"
|
assert create_file_mock.call_args[0][1] == "something.yaml"
|
||||||
|
|
||||||
await domain_bps.async_add_blueprint(blueprint_1, "something2.yaml")
|
|
||||||
assert create_file_mock.call_args[0][1] == "something2.yaml"
|
|
||||||
|
|
||||||
# Should be in cache.
|
# Should be in cache.
|
||||||
with patch.object(domain_bps, "_load_blueprint") as mock_load:
|
with patch.object(domain_bps, "_load_blueprint") as mock_load:
|
||||||
assert await domain_bps.async_get_blueprint("something.yaml") == blueprint_1
|
assert await domain_bps.async_get_blueprint("something.yaml") == blueprint_1
|
||||||
|
@ -3,6 +3,7 @@ from pathlib import Path
|
|||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
@ -129,6 +130,52 @@ async def test_import_blueprint(
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"validation_errors": None,
|
"validation_errors": None,
|
||||||
|
"exists": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_import_blueprint_update(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
aioclient_mock: AiohttpClientMocker,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
setup_bp,
|
||||||
|
) -> None:
|
||||||
|
"""Test importing blueprints."""
|
||||||
|
raw_data = Path(
|
||||||
|
hass.config.path("blueprints/automation/in_folder/in_folder_blueprint.yaml")
|
||||||
|
).read_text()
|
||||||
|
|
||||||
|
aioclient_mock.get(
|
||||||
|
"https://raw.githubusercontent.com/in_folder/home-assistant-config/main/blueprints/automation/in_folder_blueprint.yaml",
|
||||||
|
text=raw_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
await client.send_json(
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"type": "blueprint/import",
|
||||||
|
"url": "https://github.com/in_folder/home-assistant-config/blob/main/blueprints/automation/in_folder_blueprint.yaml",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
|
||||||
|
assert msg["id"] == 5
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {
|
||||||
|
"suggested_filename": "in_folder/in_folder_blueprint",
|
||||||
|
"raw_data": raw_data,
|
||||||
|
"blueprint": {
|
||||||
|
"metadata": {
|
||||||
|
"domain": "automation",
|
||||||
|
"input": {"action": None, "trigger": None},
|
||||||
|
"name": "In Folder Blueprint",
|
||||||
|
"source_url": "https://github.com/in_folder/home-assistant-config/blob/main/blueprints/automation/in_folder_blueprint.yaml",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"validation_errors": None,
|
||||||
|
"exists": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -212,6 +259,42 @@ async def test_save_existing_file(
|
|||||||
assert msg["error"] == {"code": "already_exists", "message": "File already exists"}
|
assert msg["error"] == {"code": "already_exists", "message": "File already exists"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_save_existing_file_override(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
aioclient_mock: AiohttpClientMocker,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test saving blueprints."""
|
||||||
|
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
with patch("pathlib.Path.write_text") as write_mock:
|
||||||
|
await client.send_json(
|
||||||
|
{
|
||||||
|
"id": 7,
|
||||||
|
"type": "blueprint/save",
|
||||||
|
"path": "test_event_service",
|
||||||
|
"yaml": 'blueprint: {name: "name", domain: "automation"}',
|
||||||
|
"domain": "automation",
|
||||||
|
"source_url": "https://github.com/balloob/home-assistant-config/blob/main/blueprints/automation/test_event_service.yaml",
|
||||||
|
"allow_override": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
|
||||||
|
assert msg["id"] == 7
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"overrides_existing": True}
|
||||||
|
assert yaml.safe_load(write_mock.mock_calls[0][1][0]) == {
|
||||||
|
"blueprint": {
|
||||||
|
"name": "name",
|
||||||
|
"domain": "automation",
|
||||||
|
"source_url": "https://github.com/balloob/home-assistant-config/blob/main/blueprints/automation/test_event_service.yaml",
|
||||||
|
"input": {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_save_file_error(
|
async def test_save_file_error(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
aioclient_mock: AiohttpClientMocker,
|
aioclient_mock: AiohttpClientMocker,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user