Refactor smlight event_function to common function (#126260)

refactor event_function
This commit is contained in:
TimL 2024-09-20 20:40:50 +10:00 committed by GitHub
parent f93bcbaa84
commit 76967e848d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 33 deletions

View File

@ -1 +1,21 @@
"""Tests for the SMLIGHT Zigbee adapter integration.""" """Tests for the SMLIGHT Zigbee adapter integration."""
from collections.abc import Callable
from unittest.mock import MagicMock
from pysmlight.const import Events as SmEvents
from pysmlight.sse import MessageEvent
def get_mock_event_function(
mock: MagicMock, event: SmEvents
) -> Callable[[MessageEvent], None]:
"""Extract event function from mock call_args."""
return next(
(
call_args[0][1]
for call_args in mock.sse.register_callback.call_args_list
if call_args[0][0] == event
),
None,
)

View File

@ -1,6 +1,5 @@
"""Tests for the SMLIGHT binary sensor platform.""" """Tests for the SMLIGHT binary sensor platform."""
from collections.abc import Callable
from unittest.mock import MagicMock from unittest.mock import MagicMock
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
@ -14,6 +13,7 @@ from homeassistant.const import STATE_ON, STATE_UNKNOWN, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from . import get_mock_event_function
from .conftest import setup_integration from .conftest import setup_integration
from tests.common import MockConfigEntry, async_fire_time_changed, snapshot_platform from tests.common import MockConfigEntry, async_fire_time_changed, snapshot_platform
@ -95,13 +95,8 @@ async def test_internet_sensor_event(
assert len(mock_smlight_client.get_param.mock_calls) == 2 assert len(mock_smlight_client.get_param.mock_calls) == 2
mock_smlight_client.get_param.assert_called_with("inetState") mock_smlight_client.get_param.assert_called_with("inetState")
event_function: Callable[[MessageEvent], None] = next( event_function = get_mock_event_function(
( mock_smlight_client, Events.EVENT_INET_STATE
call_args[0][1]
for call_args in mock_smlight_client.sse.register_callback.call_args_list
if call_args[0][0] == Events.EVENT_INET_STATE
),
None,
) )
event_function(MOCK_INET_STATE) event_function(MOCK_INET_STATE)

View File

@ -1,6 +1,5 @@
"""Tests for the SMLIGHT update platform.""" """Tests for the SMLIGHT update platform."""
from collections.abc import Callable
from unittest.mock import MagicMock from unittest.mock import MagicMock
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
@ -23,6 +22,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from . import get_mock_event_function
from .conftest import setup_integration from .conftest import setup_integration
from tests.common import MockConfigEntry, async_fire_time_changed, snapshot_platform from tests.common import MockConfigEntry, async_fire_time_changed, snapshot_platform
@ -67,18 +67,6 @@ MOCK_FIRMWARE_NOTES = [
] ]
def get_callback_function(mock: MagicMock, trigger: SmEvents):
"""Extract the callback function for a given trigger."""
return next(
(
call_args[0][1]
for call_args in mock.sse.register_callback.call_args_list
if trigger == call_args[0][0]
),
None,
)
@pytest.fixture @pytest.fixture
def platforms() -> list[Platform]: def platforms() -> list[Platform]:
"""Platforms, which should be loaded during the test.""" """Platforms, which should be loaded during the test."""
@ -122,17 +110,13 @@ async def test_update_firmware(
assert len(mock_smlight_client.fw_update.mock_calls) == 1 assert len(mock_smlight_client.fw_update.mock_calls) == 1
event_function: Callable[[MessageEvent], None] = get_callback_function( event_function = get_mock_event_function(mock_smlight_client, SmEvents.ZB_FW_prgs)
mock_smlight_client, SmEvents.ZB_FW_prgs
)
event_function(MOCK_FIRMWARE_PROGRESS) event_function(MOCK_FIRMWARE_PROGRESS)
state = hass.states.get(entity_id) state = hass.states.get(entity_id)
assert state.attributes[ATTR_IN_PROGRESS] == 50 assert state.attributes[ATTR_IN_PROGRESS] == 50
event_function: Callable[[MessageEvent], None] = get_callback_function( event_function = get_mock_event_function(mock_smlight_client, SmEvents.FW_UPD_done)
mock_smlight_client, SmEvents.FW_UPD_done
)
event_function(MOCK_FIRMWARE_DONE) event_function(MOCK_FIRMWARE_DONE)
@ -178,9 +162,7 @@ async def test_update_legacy_firmware_v2(
assert len(mock_smlight_client.fw_update.mock_calls) == 1 assert len(mock_smlight_client.fw_update.mock_calls) == 1
event_function: Callable[[MessageEvent], None] = get_callback_function( event_function = get_mock_event_function(mock_smlight_client, SmEvents.ESP_UPD_done)
mock_smlight_client, SmEvents.ESP_UPD_done
)
event_function(MOCK_FIRMWARE_DONE) event_function(MOCK_FIRMWARE_DONE)
@ -220,9 +202,7 @@ async def test_update_firmware_failed(
assert len(mock_smlight_client.fw_update.mock_calls) == 1 assert len(mock_smlight_client.fw_update.mock_calls) == 1
event_function: Callable[[MessageEvent], None] = get_callback_function( event_function = get_mock_event_function(mock_smlight_client, SmEvents.ZB_FW_err)
mock_smlight_client, SmEvents.ZB_FW_err
)
async def _call_event_function(event: MessageEvent): async def _call_event_function(event: MessageEvent):
event_function(event) event_function(event)