diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 61c62a1eaa4..4014c2162dd 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -1272,13 +1272,23 @@ async def websocket_remove_device(hass, connection, msg): dev_registry = await get_dev_reg(hass) device = dev_registry.async_get(device_id) + if not device: + connection.send_error( + msg["id"], websocket_api.const.ERR_NOT_FOUND, "Device not found" + ) + return + for config_entry in device.config_entries: config_entry = hass.config_entries.async_get_entry(config_entry) # Only delete the device if it belongs to an MQTT device entry if config_entry.domain == DOMAIN: dev_registry.async_remove_device(device_id) connection.send_message(websocket_api.result_message(msg["id"])) - break + return + + connection.send_error( + msg["id"], websocket_api.const.ERR_NOT_FOUND, "Non MQTT device" + ) @websocket_api.async_response diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 5dc05a95a55..7d06c62b915 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -7,7 +7,7 @@ from unittest import mock import pytest import voluptuous as vol -from homeassistant.components import mqtt +from homeassistant.components import mqtt, websocket_api from homeassistant.components.mqtt.discovery import async_start from homeassistant.const import ( ATTR_DOMAIN, @@ -17,6 +17,7 @@ from homeassistant.const import ( ) from homeassistant.core import callback from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.helpers import device_registry from homeassistant.setup import async_setup_component from homeassistant.util.dt import utcnow @@ -905,8 +906,31 @@ async def test_mqtt_ws_remove_discovered_device_twice( response = await client.receive_json() assert response["success"] + await client.send_json( + {"id": 6, "type": "mqtt/device/remove", "device_id": device_entry.id} + ) + response = await client.receive_json() + assert not response["success"] + assert response["error"]["code"] == websocket_api.const.ERR_NOT_FOUND + + +async def test_mqtt_ws_remove_non_mqtt_device( + hass, device_reg, hass_ws_client, mqtt_mock +): + """Test MQTT websocket device removal of device belonging to other domain.""" + config_entry = MockConfigEntry(domain="test") + config_entry.add_to_hass(hass) + + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + assert device_entry is not None + + client = await hass_ws_client(hass) await client.send_json( {"id": 5, "type": "mqtt/device/remove", "device_id": device_entry.id} ) response = await client.receive_json() assert not response["success"] + assert response["error"]["code"] == websocket_api.const.ERR_NOT_FOUND