diff --git a/tests/components/lutron_caseta/__init__.py b/tests/components/lutron_caseta/__init__.py index b27d30ac31f..bb7131819aa 100644 --- a/tests/components/lutron_caseta/__init__.py +++ b/tests/components/lutron_caseta/__init__.py @@ -1,5 +1,6 @@ """Tests for the Lutron Caseta integration.""" +from collections import defaultdict from unittest.mock import patch from homeassistant.components.lutron_caseta import DOMAIN @@ -84,7 +85,9 @@ _LEAP_DEVICE_TYPES = { } -async def async_setup_integration(hass: HomeAssistant, mock_bridge) -> MockConfigEntry: +async def async_setup_integration( + hass: HomeAssistant, mock_bridge +) -> tuple[MockConfigEntry, "MockBridge"]: """Set up a mock bridge.""" mock_entry = MockConfigEntry(domain=DOMAIN, data=ENTRY_MOCK_DATA) mock_entry.add_to_hass(hass) @@ -92,10 +95,12 @@ async def async_setup_integration(hass: HomeAssistant, mock_bridge) -> MockConfi with patch( "homeassistant.components.lutron_caseta.Smartbridge.create_tls" ) as create_tls: - create_tls.return_value = mock_bridge(can_connect=True) + mocked_bridge = mock_bridge(can_connect=True) + create_tls.return_value = mocked_bridge await hass.config_entries.async_setup(mock_entry.entry_id) await hass.async_block_till_done() - return mock_entry + + return mock_entry, mocked_bridge class MockBridge: @@ -110,6 +115,7 @@ class MockBridge: self.scenes = self.get_scenes() self.devices = self.load_devices() self.buttons = self.load_buttons() + self.button_subscribers: defaultdict[str, list] = defaultdict(list) async def connect(self): """Connect the mock bridge.""" @@ -121,6 +127,7 @@ class MockBridge: def add_button_subscriber(self, button_id: str, callback_): """Mock a listener for button presses.""" + self.button_subscribers[button_id].append(callback_) def is_connected(self): """Return whether the mock bridge is connected.""" diff --git a/tests/components/lutron_caseta/test_device_trigger.py b/tests/components/lutron_caseta/test_device_trigger.py index 5b1dc7ae381..04eac003603 100644 --- a/tests/components/lutron_caseta/test_device_trigger.py +++ b/tests/components/lutron_caseta/test_device_trigger.py @@ -157,6 +157,7 @@ async def test_get_triggers(hass: HomeAssistant) -> None: triggers = await async_get_device_automations( hass, DeviceAutomationType.TRIGGER, device_id ) + triggers = [trigger for trigger in triggers if trigger[CONF_DOMAIN] == DOMAIN] assert triggers == unordered(expected_triggers)