Clean tradfri hass data and add tests (#39620)

This commit is contained in:
Martin Hjelmare 2020-09-03 18:39:24 +02:00 committed by GitHub
parent d128443a2a
commit bde0bdbf80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 153 additions and 64 deletions

View File

@ -1,6 +1,5 @@
"""Support for IKEA Tradfri.""" """Support for IKEA Tradfri."""
import asyncio import asyncio
import logging
from pytradfri import Gateway, RequestError from pytradfri import Gateway, RequestError
from pytradfri.api.aiocoap_api import APIFactory from pytradfri.api.aiocoap_api import APIFactory
@ -31,9 +30,8 @@ from .const import (
PLATFORMS, PLATFORMS,
) )
_LOGGER = logging.getLogger(__name__)
FACTORY = "tradfri_factory" FACTORY = "tradfri_factory"
LISTENERS = "tradfri_listeners"
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
{ {
@ -98,6 +96,8 @@ async def async_setup(hass, config):
async def async_setup_entry(hass, entry): async def async_setup_entry(hass, entry):
"""Create a gateway.""" """Create a gateway."""
# host, identity, key, allow_tradfri_groups # host, identity, key, allow_tradfri_groups
tradfri_data = hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {}
listeners = tradfri_data[LISTENERS] = []
factory = await APIFactory.init( factory = await APIFactory.init(
entry.data[CONF_HOST], entry.data[CONF_HOST],
@ -109,7 +109,7 @@ async def async_setup_entry(hass, entry):
"""Close connection when hass stops.""" """Close connection when hass stops."""
await factory.shutdown() await factory.shutdown()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop) listeners.append(hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop))
api = factory.request api = factory.request
gateway = Gateway() gateway = Gateway()
@ -120,9 +120,8 @@ async def async_setup_entry(hass, entry):
await factory.shutdown() await factory.shutdown()
raise ConfigEntryNotReady from err raise ConfigEntryNotReady from err
hass.data.setdefault(KEY_API, {})[entry.entry_id] = api tradfri_data[KEY_API] = api
hass.data.setdefault(KEY_GATEWAY, {})[entry.entry_id] = gateway tradfri_data[KEY_GATEWAY] = gateway
tradfri_data = hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {}
tradfri_data[FACTORY] = factory tradfri_data[FACTORY] = factory
dev_reg = await hass.helpers.device_registry.async_get_registry() dev_reg = await hass.helpers.device_registry.async_get_registry()
@ -156,10 +155,11 @@ async def async_unload_entry(hass, entry):
) )
) )
if unload_ok: if unload_ok:
hass.data[KEY_API].pop(entry.entry_id)
hass.data[KEY_GATEWAY].pop(entry.entry_id)
tradfri_data = hass.data[DOMAIN].pop(entry.entry_id) tradfri_data = hass.data[DOMAIN].pop(entry.entry_id)
factory = tradfri_data[FACTORY] factory = tradfri_data[FACTORY]
await factory.shutdown() await factory.shutdown()
# unsubscribe listeners
for listener in tradfri_data[LISTENERS]:
listener()
return unload_ok return unload_ok

View File

@ -3,14 +3,15 @@
from homeassistant.components.cover import ATTR_POSITION, CoverEntity from homeassistant.components.cover import ATTR_POSITION, CoverEntity
from .base_class import TradfriBaseDevice from .base_class import TradfriBaseDevice
from .const import ATTR_MODEL, CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY from .const import ATTR_MODEL, CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY
async def async_setup_entry(hass, config_entry, async_add_entities): async def async_setup_entry(hass, config_entry, async_add_entities):
"""Load Tradfri covers based on a config entry.""" """Load Tradfri covers based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
api = hass.data[KEY_API][config_entry.entry_id] tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id] api = tradfri_data[KEY_API]
gateway = tradfri_data[KEY_GATEWAY]
devices_commands = await api(gateway.get_devices()) devices_commands = await api(gateway.get_devices())
devices = await api(devices_commands) devices = await api(devices_commands)

View File

@ -21,6 +21,7 @@ from .const import (
ATTR_TRANSITION_TIME, ATTR_TRANSITION_TIME,
CONF_GATEWAY_ID, CONF_GATEWAY_ID,
CONF_IMPORT_GROUPS, CONF_IMPORT_GROUPS,
DOMAIN,
KEY_API, KEY_API,
KEY_GATEWAY, KEY_GATEWAY,
SUPPORTED_GROUP_FEATURES, SUPPORTED_GROUP_FEATURES,
@ -33,8 +34,9 @@ _LOGGER = logging.getLogger(__name__)
async def async_setup_entry(hass, config_entry, async_add_entities): async def async_setup_entry(hass, config_entry, async_add_entities):
"""Load Tradfri lights based on a config entry.""" """Load Tradfri lights based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
api = hass.data[KEY_API][config_entry.entry_id] tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id] api = tradfri_data[KEY_API]
gateway = tradfri_data[KEY_GATEWAY]
devices_commands = await api(gateway.get_devices()) devices_commands = await api(gateway.get_devices())
devices = await api(devices_commands) devices = await api(devices_commands)

View File

@ -3,14 +3,15 @@
from homeassistant.const import DEVICE_CLASS_BATTERY, UNIT_PERCENTAGE from homeassistant.const import DEVICE_CLASS_BATTERY, UNIT_PERCENTAGE
from .base_class import TradfriBaseDevice from .base_class import TradfriBaseDevice
from .const import CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY from .const import CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY
async def async_setup_entry(hass, config_entry, async_add_entities): async def async_setup_entry(hass, config_entry, async_add_entities):
"""Set up a Tradfri config entry.""" """Set up a Tradfri config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
api = hass.data[KEY_API][config_entry.entry_id] tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id] api = tradfri_data[KEY_API]
gateway = tradfri_data[KEY_GATEWAY]
devices_commands = await api(gateway.get_devices()) devices_commands = await api(gateway.get_devices())
all_devices = await api(devices_commands) all_devices = await api(devices_commands)

View File

@ -2,14 +2,15 @@
from homeassistant.components.switch import SwitchEntity from homeassistant.components.switch import SwitchEntity
from .base_class import TradfriBaseDevice from .base_class import TradfriBaseDevice
from .const import CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY from .const import CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY
async def async_setup_entry(hass, config_entry, async_add_entities): async def async_setup_entry(hass, config_entry, async_add_entities):
"""Load Tradfri switches based on a config entry.""" """Load Tradfri switches based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID] gateway_id = config_entry.data[CONF_GATEWAY_ID]
api = hass.data[KEY_API][config_entry.entry_id] tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id] api = tradfri_data[KEY_API]
gateway = tradfri_data[KEY_GATEWAY]
devices_commands = await api(gateway.get_devices()) devices_commands = await api(gateway.get_devices())
devices = await api(devices_commands) devices = await api(devices_commands)

View File

@ -1 +1,2 @@
"""Tests for the tradfri component.""" """Tests for the tradfri component."""
MOCK_GATEWAY_ID = "mock-gateway-id"

View File

@ -1,7 +1,11 @@
"""Common tradfri test fixtures.""" """Common tradfri test fixtures."""
import pytest import pytest
from tests.async_mock import patch from . import MOCK_GATEWAY_ID
from tests.async_mock import Mock, patch
# pylint: disable=protected-access
@pytest.fixture @pytest.fixture
@ -9,8 +13,8 @@ def mock_gateway_info():
"""Mock get_gateway_info.""" """Mock get_gateway_info."""
with patch( with patch(
"homeassistant.components.tradfri.config_flow.get_gateway_info" "homeassistant.components.tradfri.config_flow.get_gateway_info"
) as mock_gateway: ) as gateway_info:
yield mock_gateway yield gateway_info
@pytest.fixture @pytest.fixture
@ -19,3 +23,64 @@ def mock_entry_setup():
with patch("homeassistant.components.tradfri.async_setup_entry") as mock_setup: with patch("homeassistant.components.tradfri.async_setup_entry") as mock_setup:
mock_setup.return_value = True mock_setup.return_value = True
yield mock_setup yield mock_setup
@pytest.fixture(name="gateway_id")
def mock_gateway_id_fixture():
"""Return mock gateway_id."""
return MOCK_GATEWAY_ID
@pytest.fixture(name="mock_gateway")
def mock_gateway_fixture(gateway_id):
"""Mock a Tradfri gateway."""
def get_devices():
"""Return mock devices."""
return gateway.mock_devices
def get_groups():
"""Return mock groups."""
return gateway.mock_groups
gateway_info = Mock(id=gateway_id, firmware_version="1.2.1234")
def get_gateway_info():
"""Return mock gateway info."""
return gateway_info
gateway = Mock(
get_devices=get_devices,
get_groups=get_groups,
get_gateway_info=get_gateway_info,
mock_devices=[],
mock_groups=[],
mock_responses=[],
)
with patch("homeassistant.components.tradfri.Gateway", return_value=gateway), patch(
"homeassistant.components.tradfri.config_flow.Gateway", return_value=gateway
):
yield gateway
@pytest.fixture(name="mock_api")
def mock_api_fixture(mock_gateway):
"""Mock api."""
async def api(command):
"""Mock api function."""
# Store the data for "real" command objects.
if hasattr(command, "_data") and not isinstance(command, Mock):
mock_gateway.mock_responses.append(command._data)
return command
return api
@pytest.fixture(name="api_factory")
def mock_api_factory_fixture(mock_api):
"""Mock pytradfri api factory."""
with patch("homeassistant.components.tradfri.APIFactory", autospec=True) as factory:
factory.init.return_value = factory.return_value
factory.return_value.request = mock_api
yield factory.return_value

View File

@ -1,4 +1,9 @@
"""Tests for Tradfri setup.""" """Tests for Tradfri setup."""
from homeassistant.components import tradfri
from homeassistant.helpers.device_registry import (
async_entries_for_config_entry,
async_get_registry as async_get_device_registry,
)
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.async_mock import patch from tests.async_mock import patch
@ -48,13 +53,15 @@ async def test_config_json_host_not_imported(hass):
assert len(mock_init.mock_calls) == 0 assert len(mock_init.mock_calls) == 0
async def test_config_json_host_imported(hass, mock_gateway_info, mock_entry_setup): async def test_config_json_host_imported(
hass, mock_gateway_info, mock_entry_setup, gateway_id
):
"""Test that we import a configured host.""" """Test that we import a configured host."""
mock_gateway_info.side_effect = lambda hass, host, identity, key: { mock_gateway_info.side_effect = lambda hass, host, identity, key: {
"host": host, "host": host,
"identity": identity, "identity": identity,
"key": key, "key": key,
"gateway_id": "mock-gateway", "gateway_id": gateway_id,
} }
with patch( with patch(
@ -68,3 +75,45 @@ async def test_config_json_host_imported(hass, mock_gateway_info, mock_entry_set
assert config_entry.domain == "tradfri" assert config_entry.domain == "tradfri"
assert config_entry.source == "import" assert config_entry.source == "import"
assert config_entry.title == "mock-host" assert config_entry.title == "mock-host"
async def test_entry_setup_unload(hass, api_factory, gateway_id):
"""Test config entry setup and unload."""
entry = MockConfigEntry(
domain=tradfri.DOMAIN,
data={
tradfri.CONF_HOST: "mock-host",
tradfri.CONF_IDENTITY: "mock-identity",
tradfri.CONF_KEY: "mock-key",
tradfri.CONF_IMPORT_GROUPS: True,
tradfri.CONF_GATEWAY_ID: gateway_id,
},
)
entry.add_to_hass(hass)
with patch.object(
hass.config_entries, "async_forward_entry_setup", return_value=True
) as setup:
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert setup.call_count == len(tradfri.PLATFORMS)
dev_reg = await async_get_device_registry(hass)
dev_entries = async_entries_for_config_entry(dev_reg, entry.entry_id)
assert dev_entries
dev_entry = dev_entries[0]
assert dev_entry.identifiers == {
(tradfri.DOMAIN, entry.data[tradfri.CONF_GATEWAY_ID])
}
assert dev_entry.manufacturer == tradfri.ATTR_TRADFRI_MANUFACTURER
assert dev_entry.name == tradfri.ATTR_TRADFRI_GATEWAY
assert dev_entry.model == tradfri.ATTR_TRADFRI_GATEWAY_MODEL
with patch.object(
hass.config_entries, "async_forward_entry_unload", return_value=True
) as unload:
assert await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
assert unload.call_count == len(tradfri.PLATFORMS)
assert api_factory.shutdown.call_count == 1

View File

@ -9,6 +9,8 @@ from pytradfri.device.light_control import LightControl
from homeassistant.components import tradfri from homeassistant.components import tradfri
from . import MOCK_GATEWAY_ID
from tests.async_mock import MagicMock, Mock, PropertyMock, patch from tests.async_mock import MagicMock, Mock, PropertyMock, patch
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -93,42 +95,6 @@ def setup(request):
request.addfinalizer(teardown) request.addfinalizer(teardown)
@pytest.fixture
def mock_gateway():
"""Mock a Tradfri gateway."""
def get_devices():
"""Return mock devices."""
return gateway.mock_devices
def get_groups():
"""Return mock groups."""
return gateway.mock_groups
gateway = Mock(
get_devices=get_devices,
get_groups=get_groups,
mock_devices=[],
mock_groups=[],
mock_responses=[],
)
return gateway
@pytest.fixture
def mock_api(mock_gateway):
"""Mock api."""
async def api(command):
"""Mock api function."""
# Store the data for "real" command objects.
if hasattr(command, "_data") and not isinstance(command, Mock):
mock_gateway.mock_responses.append(command._data)
return command
return api
async def generate_psk(self, code): async def generate_psk(self, code):
"""Mock psk.""" """Mock psk."""
return "mock" return "mock"
@ -143,11 +109,14 @@ async def setup_gateway(hass, mock_gateway, mock_api):
"identity": "mock-identity", "identity": "mock-identity",
"key": "mock-key", "key": "mock-key",
"import_groups": True, "import_groups": True,
"gateway_id": "mock-gateway-id", "gateway_id": MOCK_GATEWAY_ID,
}, },
) )
hass.data[tradfri.KEY_GATEWAY] = {entry.entry_id: mock_gateway} tradfri_data = {}
hass.data[tradfri.KEY_API] = {entry.entry_id: mock_api} hass.data[tradfri.DOMAIN] = {entry.entry_id: tradfri_data}
tradfri_data[tradfri.KEY_API] = mock_api
tradfri_data[tradfri.KEY_GATEWAY] = mock_gateway
await hass.config_entries.async_forward_entry_setup(entry, "light") await hass.config_entries.async_forward_entry_setup(entry, "light")
await hass.async_block_till_done() await hass.async_block_till_done()