diff --git a/homeassistant/components/frontend/__init__.py b/homeassistant/components/frontend/__init__.py index fdea21fe91e..b26d7a4e168 100644 --- a/homeassistant/components/frontend/__init__.py +++ b/homeassistant/components/frontend/__init__.py @@ -16,6 +16,7 @@ from homeassistant.components.http.view import HomeAssistantView from homeassistant.config import async_hass_config_yaml from homeassistant.const import CONF_NAME, EVENT_THEMES_UPDATED from homeassistant.core import callback +from homeassistant.helpers import service import homeassistant.helpers.config_validation as cv from homeassistant.helpers.translation import async_get_translations from homeassistant.loader import bind_hass @@ -103,19 +104,6 @@ CONFIG_SCHEMA = vol.Schema( SERVICE_SET_THEME = "set_theme" SERVICE_RELOAD_THEMES = "reload_themes" -SERVICE_SET_THEME_SCHEMA = vol.Schema({vol.Required(CONF_NAME): cv.string}) -WS_TYPE_GET_PANELS = "get_panels" -SCHEMA_GET_PANELS = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - {vol.Required("type"): WS_TYPE_GET_PANELS} -) -WS_TYPE_GET_THEMES = "frontend/get_themes" -SCHEMA_GET_THEMES = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - {vol.Required("type"): WS_TYPE_GET_THEMES} -) -WS_TYPE_GET_TRANSLATIONS = "frontend/get_translations" -SCHEMA_GET_TRANSLATIONS = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - {vol.Required("type"): WS_TYPE_GET_TRANSLATIONS, vol.Required("language"): str} -) class Panel: @@ -251,15 +239,9 @@ def _frontend_root(dev_repo_path): async def async_setup(hass, config): """Set up the serving of the frontend.""" await async_setup_frontend_storage(hass) - hass.components.websocket_api.async_register_command( - WS_TYPE_GET_PANELS, websocket_get_panels, SCHEMA_GET_PANELS - ) - hass.components.websocket_api.async_register_command( - WS_TYPE_GET_THEMES, websocket_get_themes, SCHEMA_GET_THEMES - ) - hass.components.websocket_api.async_register_command( - WS_TYPE_GET_TRANSLATIONS, websocket_get_translations, SCHEMA_GET_TRANSLATIONS - ) + hass.components.websocket_api.async_register_command(websocket_get_panels) + hass.components.websocket_api.async_register_command(websocket_get_themes) + hass.components.websocket_api.async_register_command(websocket_get_translations) hass.http.register_view(ManifestJSONView) conf = config.get(DOMAIN, {}) @@ -331,11 +313,7 @@ async def async_setup(hass, config): def _async_setup_themes(hass, themes): """Set up themes data and services.""" hass.data[DATA_DEFAULT_THEME] = DEFAULT_THEME - if themes is None: - hass.data[DATA_THEMES] = {} - return - - hass.data[DATA_THEMES] = themes + hass.data[DATA_THEMES] = themes or {} @callback def update_theme_and_fire_event(): @@ -348,9 +326,7 @@ def _async_setup_themes(hass, themes): "app-header-background-color", themes[name].get(PRIMARY_COLOR, DEFAULT_THEME_COLOR), ) - hass.bus.async_fire( - EVENT_THEMES_UPDATED, {"themes": themes, "default_theme": name} - ) + hass.bus.async_fire(EVENT_THEMES_UPDATED) @callback def set_theme(call): @@ -373,10 +349,17 @@ def _async_setup_themes(hass, themes): hass.data[DATA_DEFAULT_THEME] = DEFAULT_THEME update_theme_and_fire_event() - hass.services.async_register( - DOMAIN, SERVICE_SET_THEME, set_theme, schema=SERVICE_SET_THEME_SCHEMA + service.async_register_admin_service( + hass, + DOMAIN, + SERVICE_SET_THEME, + set_theme, + vol.Schema({vol.Required(CONF_NAME): cv.string}), + ) + + service.async_register_admin_service( + hass, DOMAIN, SERVICE_RELOAD_THEMES, reload_themes ) - hass.services.async_register(DOMAIN, SERVICE_RELOAD_THEMES, reload_themes) class IndexView(web_urldispatcher.AbstractResource): @@ -498,6 +481,7 @@ class ManifestJSONView(HomeAssistantView): @callback +@websocket_api.websocket_command({"type": "get_panels"}) def websocket_get_panels(hass, connection, msg): """Handle get panels command. @@ -514,6 +498,7 @@ def websocket_get_panels(hass, connection, msg): @callback +@websocket_api.websocket_command({"type": "frontend/get_themes"}) def websocket_get_themes(hass, connection, msg): """Handle get themes command. @@ -530,6 +515,9 @@ def websocket_get_themes(hass, connection, msg): ) +@websocket_api.websocket_command( + {"type": "frontend/get_translations", vol.Required("language"): str} +) @websocket_api.async_response async def websocket_get_translations(hass, connection, msg): """Handle get translations command. diff --git a/homeassistant/core.py b/homeassistant/core.py index 3f561cdfab8..e819a32b7c7 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -298,10 +298,10 @@ class HomeAssistant: if asyncio.iscoroutine(check_target): task = self.loop.create_task(target) # type: ignore - elif is_callback(check_target): - self.loop.call_soon(target, *args) elif asyncio.iscoroutinefunction(check_target): task = self.loop.create_task(target(*args)) + elif is_callback(check_target): + self.loop.call_soon(target, *args) else: task = self.loop.run_in_executor( # type: ignore None, target, *args @@ -360,7 +360,11 @@ class HomeAssistant: target: target to call. args: parameters for method to call. """ - if not asyncio.iscoroutine(target) and is_callback(target): + if ( + not asyncio.iscoroutine(target) + and not asyncio.iscoroutinefunction(target) + and is_callback(target) + ): target(*args) else: self.async_add_job(target, *args) @@ -1245,10 +1249,10 @@ class ServiceRegistry: self, handler: Service, service_call: ServiceCall ) -> None: """Execute a service.""" - if handler.is_callback: - handler.func(service_call) - elif handler.is_coroutinefunction: + if handler.is_coroutinefunction: await handler.func(service_call) + elif handler.is_callback: + handler.func(service_call) else: await self._hass.async_add_executor_job(handler.func, service_call) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 46ebc467c0b..9085c929651 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -461,7 +461,9 @@ def async_register_admin_service( if not user.is_admin: raise Unauthorized(context=call.context) - await hass.async_add_job(service_func, call) + result = hass.async_add_job(service_func, call) + if result is not None: + await result hass.services.async_register(domain, service, admin_handler, schema) diff --git a/tests/test_core.py b/tests/test_core.py index aa0c615ec04..657bbeda7c6 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1180,3 +1180,28 @@ def test_context(): assert c.user_id == 23 assert c.parent_id == 100 assert c.id is not None + + +async def test_async_functions_with_callback(hass): + """Test we deal with async functions accidentally marked as callback.""" + runs = [] + + @ha.callback + async def test(): + runs.append(True) + + await hass.async_add_job(test) + assert len(runs) == 1 + + hass.async_run_job(test) + await hass.async_block_till_done() + assert len(runs) == 2 + + @ha.callback + async def service_handler(call): + runs.append(True) + + hass.services.async_register("test_domain", "test_service", service_handler) + + await hass.services.async_call("test_domain", "test_service", blocking=True) + assert len(runs) == 3