Refactor tradfri switch tests (#86816)

Co-authored-by: Patrik Lindgren <21142447+ggravlingen@users.noreply.github.com>
This commit is contained in:
Martin Hjelmare 2023-02-05 21:02:17 +01:00 committed by GitHub
parent a2530e7f19
commit 0aa489e3f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 144 additions and 128 deletions

View File

@ -1,5 +1,12 @@
"""Common tools used for the Tradfri test suite.""" """Common tools used for the Tradfri test suite."""
from copy import deepcopy
from typing import Any
from unittest.mock import Mock
from pytradfri.device import Device
from homeassistant.components import tradfri from homeassistant.components import tradfri
from homeassistant.core import HomeAssistant
from . import GATEWAY_ID from . import GATEWAY_ID
@ -23,3 +30,47 @@ async def setup_integration(hass):
await hass.async_block_till_done() await hass.async_block_till_done()
return entry return entry
def modify_state(
state: dict[str, Any], partial_state: dict[str, Any]
) -> dict[str, Any]:
"""Modify a state with a partial state."""
for key, value in partial_state.items():
if isinstance(value, list):
for index, item in enumerate(value):
state[key][index] = modify_state(state[key][index], item)
elif isinstance(value, dict):
state[key] = modify_state(state[key], value)
else:
state[key] = value
return state
async def trigger_observe_callback(
hass: HomeAssistant,
mock_gateway: Mock,
device: Device,
new_device_state: dict[str, Any] | None = None,
) -> None:
"""Trigger the observe callback."""
observe_command = next(
(
command
for command in mock_gateway.mock_commands
if command.path == device.path and command.observe
),
None,
)
assert observe_command
if new_device_state is not None:
mock_gateway.mock_responses.append(new_device_state)
device_state = deepcopy(device.raw)
new_state = mock_gateway.mock_responses[-1]
device_state = modify_state(device_state, new_state)
observe_command.process_result(device_state)
await hass.async_block_till_done()

View File

@ -1,5 +1,8 @@
"""Common tradfri test fixtures.""" """Common tradfri test fixtures."""
from unittest.mock import Mock, PropertyMock, patch from __future__ import annotations
from collections.abc import Generator
from unittest.mock import MagicMock, Mock, PropertyMock, patch
import pytest import pytest
@ -43,6 +46,7 @@ def mock_gateway_fixture():
get_devices=get_devices, get_devices=get_devices,
get_groups=get_groups, get_groups=get_groups,
get_gateway_info=get_gateway_info, get_gateway_info=get_gateway_info,
mock_commands=[],
mock_devices=[], mock_devices=[],
mock_groups=[], mock_groups=[],
mock_responses=[], mock_responses=[],
@ -62,13 +66,14 @@ def mock_api_fixture(mock_gateway):
# Store the data for "real" command objects. # Store the data for "real" command objects.
if hasattr(command, "_data") and not isinstance(command, Mock): if hasattr(command, "_data") and not isinstance(command, Mock):
mock_gateway.mock_responses.append(command._data) mock_gateway.mock_responses.append(command._data)
mock_gateway.mock_commands.append(command)
return command return command
return api return api
@pytest.fixture @pytest.fixture
def mock_api_factory(mock_api): def mock_api_factory(mock_api) -> Generator[MagicMock, None, None]:
"""Mock pytradfri api factory.""" """Mock pytradfri api factory."""
with patch(f"{TRADFRI_PATH}.APIFactory", autospec=True) as factory: with patch(f"{TRADFRI_PATH}.APIFactory", autospec=True) as factory:
factory.init.return_value = factory.return_value factory.init.return_value = factory.return_value

View File

@ -0,0 +1,18 @@
{
"9001": "Test",
"9002": 1536968250,
"9020": 1536968280,
"9003": 65548,
"9054": 0,
"5750": 3,
"9019": 1,
"9084": " 43 86 6e b5 6a df dc da d6 ce 9c 5a b4 63 a4 2a",
"3": {
"0": "IKEA of Sweden",
"1": "TRADFRI control outlet",
"3": "1.4.020",
"2": "",
"6": 1
},
"3312": [{ "9003": 0, "5850": 0 }]
}

View File

@ -1,160 +1,102 @@
"""Tradfri switch (recognised as sockets in the IKEA ecosystem) platform tests.""" """Tradfri switch (recognised as sockets in the IKEA ecosystem) platform tests."""
from __future__ import annotations
from unittest.mock import MagicMock, Mock, PropertyMock, patch import json
from typing import Any
from unittest.mock import MagicMock, Mock
import pytest import pytest
from pytradfri.const import ATTR_REACHABLE_STATE
from pytradfri.device import Device from pytradfri.device import Device
from pytradfri.device.socket import Socket from pytradfri.device.socket import Socket
from pytradfri.device.socket_control import SocketControl
from .common import setup_integration from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN
from homeassistant.components.tradfri.const import DOMAIN
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE
from homeassistant.core import HomeAssistant
from .common import setup_integration, trigger_observe_callback
from tests.common import load_fixture
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(scope="module")
def setup(request): def outlet() -> dict[str, Any]:
"""Set up patches for pytradfri methods.""" """Return an outlet response."""
with patch( return json.loads(load_fixture("outlet.json", DOMAIN))
"pytradfri.device.SocketControl.raw",
new_callable=PropertyMock,
return_value=[{"mock": "mock"}],
), patch(
"pytradfri.device.SocketControl.sockets",
):
yield
def mock_switch(test_features=None, test_state=None, device_number=0): @pytest.fixture
"""Mock a tradfri switch/socket.""" def socket(outlet: dict[str, Any]) -> Socket:
if test_features is None: """Return socket."""
test_features = {} device = Device(outlet)
if test_state is None: socket_control = device.socket_control
test_state = {} assert socket_control
mock_switch_data = Mock(**test_state) return socket_control.sockets[0]
dev_info_mock = MagicMock()
dev_info_mock.manufacturer = "manufacturer"
dev_info_mock.model_number = "model"
dev_info_mock.firmware_version = "1.2.3"
_mock_switch = Mock(
id=f"mock-switch-id-{device_number}",
reachable=True,
observe=Mock(),
device_info=dev_info_mock,
has_light_control=False,
has_socket_control=True,
has_blind_control=False,
has_signal_repeater_control=False,
has_air_purifier_control=False,
)
_mock_switch.name = f"tradfri_switch_{device_number}"
socket_control = SocketControl(_mock_switch)
# Store the initial state.
setattr(socket_control, "sockets", [mock_switch_data])
_mock_switch.socket_control = socket_control
return _mock_switch
async def test_switch(hass, mock_gateway, mock_api_factory): async def test_switch_available(
"""Test that switches are correctly added.""" hass: HomeAssistant,
state = { mock_gateway: Mock,
"state": True, mock_api_factory: MagicMock,
} socket: Socket,
) -> None:
mock_gateway.mock_devices.append(mock_switch(test_state=state))
await setup_integration(hass)
switch_1 = hass.states.get("switch.tradfri_switch_0")
assert switch_1 is not None
assert switch_1.state == "on"
async def test_switch_observed(hass, mock_gateway, mock_api_factory):
"""Test that switches are correctly observed."""
state = {
"state": True,
}
switch = mock_switch(test_state=state)
mock_gateway.mock_devices.append(switch)
await setup_integration(hass)
assert len(switch.observe.mock_calls) > 0
async def test_switch_available(hass, mock_gateway, mock_api_factory):
"""Test switch available property.""" """Test switch available property."""
entity_id = "switch.test"
switch = mock_switch(test_state={"state": True}, device_number=1) device = socket.device
switch.reachable = True mock_gateway.mock_devices.append(device)
switch2 = mock_switch(test_state={"state": True}, device_number=2)
switch2.reachable = False
mock_gateway.mock_devices.append(switch)
mock_gateway.mock_devices.append(switch2)
await setup_integration(hass) await setup_integration(hass)
assert hass.states.get("switch.tradfri_switch_1").state == "on" state = hass.states.get(entity_id)
assert hass.states.get("switch.tradfri_switch_2").state == "unavailable" assert state
assert state.state == STATE_OFF
await trigger_observe_callback(
hass, mock_gateway, device, {ATTR_REACHABLE_STATE: 0}
)
state = hass.states.get(entity_id)
assert state
assert state.state == STATE_UNAVAILABLE
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_data, expected_result", "service, expected_state",
[ [
( ("turn_on", STATE_ON),
"turn_on", ("turn_off", STATE_OFF),
"on",
),
("turn_off", "off"),
], ],
) )
async def test_turn_on_off( async def test_turn_on_off(
hass, hass: HomeAssistant,
mock_gateway, mock_gateway: Mock,
mock_api_factory, mock_api_factory: MagicMock,
test_data, socket: Socket,
expected_result, service: str,
): expected_state: str,
) -> None:
"""Test turning switch on/off.""" """Test turning switch on/off."""
# Note pytradfri style, not hass. Values not really important. entity_id = "switch.test"
initial_state = { device = socket.device
"state": True, mock_gateway.mock_devices.append(device)
}
# Setup the gateway with a mock switch.
switch = mock_switch(test_state=initial_state, device_number=0)
mock_gateway.mock_devices.append(switch)
await setup_integration(hass) await setup_integration(hass)
# Use the turn_on/turn_off service call to change the switch state. state = hass.states.get(entity_id)
assert state
assert state.state == STATE_OFF
await hass.services.async_call( await hass.services.async_call(
"switch", SWITCH_DOMAIN,
test_data, service,
{ {
"entity_id": "switch.tradfri_switch_0", "entity_id": entity_id,
}, },
blocking=True, blocking=True,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
# Check that the switch is observed. await trigger_observe_callback(hass, mock_gateway, device)
mock_func = switch.observe
assert len(mock_func.mock_calls) > 0
_, callkwargs = mock_func.call_args
assert "callback" in callkwargs
# Callback function to refresh switch state.
callback = callkwargs["callback"]
responses = mock_gateway.mock_responses state = hass.states.get(entity_id)
mock_gateway_response = responses[0] assert state
assert state.state == expected_state
# Use the callback function to update the switch state.
dev = Device(mock_gateway_response)
switch_data = Socket(dev, 0)
switch.socket_control.sockets[0] = switch_data
callback(switch)
await hass.async_block_till_done()
# Check that the state is correct.
state = hass.states.get("switch.tradfri_switch_0")
assert state.state == expected_result