From 76967e848db22d8134b41be65e51bb856b0d87c5 Mon Sep 17 00:00:00 2001 From: TimL Date: Fri, 20 Sep 2024 20:40:50 +1000 Subject: [PATCH] Refactor smlight event_function to common function (#126260) refactor event_function --- tests/components/smlight/__init__.py | 20 +++++++++++++ .../components/smlight/test_binary_sensor.py | 11 ++----- tests/components/smlight/test_update.py | 30 ++++--------------- 3 files changed, 28 insertions(+), 33 deletions(-) diff --git a/tests/components/smlight/__init__.py b/tests/components/smlight/__init__.py index 37184226507..e518e0573ba 100644 --- a/tests/components/smlight/__init__.py +++ b/tests/components/smlight/__init__.py @@ -1 +1,21 @@ """Tests for the SMLIGHT Zigbee adapter integration.""" + +from collections.abc import Callable +from unittest.mock import MagicMock + +from pysmlight.const import Events as SmEvents +from pysmlight.sse import MessageEvent + + +def get_mock_event_function( + mock: MagicMock, event: SmEvents +) -> Callable[[MessageEvent], None]: + """Extract event function from mock call_args.""" + return next( + ( + call_args[0][1] + for call_args in mock.sse.register_callback.call_args_list + if call_args[0][0] == event + ), + None, + ) diff --git a/tests/components/smlight/test_binary_sensor.py b/tests/components/smlight/test_binary_sensor.py index 1b1c0358c37..b1d72b66dcf 100644 --- a/tests/components/smlight/test_binary_sensor.py +++ b/tests/components/smlight/test_binary_sensor.py @@ -1,6 +1,5 @@ """Tests for the SMLIGHT binary sensor platform.""" -from collections.abc import Callable from unittest.mock import MagicMock from freezegun.api import FrozenDateTimeFactory @@ -14,6 +13,7 @@ from homeassistant.const import STATE_ON, STATE_UNKNOWN, Platform from homeassistant.core import HomeAssistant from homeassistant.helpers import entity_registry as er +from . import get_mock_event_function from .conftest import setup_integration from tests.common import MockConfigEntry, async_fire_time_changed, snapshot_platform @@ -95,13 +95,8 @@ async def test_internet_sensor_event( assert len(mock_smlight_client.get_param.mock_calls) == 2 mock_smlight_client.get_param.assert_called_with("inetState") - event_function: Callable[[MessageEvent], None] = next( - ( - call_args[0][1] - for call_args in mock_smlight_client.sse.register_callback.call_args_list - if call_args[0][0] == Events.EVENT_INET_STATE - ), - None, + event_function = get_mock_event_function( + mock_smlight_client, Events.EVENT_INET_STATE ) event_function(MOCK_INET_STATE) diff --git a/tests/components/smlight/test_update.py b/tests/components/smlight/test_update.py index b0b8910ef9b..7bff12bb027 100644 --- a/tests/components/smlight/test_update.py +++ b/tests/components/smlight/test_update.py @@ -1,6 +1,5 @@ """Tests for the SMLIGHT update platform.""" -from collections.abc import Callable from unittest.mock import MagicMock from freezegun.api import FrozenDateTimeFactory @@ -23,6 +22,7 @@ from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import entity_registry as er +from . import get_mock_event_function from .conftest import setup_integration from tests.common import MockConfigEntry, async_fire_time_changed, snapshot_platform @@ -67,18 +67,6 @@ MOCK_FIRMWARE_NOTES = [ ] -def get_callback_function(mock: MagicMock, trigger: SmEvents): - """Extract the callback function for a given trigger.""" - return next( - ( - call_args[0][1] - for call_args in mock.sse.register_callback.call_args_list - if trigger == call_args[0][0] - ), - None, - ) - - @pytest.fixture def platforms() -> list[Platform]: """Platforms, which should be loaded during the test.""" @@ -122,17 +110,13 @@ async def test_update_firmware( assert len(mock_smlight_client.fw_update.mock_calls) == 1 - event_function: Callable[[MessageEvent], None] = get_callback_function( - mock_smlight_client, SmEvents.ZB_FW_prgs - ) + event_function = get_mock_event_function(mock_smlight_client, SmEvents.ZB_FW_prgs) event_function(MOCK_FIRMWARE_PROGRESS) state = hass.states.get(entity_id) assert state.attributes[ATTR_IN_PROGRESS] == 50 - event_function: Callable[[MessageEvent], None] = get_callback_function( - mock_smlight_client, SmEvents.FW_UPD_done - ) + event_function = get_mock_event_function(mock_smlight_client, SmEvents.FW_UPD_done) event_function(MOCK_FIRMWARE_DONE) @@ -178,9 +162,7 @@ async def test_update_legacy_firmware_v2( assert len(mock_smlight_client.fw_update.mock_calls) == 1 - event_function: Callable[[MessageEvent], None] = get_callback_function( - mock_smlight_client, SmEvents.ESP_UPD_done - ) + event_function = get_mock_event_function(mock_smlight_client, SmEvents.ESP_UPD_done) event_function(MOCK_FIRMWARE_DONE) @@ -220,9 +202,7 @@ async def test_update_firmware_failed( assert len(mock_smlight_client.fw_update.mock_calls) == 1 - event_function: Callable[[MessageEvent], None] = get_callback_function( - mock_smlight_client, SmEvents.ZB_FW_err - ) + event_function = get_mock_event_function(mock_smlight_client, SmEvents.ZB_FW_err) async def _call_event_function(event: MessageEvent): event_function(event)