diff --git a/homeassistant/components/bluetooth/websocket_api.py b/homeassistant/components/bluetooth/websocket_api.py index 9022d98bf06..042fe3fe24b 100644 --- a/homeassistant/components/bluetooth/websocket_api.py +++ b/homeassistant/components/bluetooth/websocket_api.py @@ -8,8 +8,10 @@ import time from typing import Any from habluetooth import ( + BaseHaScanner, BluetoothScanningMode, HaBluetoothSlotAllocations, + HaScannerModeChange, HaScannerRegistration, HaScannerRegistrationEvent, ) @@ -27,12 +29,54 @@ from .models import BluetoothChange from .util import InvalidConfigEntryID, InvalidSource, config_entry_id_to_source +@callback +def _async_get_source_from_config_entry( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg_id: int, + config_entry_id: str | None, + validate_source: bool = True, +) -> str | None: + """Get source from config entry id. + + Returns None if no config_entry_id provided or on error (after sending error response). + If validate_source is True, also validates that the scanner exists. + """ + if not config_entry_id: + return None + + if validate_source: + # Use the full validation that checks if scanner exists + try: + return config_entry_id_to_source(hass, config_entry_id) + except InvalidConfigEntryID as err: + connection.send_error(msg_id, "invalid_config_entry_id", str(err)) + return None + except InvalidSource as err: + connection.send_error(msg_id, "invalid_source", str(err)) + return None + + # Just check if config entry exists and belongs to bluetooth + if ( + not (entry := hass.config_entries.async_get_entry(config_entry_id)) + or entry.domain != DOMAIN + ): + connection.send_error( + msg_id, + "invalid_config_entry_id", + f"Config entry {config_entry_id} not found", + ) + return None + return entry.unique_id + + @callback def async_setup(hass: HomeAssistant) -> None: """Set up the bluetooth websocket API.""" websocket_api.async_register_command(hass, ws_subscribe_advertisements) websocket_api.async_register_command(hass, ws_subscribe_connection_allocations) websocket_api.async_register_command(hass, ws_subscribe_scanner_details) + websocket_api.async_register_command(hass, ws_subscribe_scanner_state) @lru_cache(maxsize=1024) @@ -180,16 +224,12 @@ async def ws_subscribe_connection_allocations( ) -> None: """Handle subscribe advertisements websocket command.""" ws_msg_id = msg["id"] - source: str | None = None - if config_entry_id := msg.get("config_entry_id"): - try: - source = config_entry_id_to_source(hass, config_entry_id) - except InvalidConfigEntryID as err: - connection.send_error(ws_msg_id, "invalid_config_entry_id", str(err)) - return - except InvalidSource as err: - connection.send_error(ws_msg_id, "invalid_source", str(err)) - return + config_entry_id = msg.get("config_entry_id") + source = _async_get_source_from_config_entry( + hass, connection, ws_msg_id, config_entry_id + ) + if config_entry_id and source is None: + return # Error already sent by helper def _async_allocations_changed(allocations: HaBluetoothSlotAllocations) -> None: connection.send_message( @@ -220,20 +260,12 @@ async def ws_subscribe_scanner_details( ) -> None: """Handle subscribe scanner details websocket command.""" ws_msg_id = msg["id"] - source: str | None = None - if config_entry_id := msg.get("config_entry_id"): - if ( - not (entry := hass.config_entries.async_get_entry(config_entry_id)) - or entry.domain != DOMAIN - ): - connection.send_error( - ws_msg_id, - "invalid_config_entry_id", - f"Invalid config entry id: {config_entry_id}", - ) - return - source = entry.unique_id - assert source is not None + config_entry_id = msg.get("config_entry_id") + source = _async_get_source_from_config_entry( + hass, connection, ws_msg_id, config_entry_id, validate_source=False + ) + if config_entry_id and source is None: + return # Error already sent by helper def _async_event_message(message: dict[str, Any]) -> None: connection.send_message( @@ -260,3 +292,70 @@ async def ws_subscribe_scanner_details( ] ): _async_event_message({"add": matching_scanners}) + + +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "bluetooth/subscribe_scanner_state", + vol.Optional("config_entry_id"): str, + } +) +@websocket_api.async_response +async def ws_subscribe_scanner_state( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] +) -> None: + """Handle subscribe scanner state websocket command.""" + ws_msg_id = msg["id"] + config_entry_id = msg.get("config_entry_id") + source = _async_get_source_from_config_entry( + hass, connection, ws_msg_id, config_entry_id, validate_source=False + ) + if config_entry_id and source is None: + return # Error already sent by helper + + @callback + def _async_send_scanner_state( + scanner: BaseHaScanner, + current_mode: BluetoothScanningMode | None, + requested_mode: BluetoothScanningMode | None, + ) -> None: + payload = { + "source": scanner.source, + "adapter": scanner.adapter, + "current_mode": current_mode.value if current_mode else None, + "requested_mode": requested_mode.value if requested_mode else None, + } + connection.send_message( + json_bytes( + websocket_api.event_message( + ws_msg_id, + payload, + ) + ) + ) + + @callback + def _async_scanner_state_changed(mode_change: HaScannerModeChange) -> None: + _async_send_scanner_state( + mode_change.scanner, + mode_change.current_mode, + mode_change.requested_mode, + ) + + manager = _get_manager(hass) + connection.subscriptions[ws_msg_id] = ( + manager.async_register_scanner_mode_change_callback( + _async_scanner_state_changed, source + ) + ) + connection.send_message(json_bytes(websocket_api.result_message(ws_msg_id))) + + # Send initial state for all matching scanners + for scanner in manager.async_current_scanners(): + if source is None or scanner.source == source: + _async_send_scanner_state( + scanner, + scanner.current_mode, + scanner.requested_mode, + ) diff --git a/tests/components/bluetooth/test_websocket_api.py b/tests/components/bluetooth/test_websocket_api.py index 19693db4000..1bb76065a5d 100644 --- a/tests/components/bluetooth/test_websocket_api.py +++ b/tests/components/bluetooth/test_websocket_api.py @@ -7,6 +7,7 @@ from unittest.mock import ANY, patch from bleak_retry_connector import Allocations from freezegun import freeze_time +from habluetooth import BluetoothScanningMode import pytest from homeassistant.components.bluetooth import DOMAIN @@ -440,4 +441,126 @@ async def test_subscribe_scanner_details_invalid_config_entry_id( response = await client.receive_json() assert not response["success"] assert response["error"]["code"] == "invalid_config_entry_id" - assert response["error"]["message"] == "Invalid config entry id: non_existent" + assert response["error"]["message"] == "Config entry non_existent not found" + + +@pytest.mark.usefixtures("enable_bluetooth") +async def test_subscribe_scanner_state( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test bluetooth subscribe_scanner_state.""" + client = await hass_ws_client() + await client.send_json( + { + "id": 1, + "type": "bluetooth/subscribe_scanner_state", + } + ) + async with asyncio.timeout(1): + response = await client.receive_json() + assert response["success"] + + # Should receive initial state for existing scanner + async with asyncio.timeout(1): + response = await client.receive_json() + assert response["event"] == { + "source": "00:00:00:00:00:01", + "adapter": "hci0", + "current_mode": "active", + "requested_mode": "active", + } + + # Register a new scanner + manager = _get_manager() + hci3_scanner = FakeScanner("AA:BB:CC:DD:EE:33", "hci3") + cancel_hci3 = manager.async_register_hass_scanner(hci3_scanner) + + # Simulate a mode change + hci3_scanner.current_mode = BluetoothScanningMode.ACTIVE + hci3_scanner.requested_mode = BluetoothScanningMode.ACTIVE + manager.scanner_mode_changed(hci3_scanner) + + async with asyncio.timeout(1): + response = await client.receive_json() + assert response["event"] == { + "source": "AA:BB:CC:DD:EE:33", + "adapter": "hci3", + "current_mode": "active", + "requested_mode": "active", + } + + cancel_hci3() + + +@pytest.mark.usefixtures("enable_bluetooth") +async def test_subscribe_scanner_state_specific_scanner( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test bluetooth subscribe_scanner_state for a specific source address.""" + # Register the scanner first + manager = _get_manager() + hci3_scanner = FakeScanner("AA:BB:CC:DD:EE:33", "hci3") + cancel_hci3 = manager.async_register_hass_scanner(hci3_scanner) + + entry = MockConfigEntry(domain=DOMAIN, unique_id="AA:BB:CC:DD:EE:33") + entry.add_to_hass(hass) + client = await hass_ws_client() + await client.send_json( + { + "id": 1, + "type": "bluetooth/subscribe_scanner_state", + "config_entry_id": entry.entry_id, + } + ) + async with asyncio.timeout(1): + response = await client.receive_json() + assert response["success"] + + # Should receive initial state + async with asyncio.timeout(1): + response = await client.receive_json() + assert response["event"] == { + "source": "AA:BB:CC:DD:EE:33", + "adapter": "hci3", + "current_mode": None, + "requested_mode": None, + } + + # Simulate a mode change + hci3_scanner.current_mode = BluetoothScanningMode.PASSIVE + hci3_scanner.requested_mode = BluetoothScanningMode.ACTIVE + manager.scanner_mode_changed(hci3_scanner) + + async with asyncio.timeout(1): + response = await client.receive_json() + assert response["event"] == { + "source": "AA:BB:CC:DD:EE:33", + "adapter": "hci3", + "current_mode": "passive", + "requested_mode": "active", + } + + cancel_hci3() + + +@pytest.mark.usefixtures("enable_bluetooth") +async def test_subscribe_scanner_state_invalid_config_entry_id( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test bluetooth subscribe_scanner_state for an invalid config entry id.""" + client = await hass_ws_client() + await client.send_json( + { + "id": 1, + "type": "bluetooth/subscribe_scanner_state", + "config_entry_id": "non_existent", + } + ) + async with asyncio.timeout(1): + response = await client.receive_json() + assert not response["success"] + assert response["error"]["code"] == "invalid_config_entry_id" + assert response["error"]["message"] == "Config entry non_existent not found"