mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Clean tradfri hass data and add tests (#39620)
This commit is contained in:
parent
d128443a2a
commit
bde0bdbf80
@ -1,6 +1,5 @@
|
|||||||
"""Support for IKEA Tradfri."""
|
"""Support for IKEA Tradfri."""
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
|
|
||||||
from pytradfri import Gateway, RequestError
|
from pytradfri import Gateway, RequestError
|
||||||
from pytradfri.api.aiocoap_api import APIFactory
|
from pytradfri.api.aiocoap_api import APIFactory
|
||||||
@ -31,9 +30,8 @@ from .const import (
|
|||||||
PLATFORMS,
|
PLATFORMS,
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
FACTORY = "tradfri_factory"
|
FACTORY = "tradfri_factory"
|
||||||
|
LISTENERS = "tradfri_listeners"
|
||||||
|
|
||||||
CONFIG_SCHEMA = vol.Schema(
|
CONFIG_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
@ -98,6 +96,8 @@ async def async_setup(hass, config):
|
|||||||
async def async_setup_entry(hass, entry):
|
async def async_setup_entry(hass, entry):
|
||||||
"""Create a gateway."""
|
"""Create a gateway."""
|
||||||
# host, identity, key, allow_tradfri_groups
|
# host, identity, key, allow_tradfri_groups
|
||||||
|
tradfri_data = hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {}
|
||||||
|
listeners = tradfri_data[LISTENERS] = []
|
||||||
|
|
||||||
factory = await APIFactory.init(
|
factory = await APIFactory.init(
|
||||||
entry.data[CONF_HOST],
|
entry.data[CONF_HOST],
|
||||||
@ -109,7 +109,7 @@ async def async_setup_entry(hass, entry):
|
|||||||
"""Close connection when hass stops."""
|
"""Close connection when hass stops."""
|
||||||
await factory.shutdown()
|
await factory.shutdown()
|
||||||
|
|
||||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop)
|
listeners.append(hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop))
|
||||||
|
|
||||||
api = factory.request
|
api = factory.request
|
||||||
gateway = Gateway()
|
gateway = Gateway()
|
||||||
@ -120,9 +120,8 @@ async def async_setup_entry(hass, entry):
|
|||||||
await factory.shutdown()
|
await factory.shutdown()
|
||||||
raise ConfigEntryNotReady from err
|
raise ConfigEntryNotReady from err
|
||||||
|
|
||||||
hass.data.setdefault(KEY_API, {})[entry.entry_id] = api
|
tradfri_data[KEY_API] = api
|
||||||
hass.data.setdefault(KEY_GATEWAY, {})[entry.entry_id] = gateway
|
tradfri_data[KEY_GATEWAY] = gateway
|
||||||
tradfri_data = hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {}
|
|
||||||
tradfri_data[FACTORY] = factory
|
tradfri_data[FACTORY] = factory
|
||||||
|
|
||||||
dev_reg = await hass.helpers.device_registry.async_get_registry()
|
dev_reg = await hass.helpers.device_registry.async_get_registry()
|
||||||
@ -156,10 +155,11 @@ async def async_unload_entry(hass, entry):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
if unload_ok:
|
if unload_ok:
|
||||||
hass.data[KEY_API].pop(entry.entry_id)
|
|
||||||
hass.data[KEY_GATEWAY].pop(entry.entry_id)
|
|
||||||
tradfri_data = hass.data[DOMAIN].pop(entry.entry_id)
|
tradfri_data = hass.data[DOMAIN].pop(entry.entry_id)
|
||||||
factory = tradfri_data[FACTORY]
|
factory = tradfri_data[FACTORY]
|
||||||
await factory.shutdown()
|
await factory.shutdown()
|
||||||
|
# unsubscribe listeners
|
||||||
|
for listener in tradfri_data[LISTENERS]:
|
||||||
|
listener()
|
||||||
|
|
||||||
return unload_ok
|
return unload_ok
|
||||||
|
@ -3,14 +3,15 @@
|
|||||||
from homeassistant.components.cover import ATTR_POSITION, CoverEntity
|
from homeassistant.components.cover import ATTR_POSITION, CoverEntity
|
||||||
|
|
||||||
from .base_class import TradfriBaseDevice
|
from .base_class import TradfriBaseDevice
|
||||||
from .const import ATTR_MODEL, CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY
|
from .const import ATTR_MODEL, CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||||
"""Load Tradfri covers based on a config entry."""
|
"""Load Tradfri covers based on a config entry."""
|
||||||
gateway_id = config_entry.data[CONF_GATEWAY_ID]
|
gateway_id = config_entry.data[CONF_GATEWAY_ID]
|
||||||
api = hass.data[KEY_API][config_entry.entry_id]
|
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
|
||||||
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
|
api = tradfri_data[KEY_API]
|
||||||
|
gateway = tradfri_data[KEY_GATEWAY]
|
||||||
|
|
||||||
devices_commands = await api(gateway.get_devices())
|
devices_commands = await api(gateway.get_devices())
|
||||||
devices = await api(devices_commands)
|
devices = await api(devices_commands)
|
||||||
|
@ -21,6 +21,7 @@ from .const import (
|
|||||||
ATTR_TRANSITION_TIME,
|
ATTR_TRANSITION_TIME,
|
||||||
CONF_GATEWAY_ID,
|
CONF_GATEWAY_ID,
|
||||||
CONF_IMPORT_GROUPS,
|
CONF_IMPORT_GROUPS,
|
||||||
|
DOMAIN,
|
||||||
KEY_API,
|
KEY_API,
|
||||||
KEY_GATEWAY,
|
KEY_GATEWAY,
|
||||||
SUPPORTED_GROUP_FEATURES,
|
SUPPORTED_GROUP_FEATURES,
|
||||||
@ -33,8 +34,9 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||||
"""Load Tradfri lights based on a config entry."""
|
"""Load Tradfri lights based on a config entry."""
|
||||||
gateway_id = config_entry.data[CONF_GATEWAY_ID]
|
gateway_id = config_entry.data[CONF_GATEWAY_ID]
|
||||||
api = hass.data[KEY_API][config_entry.entry_id]
|
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
|
||||||
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
|
api = tradfri_data[KEY_API]
|
||||||
|
gateway = tradfri_data[KEY_GATEWAY]
|
||||||
|
|
||||||
devices_commands = await api(gateway.get_devices())
|
devices_commands = await api(gateway.get_devices())
|
||||||
devices = await api(devices_commands)
|
devices = await api(devices_commands)
|
||||||
|
@ -3,14 +3,15 @@
|
|||||||
from homeassistant.const import DEVICE_CLASS_BATTERY, UNIT_PERCENTAGE
|
from homeassistant.const import DEVICE_CLASS_BATTERY, UNIT_PERCENTAGE
|
||||||
|
|
||||||
from .base_class import TradfriBaseDevice
|
from .base_class import TradfriBaseDevice
|
||||||
from .const import CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY
|
from .const import CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||||
"""Set up a Tradfri config entry."""
|
"""Set up a Tradfri config entry."""
|
||||||
gateway_id = config_entry.data[CONF_GATEWAY_ID]
|
gateway_id = config_entry.data[CONF_GATEWAY_ID]
|
||||||
api = hass.data[KEY_API][config_entry.entry_id]
|
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
|
||||||
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
|
api = tradfri_data[KEY_API]
|
||||||
|
gateway = tradfri_data[KEY_GATEWAY]
|
||||||
|
|
||||||
devices_commands = await api(gateway.get_devices())
|
devices_commands = await api(gateway.get_devices())
|
||||||
all_devices = await api(devices_commands)
|
all_devices = await api(devices_commands)
|
||||||
|
@ -2,14 +2,15 @@
|
|||||||
from homeassistant.components.switch import SwitchEntity
|
from homeassistant.components.switch import SwitchEntity
|
||||||
|
|
||||||
from .base_class import TradfriBaseDevice
|
from .base_class import TradfriBaseDevice
|
||||||
from .const import CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY
|
from .const import CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||||
"""Load Tradfri switches based on a config entry."""
|
"""Load Tradfri switches based on a config entry."""
|
||||||
gateway_id = config_entry.data[CONF_GATEWAY_ID]
|
gateway_id = config_entry.data[CONF_GATEWAY_ID]
|
||||||
api = hass.data[KEY_API][config_entry.entry_id]
|
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
|
||||||
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
|
api = tradfri_data[KEY_API]
|
||||||
|
gateway = tradfri_data[KEY_GATEWAY]
|
||||||
|
|
||||||
devices_commands = await api(gateway.get_devices())
|
devices_commands = await api(gateway.get_devices())
|
||||||
devices = await api(devices_commands)
|
devices = await api(devices_commands)
|
||||||
|
@ -1 +1,2 @@
|
|||||||
"""Tests for the tradfri component."""
|
"""Tests for the tradfri component."""
|
||||||
|
MOCK_GATEWAY_ID = "mock-gateway-id"
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
"""Common tradfri test fixtures."""
|
"""Common tradfri test fixtures."""
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.async_mock import patch
|
from . import MOCK_GATEWAY_ID
|
||||||
|
|
||||||
|
from tests.async_mock import Mock, patch
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -9,8 +13,8 @@ def mock_gateway_info():
|
|||||||
"""Mock get_gateway_info."""
|
"""Mock get_gateway_info."""
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.tradfri.config_flow.get_gateway_info"
|
"homeassistant.components.tradfri.config_flow.get_gateway_info"
|
||||||
) as mock_gateway:
|
) as gateway_info:
|
||||||
yield mock_gateway
|
yield gateway_info
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -19,3 +23,64 @@ def mock_entry_setup():
|
|||||||
with patch("homeassistant.components.tradfri.async_setup_entry") as mock_setup:
|
with patch("homeassistant.components.tradfri.async_setup_entry") as mock_setup:
|
||||||
mock_setup.return_value = True
|
mock_setup.return_value = True
|
||||||
yield mock_setup
|
yield mock_setup
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="gateway_id")
|
||||||
|
def mock_gateway_id_fixture():
|
||||||
|
"""Return mock gateway_id."""
|
||||||
|
return MOCK_GATEWAY_ID
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="mock_gateway")
|
||||||
|
def mock_gateway_fixture(gateway_id):
|
||||||
|
"""Mock a Tradfri gateway."""
|
||||||
|
|
||||||
|
def get_devices():
|
||||||
|
"""Return mock devices."""
|
||||||
|
return gateway.mock_devices
|
||||||
|
|
||||||
|
def get_groups():
|
||||||
|
"""Return mock groups."""
|
||||||
|
return gateway.mock_groups
|
||||||
|
|
||||||
|
gateway_info = Mock(id=gateway_id, firmware_version="1.2.1234")
|
||||||
|
|
||||||
|
def get_gateway_info():
|
||||||
|
"""Return mock gateway info."""
|
||||||
|
return gateway_info
|
||||||
|
|
||||||
|
gateway = Mock(
|
||||||
|
get_devices=get_devices,
|
||||||
|
get_groups=get_groups,
|
||||||
|
get_gateway_info=get_gateway_info,
|
||||||
|
mock_devices=[],
|
||||||
|
mock_groups=[],
|
||||||
|
mock_responses=[],
|
||||||
|
)
|
||||||
|
with patch("homeassistant.components.tradfri.Gateway", return_value=gateway), patch(
|
||||||
|
"homeassistant.components.tradfri.config_flow.Gateway", return_value=gateway
|
||||||
|
):
|
||||||
|
yield gateway
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="mock_api")
|
||||||
|
def mock_api_fixture(mock_gateway):
|
||||||
|
"""Mock api."""
|
||||||
|
|
||||||
|
async def api(command):
|
||||||
|
"""Mock api function."""
|
||||||
|
# Store the data for "real" command objects.
|
||||||
|
if hasattr(command, "_data") and not isinstance(command, Mock):
|
||||||
|
mock_gateway.mock_responses.append(command._data)
|
||||||
|
return command
|
||||||
|
|
||||||
|
return api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="api_factory")
|
||||||
|
def mock_api_factory_fixture(mock_api):
|
||||||
|
"""Mock pytradfri api factory."""
|
||||||
|
with patch("homeassistant.components.tradfri.APIFactory", autospec=True) as factory:
|
||||||
|
factory.init.return_value = factory.return_value
|
||||||
|
factory.return_value.request = mock_api
|
||||||
|
yield factory.return_value
|
||||||
|
@ -1,4 +1,9 @@
|
|||||||
"""Tests for Tradfri setup."""
|
"""Tests for Tradfri setup."""
|
||||||
|
from homeassistant.components import tradfri
|
||||||
|
from homeassistant.helpers.device_registry import (
|
||||||
|
async_entries_for_config_entry,
|
||||||
|
async_get_registry as async_get_device_registry,
|
||||||
|
)
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.async_mock import patch
|
from tests.async_mock import patch
|
||||||
@ -48,13 +53,15 @@ async def test_config_json_host_not_imported(hass):
|
|||||||
assert len(mock_init.mock_calls) == 0
|
assert len(mock_init.mock_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
async def test_config_json_host_imported(hass, mock_gateway_info, mock_entry_setup):
|
async def test_config_json_host_imported(
|
||||||
|
hass, mock_gateway_info, mock_entry_setup, gateway_id
|
||||||
|
):
|
||||||
"""Test that we import a configured host."""
|
"""Test that we import a configured host."""
|
||||||
mock_gateway_info.side_effect = lambda hass, host, identity, key: {
|
mock_gateway_info.side_effect = lambda hass, host, identity, key: {
|
||||||
"host": host,
|
"host": host,
|
||||||
"identity": identity,
|
"identity": identity,
|
||||||
"key": key,
|
"key": key,
|
||||||
"gateway_id": "mock-gateway",
|
"gateway_id": gateway_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
@ -68,3 +75,45 @@ async def test_config_json_host_imported(hass, mock_gateway_info, mock_entry_set
|
|||||||
assert config_entry.domain == "tradfri"
|
assert config_entry.domain == "tradfri"
|
||||||
assert config_entry.source == "import"
|
assert config_entry.source == "import"
|
||||||
assert config_entry.title == "mock-host"
|
assert config_entry.title == "mock-host"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_entry_setup_unload(hass, api_factory, gateway_id):
|
||||||
|
"""Test config entry setup and unload."""
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
domain=tradfri.DOMAIN,
|
||||||
|
data={
|
||||||
|
tradfri.CONF_HOST: "mock-host",
|
||||||
|
tradfri.CONF_IDENTITY: "mock-identity",
|
||||||
|
tradfri.CONF_KEY: "mock-key",
|
||||||
|
tradfri.CONF_IMPORT_GROUPS: True,
|
||||||
|
tradfri.CONF_GATEWAY_ID: gateway_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
with patch.object(
|
||||||
|
hass.config_entries, "async_forward_entry_setup", return_value=True
|
||||||
|
) as setup:
|
||||||
|
await hass.config_entries.async_setup(entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert setup.call_count == len(tradfri.PLATFORMS)
|
||||||
|
|
||||||
|
dev_reg = await async_get_device_registry(hass)
|
||||||
|
dev_entries = async_entries_for_config_entry(dev_reg, entry.entry_id)
|
||||||
|
|
||||||
|
assert dev_entries
|
||||||
|
dev_entry = dev_entries[0]
|
||||||
|
assert dev_entry.identifiers == {
|
||||||
|
(tradfri.DOMAIN, entry.data[tradfri.CONF_GATEWAY_ID])
|
||||||
|
}
|
||||||
|
assert dev_entry.manufacturer == tradfri.ATTR_TRADFRI_MANUFACTURER
|
||||||
|
assert dev_entry.name == tradfri.ATTR_TRADFRI_GATEWAY
|
||||||
|
assert dev_entry.model == tradfri.ATTR_TRADFRI_GATEWAY_MODEL
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
hass.config_entries, "async_forward_entry_unload", return_value=True
|
||||||
|
) as unload:
|
||||||
|
assert await hass.config_entries.async_unload(entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert unload.call_count == len(tradfri.PLATFORMS)
|
||||||
|
assert api_factory.shutdown.call_count == 1
|
||||||
|
@ -9,6 +9,8 @@ from pytradfri.device.light_control import LightControl
|
|||||||
|
|
||||||
from homeassistant.components import tradfri
|
from homeassistant.components import tradfri
|
||||||
|
|
||||||
|
from . import MOCK_GATEWAY_ID
|
||||||
|
|
||||||
from tests.async_mock import MagicMock, Mock, PropertyMock, patch
|
from tests.async_mock import MagicMock, Mock, PropertyMock, patch
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
@ -93,42 +95,6 @@ def setup(request):
|
|||||||
request.addfinalizer(teardown)
|
request.addfinalizer(teardown)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_gateway():
|
|
||||||
"""Mock a Tradfri gateway."""
|
|
||||||
|
|
||||||
def get_devices():
|
|
||||||
"""Return mock devices."""
|
|
||||||
return gateway.mock_devices
|
|
||||||
|
|
||||||
def get_groups():
|
|
||||||
"""Return mock groups."""
|
|
||||||
return gateway.mock_groups
|
|
||||||
|
|
||||||
gateway = Mock(
|
|
||||||
get_devices=get_devices,
|
|
||||||
get_groups=get_groups,
|
|
||||||
mock_devices=[],
|
|
||||||
mock_groups=[],
|
|
||||||
mock_responses=[],
|
|
||||||
)
|
|
||||||
return gateway
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_api(mock_gateway):
|
|
||||||
"""Mock api."""
|
|
||||||
|
|
||||||
async def api(command):
|
|
||||||
"""Mock api function."""
|
|
||||||
# Store the data for "real" command objects.
|
|
||||||
if hasattr(command, "_data") and not isinstance(command, Mock):
|
|
||||||
mock_gateway.mock_responses.append(command._data)
|
|
||||||
return command
|
|
||||||
|
|
||||||
return api
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_psk(self, code):
|
async def generate_psk(self, code):
|
||||||
"""Mock psk."""
|
"""Mock psk."""
|
||||||
return "mock"
|
return "mock"
|
||||||
@ -143,11 +109,14 @@ async def setup_gateway(hass, mock_gateway, mock_api):
|
|||||||
"identity": "mock-identity",
|
"identity": "mock-identity",
|
||||||
"key": "mock-key",
|
"key": "mock-key",
|
||||||
"import_groups": True,
|
"import_groups": True,
|
||||||
"gateway_id": "mock-gateway-id",
|
"gateway_id": MOCK_GATEWAY_ID,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
hass.data[tradfri.KEY_GATEWAY] = {entry.entry_id: mock_gateway}
|
tradfri_data = {}
|
||||||
hass.data[tradfri.KEY_API] = {entry.entry_id: mock_api}
|
hass.data[tradfri.DOMAIN] = {entry.entry_id: tradfri_data}
|
||||||
|
tradfri_data[tradfri.KEY_API] = mock_api
|
||||||
|
tradfri_data[tradfri.KEY_GATEWAY] = mock_gateway
|
||||||
|
|
||||||
await hass.config_entries.async_forward_entry_setup(entry, "light")
|
await hass.config_entries.async_forward_entry_setup(entry, "light")
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user