diff --git a/homeassistant/components/tradfri/__init__.py b/homeassistant/components/tradfri/__init__.py index af63fe192cd..d3caaf54762 100644 --- a/homeassistant/components/tradfri/__init__.py +++ b/homeassistant/components/tradfri/__init__.py @@ -1,6 +1,5 @@ """Support for IKEA Tradfri.""" import asyncio -import logging from pytradfri import Gateway, RequestError from pytradfri.api.aiocoap_api import APIFactory @@ -31,9 +30,8 @@ from .const import ( PLATFORMS, ) -_LOGGER = logging.getLogger(__name__) - FACTORY = "tradfri_factory" +LISTENERS = "tradfri_listeners" CONFIG_SCHEMA = vol.Schema( { @@ -98,6 +96,8 @@ async def async_setup(hass, config): async def async_setup_entry(hass, entry): """Create a gateway.""" # host, identity, key, allow_tradfri_groups + tradfri_data = hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {} + listeners = tradfri_data[LISTENERS] = [] factory = await APIFactory.init( entry.data[CONF_HOST], @@ -109,7 +109,7 @@ async def async_setup_entry(hass, entry): """Close connection when hass stops.""" await factory.shutdown() - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop) + listeners.append(hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop)) api = factory.request gateway = Gateway() @@ -120,9 +120,8 @@ async def async_setup_entry(hass, entry): await factory.shutdown() raise ConfigEntryNotReady from err - hass.data.setdefault(KEY_API, {})[entry.entry_id] = api - hass.data.setdefault(KEY_GATEWAY, {})[entry.entry_id] = gateway - tradfri_data = hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {} + tradfri_data[KEY_API] = api + tradfri_data[KEY_GATEWAY] = gateway tradfri_data[FACTORY] = factory dev_reg = await hass.helpers.device_registry.async_get_registry() @@ -156,10 +155,11 @@ async def async_unload_entry(hass, entry): ) ) if unload_ok: - hass.data[KEY_API].pop(entry.entry_id) - hass.data[KEY_GATEWAY].pop(entry.entry_id) tradfri_data = hass.data[DOMAIN].pop(entry.entry_id) factory = tradfri_data[FACTORY] await factory.shutdown() + # unsubscribe listeners + for listener in tradfri_data[LISTENERS]: + listener() return unload_ok diff --git a/homeassistant/components/tradfri/cover.py b/homeassistant/components/tradfri/cover.py index 6d8669eea91..cab7b6bbab7 100644 --- a/homeassistant/components/tradfri/cover.py +++ b/homeassistant/components/tradfri/cover.py @@ -3,14 +3,15 @@ from homeassistant.components.cover import ATTR_POSITION, CoverEntity from .base_class import TradfriBaseDevice -from .const import ATTR_MODEL, CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY +from .const import ATTR_MODEL, CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY async def async_setup_entry(hass, config_entry, async_add_entities): """Load Tradfri covers based on a config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - api = hass.data[KEY_API][config_entry.entry_id] - gateway = hass.data[KEY_GATEWAY][config_entry.entry_id] + tradfri_data = hass.data[DOMAIN][config_entry.entry_id] + api = tradfri_data[KEY_API] + gateway = tradfri_data[KEY_GATEWAY] devices_commands = await api(gateway.get_devices()) devices = await api(devices_commands) diff --git a/homeassistant/components/tradfri/light.py b/homeassistant/components/tradfri/light.py index 4e44b452d33..29e096b2c49 100644 --- a/homeassistant/components/tradfri/light.py +++ b/homeassistant/components/tradfri/light.py @@ -21,6 +21,7 @@ from .const import ( ATTR_TRANSITION_TIME, CONF_GATEWAY_ID, CONF_IMPORT_GROUPS, + DOMAIN, KEY_API, KEY_GATEWAY, SUPPORTED_GROUP_FEATURES, @@ -33,8 +34,9 @@ _LOGGER = logging.getLogger(__name__) async def async_setup_entry(hass, config_entry, async_add_entities): """Load Tradfri lights based on a config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - api = hass.data[KEY_API][config_entry.entry_id] - gateway = hass.data[KEY_GATEWAY][config_entry.entry_id] + tradfri_data = hass.data[DOMAIN][config_entry.entry_id] + api = tradfri_data[KEY_API] + gateway = tradfri_data[KEY_GATEWAY] devices_commands = await api(gateway.get_devices()) devices = await api(devices_commands) diff --git a/homeassistant/components/tradfri/sensor.py b/homeassistant/components/tradfri/sensor.py index db12ab0a5cb..0cdd9152b4f 100644 --- a/homeassistant/components/tradfri/sensor.py +++ b/homeassistant/components/tradfri/sensor.py @@ -3,14 +3,15 @@ from homeassistant.const import DEVICE_CLASS_BATTERY, UNIT_PERCENTAGE from .base_class import TradfriBaseDevice -from .const import CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY +from .const import CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY async def async_setup_entry(hass, config_entry, async_add_entities): """Set up a Tradfri config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - api = hass.data[KEY_API][config_entry.entry_id] - gateway = hass.data[KEY_GATEWAY][config_entry.entry_id] + tradfri_data = hass.data[DOMAIN][config_entry.entry_id] + api = tradfri_data[KEY_API] + gateway = tradfri_data[KEY_GATEWAY] devices_commands = await api(gateway.get_devices()) all_devices = await api(devices_commands) diff --git a/homeassistant/components/tradfri/switch.py b/homeassistant/components/tradfri/switch.py index cf23ffeb445..5bc5e6ab8e8 100644 --- a/homeassistant/components/tradfri/switch.py +++ b/homeassistant/components/tradfri/switch.py @@ -2,14 +2,15 @@ from homeassistant.components.switch import SwitchEntity from .base_class import TradfriBaseDevice -from .const import CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY +from .const import CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY async def async_setup_entry(hass, config_entry, async_add_entities): """Load Tradfri switches based on a config entry.""" gateway_id = config_entry.data[CONF_GATEWAY_ID] - api = hass.data[KEY_API][config_entry.entry_id] - gateway = hass.data[KEY_GATEWAY][config_entry.entry_id] + tradfri_data = hass.data[DOMAIN][config_entry.entry_id] + api = tradfri_data[KEY_API] + gateway = tradfri_data[KEY_GATEWAY] devices_commands = await api(gateway.get_devices()) devices = await api(devices_commands) diff --git a/tests/components/tradfri/__init__.py b/tests/components/tradfri/__init__.py index 4d1b505abc9..e7a6fcb9138 100644 --- a/tests/components/tradfri/__init__.py +++ b/tests/components/tradfri/__init__.py @@ -1 +1,2 @@ """Tests for the tradfri component.""" +MOCK_GATEWAY_ID = "mock-gateway-id" diff --git a/tests/components/tradfri/conftest.py b/tests/components/tradfri/conftest.py index 891dd1377fe..a944f836dea 100644 --- a/tests/components/tradfri/conftest.py +++ b/tests/components/tradfri/conftest.py @@ -1,7 +1,11 @@ """Common tradfri test fixtures.""" import pytest -from tests.async_mock import patch +from . import MOCK_GATEWAY_ID + +from tests.async_mock import Mock, patch + +# pylint: disable=protected-access @pytest.fixture @@ -9,8 +13,8 @@ def mock_gateway_info(): """Mock get_gateway_info.""" with patch( "homeassistant.components.tradfri.config_flow.get_gateway_info" - ) as mock_gateway: - yield mock_gateway + ) as gateway_info: + yield gateway_info @pytest.fixture @@ -19,3 +23,64 @@ def mock_entry_setup(): with patch("homeassistant.components.tradfri.async_setup_entry") as mock_setup: mock_setup.return_value = True yield mock_setup + + +@pytest.fixture(name="gateway_id") +def mock_gateway_id_fixture(): + """Return mock gateway_id.""" + return MOCK_GATEWAY_ID + + +@pytest.fixture(name="mock_gateway") +def mock_gateway_fixture(gateway_id): + """Mock a Tradfri gateway.""" + + def get_devices(): + """Return mock devices.""" + return gateway.mock_devices + + def get_groups(): + """Return mock groups.""" + return gateway.mock_groups + + gateway_info = Mock(id=gateway_id, firmware_version="1.2.1234") + + def get_gateway_info(): + """Return mock gateway info.""" + return gateway_info + + gateway = Mock( + get_devices=get_devices, + get_groups=get_groups, + get_gateway_info=get_gateway_info, + mock_devices=[], + mock_groups=[], + mock_responses=[], + ) + with patch("homeassistant.components.tradfri.Gateway", return_value=gateway), patch( + "homeassistant.components.tradfri.config_flow.Gateway", return_value=gateway + ): + yield gateway + + +@pytest.fixture(name="mock_api") +def mock_api_fixture(mock_gateway): + """Mock api.""" + + async def api(command): + """Mock api function.""" + # Store the data for "real" command objects. + if hasattr(command, "_data") and not isinstance(command, Mock): + mock_gateway.mock_responses.append(command._data) + return command + + return api + + +@pytest.fixture(name="api_factory") +def mock_api_factory_fixture(mock_api): + """Mock pytradfri api factory.""" + with patch("homeassistant.components.tradfri.APIFactory", autospec=True) as factory: + factory.init.return_value = factory.return_value + factory.return_value.request = mock_api + yield factory.return_value diff --git a/tests/components/tradfri/test_init.py b/tests/components/tradfri/test_init.py index 2845137244b..34cc6d38091 100644 --- a/tests/components/tradfri/test_init.py +++ b/tests/components/tradfri/test_init.py @@ -1,4 +1,9 @@ """Tests for Tradfri setup.""" +from homeassistant.components import tradfri +from homeassistant.helpers.device_registry import ( + async_entries_for_config_entry, + async_get_registry as async_get_device_registry, +) from homeassistant.setup import async_setup_component from tests.async_mock import patch @@ -48,13 +53,15 @@ async def test_config_json_host_not_imported(hass): assert len(mock_init.mock_calls) == 0 -async def test_config_json_host_imported(hass, mock_gateway_info, mock_entry_setup): +async def test_config_json_host_imported( + hass, mock_gateway_info, mock_entry_setup, gateway_id +): """Test that we import a configured host.""" mock_gateway_info.side_effect = lambda hass, host, identity, key: { "host": host, "identity": identity, "key": key, - "gateway_id": "mock-gateway", + "gateway_id": gateway_id, } with patch( @@ -68,3 +75,45 @@ async def test_config_json_host_imported(hass, mock_gateway_info, mock_entry_set assert config_entry.domain == "tradfri" assert config_entry.source == "import" assert config_entry.title == "mock-host" + + +async def test_entry_setup_unload(hass, api_factory, gateway_id): + """Test config entry setup and unload.""" + entry = MockConfigEntry( + domain=tradfri.DOMAIN, + data={ + tradfri.CONF_HOST: "mock-host", + tradfri.CONF_IDENTITY: "mock-identity", + tradfri.CONF_KEY: "mock-key", + tradfri.CONF_IMPORT_GROUPS: True, + tradfri.CONF_GATEWAY_ID: gateway_id, + }, + ) + + entry.add_to_hass(hass) + with patch.object( + hass.config_entries, "async_forward_entry_setup", return_value=True + ) as setup: + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + assert setup.call_count == len(tradfri.PLATFORMS) + + dev_reg = await async_get_device_registry(hass) + dev_entries = async_entries_for_config_entry(dev_reg, entry.entry_id) + + assert dev_entries + dev_entry = dev_entries[0] + assert dev_entry.identifiers == { + (tradfri.DOMAIN, entry.data[tradfri.CONF_GATEWAY_ID]) + } + assert dev_entry.manufacturer == tradfri.ATTR_TRADFRI_MANUFACTURER + assert dev_entry.name == tradfri.ATTR_TRADFRI_GATEWAY + assert dev_entry.model == tradfri.ATTR_TRADFRI_GATEWAY_MODEL + + with patch.object( + hass.config_entries, "async_forward_entry_unload", return_value=True + ) as unload: + assert await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + assert unload.call_count == len(tradfri.PLATFORMS) + assert api_factory.shutdown.call_count == 1 diff --git a/tests/components/tradfri/test_light.py b/tests/components/tradfri/test_light.py index a5a5823fbf4..b4c209c1493 100644 --- a/tests/components/tradfri/test_light.py +++ b/tests/components/tradfri/test_light.py @@ -9,6 +9,8 @@ from pytradfri.device.light_control import LightControl from homeassistant.components import tradfri +from . import MOCK_GATEWAY_ID + from tests.async_mock import MagicMock, Mock, PropertyMock, patch from tests.common import MockConfigEntry @@ -93,42 +95,6 @@ def setup(request): request.addfinalizer(teardown) -@pytest.fixture -def mock_gateway(): - """Mock a Tradfri gateway.""" - - def get_devices(): - """Return mock devices.""" - return gateway.mock_devices - - def get_groups(): - """Return mock groups.""" - return gateway.mock_groups - - gateway = Mock( - get_devices=get_devices, - get_groups=get_groups, - mock_devices=[], - mock_groups=[], - mock_responses=[], - ) - return gateway - - -@pytest.fixture -def mock_api(mock_gateway): - """Mock api.""" - - async def api(command): - """Mock api function.""" - # Store the data for "real" command objects. - if hasattr(command, "_data") and not isinstance(command, Mock): - mock_gateway.mock_responses.append(command._data) - return command - - return api - - async def generate_psk(self, code): """Mock psk.""" return "mock" @@ -143,11 +109,14 @@ async def setup_gateway(hass, mock_gateway, mock_api): "identity": "mock-identity", "key": "mock-key", "import_groups": True, - "gateway_id": "mock-gateway-id", + "gateway_id": MOCK_GATEWAY_ID, }, ) - hass.data[tradfri.KEY_GATEWAY] = {entry.entry_id: mock_gateway} - hass.data[tradfri.KEY_API] = {entry.entry_id: mock_api} + tradfri_data = {} + hass.data[tradfri.DOMAIN] = {entry.entry_id: tradfri_data} + tradfri_data[tradfri.KEY_API] = mock_api + tradfri_data[tradfri.KEY_GATEWAY] = mock_gateway + await hass.config_entries.async_forward_entry_setup(entry, "light") await hass.async_block_till_done()