Refactor smlight event_function to common function (#126260)

refactor event_function
This commit is contained in:
TimL 2024-09-20 20:40:50 +10:00 committed by GitHub
parent f93bcbaa84
commit 76967e848d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 33 deletions

View File

@ -1 +1,21 @@
"""Tests for the SMLIGHT Zigbee adapter integration."""
from collections.abc import Callable
from unittest.mock import MagicMock
from pysmlight.const import Events as SmEvents
from pysmlight.sse import MessageEvent
def get_mock_event_function(
mock: MagicMock, event: SmEvents
) -> Callable[[MessageEvent], None]:
"""Extract event function from mock call_args."""
return next(
(
call_args[0][1]
for call_args in mock.sse.register_callback.call_args_list
if call_args[0][0] == event
),
None,
)

View File

@ -1,6 +1,5 @@
"""Tests for the SMLIGHT binary sensor platform."""
from collections.abc import Callable
from unittest.mock import MagicMock
from freezegun.api import FrozenDateTimeFactory
@ -14,6 +13,7 @@ from homeassistant.const import STATE_ON, STATE_UNKNOWN, Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from . import get_mock_event_function
from .conftest import setup_integration
from tests.common import MockConfigEntry, async_fire_time_changed, snapshot_platform
@ -95,13 +95,8 @@ async def test_internet_sensor_event(
assert len(mock_smlight_client.get_param.mock_calls) == 2
mock_smlight_client.get_param.assert_called_with("inetState")
event_function: Callable[[MessageEvent], None] = next(
(
call_args[0][1]
for call_args in mock_smlight_client.sse.register_callback.call_args_list
if call_args[0][0] == Events.EVENT_INET_STATE
),
None,
event_function = get_mock_event_function(
mock_smlight_client, Events.EVENT_INET_STATE
)
event_function(MOCK_INET_STATE)

View File

@ -1,6 +1,5 @@
"""Tests for the SMLIGHT update platform."""
from collections.abc import Callable
from unittest.mock import MagicMock
from freezegun.api import FrozenDateTimeFactory
@ -23,6 +22,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er
from . import get_mock_event_function
from .conftest import setup_integration
from tests.common import MockConfigEntry, async_fire_time_changed, snapshot_platform
@ -67,18 +67,6 @@ MOCK_FIRMWARE_NOTES = [
]
def get_callback_function(mock: MagicMock, trigger: SmEvents):
"""Extract the callback function for a given trigger."""
return next(
(
call_args[0][1]
for call_args in mock.sse.register_callback.call_args_list
if trigger == call_args[0][0]
),
None,
)
@pytest.fixture
def platforms() -> list[Platform]:
"""Platforms, which should be loaded during the test."""
@ -122,17 +110,13 @@ async def test_update_firmware(
assert len(mock_smlight_client.fw_update.mock_calls) == 1
event_function: Callable[[MessageEvent], None] = get_callback_function(
mock_smlight_client, SmEvents.ZB_FW_prgs
)
event_function = get_mock_event_function(mock_smlight_client, SmEvents.ZB_FW_prgs)
event_function(MOCK_FIRMWARE_PROGRESS)
state = hass.states.get(entity_id)
assert state.attributes[ATTR_IN_PROGRESS] == 50
event_function: Callable[[MessageEvent], None] = get_callback_function(
mock_smlight_client, SmEvents.FW_UPD_done
)
event_function = get_mock_event_function(mock_smlight_client, SmEvents.FW_UPD_done)
event_function(MOCK_FIRMWARE_DONE)
@ -178,9 +162,7 @@ async def test_update_legacy_firmware_v2(
assert len(mock_smlight_client.fw_update.mock_calls) == 1
event_function: Callable[[MessageEvent], None] = get_callback_function(
mock_smlight_client, SmEvents.ESP_UPD_done
)
event_function = get_mock_event_function(mock_smlight_client, SmEvents.ESP_UPD_done)
event_function(MOCK_FIRMWARE_DONE)
@ -220,9 +202,7 @@ async def test_update_firmware_failed(
assert len(mock_smlight_client.fw_update.mock_calls) == 1
event_function: Callable[[MessageEvent], None] = get_callback_function(
mock_smlight_client, SmEvents.ZB_FW_err
)
event_function = get_mock_event_function(mock_smlight_client, SmEvents.ZB_FW_err)
async def _call_event_function(event: MessageEvent):
event_function(event)