diff --git a/homeassistant/components/light/__init__.py b/homeassistant/components/light/__init__.py index 1dbd07f9439..5f3caff511a 100644 --- a/homeassistant/components/light/__init__.py +++ b/homeassistant/components/light/__init__.py @@ -13,6 +13,7 @@ import csv import voluptuous as vol from homeassistant.core import callback +from homeassistant.loader import bind_hass from homeassistant.components import group from homeassistant.config import load_yaml_config_file from homeassistant.const import ( @@ -165,6 +166,7 @@ def turn_on(hass, entity_id=None, transition=None, brightness=None, @callback +@bind_hass def async_turn_on(hass, entity_id=None, transition=None, brightness=None, brightness_pct=None, rgb_color=None, xy_color=None, color_temp=None, kelvin=None, white_value=None, diff --git a/homeassistant/core.py b/homeassistant/core.py index d1779fe420d..496bb018fbd 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -30,6 +30,7 @@ from homeassistant.const import ( EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE, EVENT_SERVICE_REMOVED, __version__) +from homeassistant.loader import Components from homeassistant.exceptions import ( HomeAssistantError, InvalidEntityFormatError) from homeassistant.util.async import ( @@ -128,6 +129,7 @@ class HomeAssistant(object): self.services = ServiceRegistry(self) self.states = StateMachine(self.bus, self.loop) self.config = Config() # type: Config + self.components = Components(self) # This is a dictionary that any component can store any data on. self.data = {} self.state = CoreState.not_running diff --git a/homeassistant/loader.py b/homeassistant/loader.py index 586988a3436..566cdd4fb15 100644 --- a/homeassistant/loader.py +++ b/homeassistant/loader.py @@ -10,6 +10,7 @@ call get_component('switch.your_platform'). In both cases the config directory is checked to see if it contains a user provided version. If not available it will check the built-in components and platforms. """ +import functools as ft import importlib import logging import os @@ -170,6 +171,49 @@ def get_component(comp_name) -> Optional[ModuleType]: return None +class Components: + """Helper to load components.""" + + def __init__(self, hass): + """Initialize the Components class.""" + self._hass = hass + + def __getattr__(self, comp_name): + """Fetch a component.""" + component = get_component(comp_name) + if component is None: + raise ImportError('Unable to load {}'.format(comp_name)) + wrapped = ComponentWrapper(self._hass, component) + setattr(self, comp_name, wrapped) + return wrapped + + +class ComponentWrapper: + """Class to wrap a component and auto fill in hass argument.""" + + def __init__(self, hass, component): + """Initialize the component wrapper.""" + self._hass = hass + self._component = component + + def __getattr__(self, attr): + """Fetch an attribute.""" + value = getattr(self._component, attr) + + if hasattr(value, '__bind_hass'): + value = ft.partial(value, self._hass) + + setattr(self, attr, value) + return value + + +def bind_hass(func): + """Decorator to indicate that first argument is hass.""" + # pylint: disable=protected-access + func.__bind_hass = True + return func + + def load_order_component(comp_name: str) -> OrderedSet: """Return an OrderedSet of components in the correct order of loading. diff --git a/tests/test_loader.py b/tests/test_loader.py index 0b3f9653faa..6081b061ed2 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -1,11 +1,15 @@ """Test to verify that we can load components.""" # pylint: disable=protected-access +import asyncio import unittest +import pytest + import homeassistant.loader as loader import homeassistant.components.http as http -from tests.common import get_test_home_assistant, MockModule +from tests.common import ( + get_test_home_assistant, MockModule, async_mock_service) class TestLoader(unittest.TestCase): @@ -54,3 +58,29 @@ class TestLoader(unittest.TestCase): # Try to get load order for non-existing component self.assertEqual([], loader.load_order_component('mod1')) + + +def test_component_loader(hass): + """Test loading components.""" + components = loader.Components(hass) + assert components.http.CONFIG_SCHEMA is http.CONFIG_SCHEMA + assert hass.components.http.CONFIG_SCHEMA is http.CONFIG_SCHEMA + + +def test_component_loader_non_existing(hass): + """Test loading components.""" + components = loader.Components(hass) + with pytest.raises(ImportError): + components.non_existing + + +@asyncio.coroutine +def test_component_wrapper(hass): + """Test component wrapper.""" + calls = async_mock_service(hass, 'light', 'turn_on') + + components = loader.Components(hass) + components.light.async_turn_on('light.test') + yield from hass.async_block_till_done() + + assert len(calls) == 1