diff --git a/homeassistant/components/unifi/config_flow.py b/homeassistant/components/unifi/config_flow.py index e40fb30a62c..f29c3951869 100644 --- a/homeassistant/components/unifi/config_flow.py +++ b/homeassistant/components/unifi/config_flow.py @@ -20,6 +20,7 @@ from .const import ( CONF_BLOCK_CLIENT, CONF_CONTROLLER, CONF_DETECTION_TIME, + CONF_DPI_RESTRICTIONS, CONF_IGNORE_WIRED_BUG, CONF_POE_CLIENTS, CONF_SITE_ID, @@ -28,6 +29,7 @@ from .const import ( CONF_TRACK_DEVICES, CONF_TRACK_WIRED_CLIENTS, CONTROLLER_ID, + DEFAULT_DPI_RESTRICTIONS, DEFAULT_POE_CLIENTS, DOMAIN as UNIFI_DOMAIN, LOGGER, @@ -295,6 +297,12 @@ class UnifiOptionsFlowHandler(config_entries.OptionsFlow): CONF_POE_CLIENTS, default=self.options.get(CONF_POE_CLIENTS, DEFAULT_POE_CLIENTS), ): bool, + vol.Optional( + CONF_DPI_RESTRICTIONS, + default=self.options.get( + CONF_DPI_RESTRICTIONS, DEFAULT_DPI_RESTRICTIONS + ), + ): bool, } ), errors=errors, diff --git a/homeassistant/components/unifi/const.py b/homeassistant/components/unifi/const.py index 42d160f2dea..ba16612a903 100644 --- a/homeassistant/components/unifi/const.py +++ b/homeassistant/components/unifi/const.py @@ -15,6 +15,7 @@ CONF_ALLOW_BANDWIDTH_SENSORS = "allow_bandwidth_sensors" CONF_ALLOW_UPTIME_SENSORS = "allow_uptime_sensors" CONF_BLOCK_CLIENT = "block_client" CONF_DETECTION_TIME = "detection_time" +CONF_DPI_RESTRICTIONS = "dpi_restrictions" CONF_IGNORE_WIRED_BUG = "ignore_wired_bug" CONF_POE_CLIENTS = "poe_clients" CONF_TRACK_CLIENTS = "track_clients" @@ -24,6 +25,7 @@ CONF_SSID_FILTER = "ssid_filter" DEFAULT_ALLOW_BANDWIDTH_SENSORS = False DEFAULT_ALLOW_UPTIME_SENSORS = False +DEFAULT_DPI_RESTRICTIONS = True DEFAULT_IGNORE_WIRED_BUG = False DEFAULT_POE_CLIENTS = True DEFAULT_TRACK_CLIENTS = True diff --git a/homeassistant/components/unifi/controller.py b/homeassistant/components/unifi/controller.py index 6fc5b3d9ed7..4d5bfa20215 100644 --- a/homeassistant/components/unifi/controller.py +++ b/homeassistant/components/unifi/controller.py @@ -7,6 +7,8 @@ from aiohttp import CookieJar import aiounifi from aiounifi.controller import ( DATA_CLIENT_REMOVED, + DATA_DPI_GROUP, + DATA_DPI_GROUP_REMOVED, DATA_EVENT, SIGNAL_CONNECTION_STATE, SIGNAL_DATA, @@ -37,6 +39,7 @@ from .const import ( CONF_BLOCK_CLIENT, CONF_CONTROLLER, CONF_DETECTION_TIME, + CONF_DPI_RESTRICTIONS, CONF_IGNORE_WIRED_BUG, CONF_POE_CLIENTS, CONF_SITE_ID, @@ -48,6 +51,7 @@ from .const import ( DEFAULT_ALLOW_BANDWIDTH_SENSORS, DEFAULT_ALLOW_UPTIME_SENSORS, DEFAULT_DETECTION_TIME, + DEFAULT_DPI_RESTRICTIONS, DEFAULT_IGNORE_WIRED_BUG, DEFAULT_POE_CLIENTS, DEFAULT_TRACK_CLIENTS, @@ -177,6 +181,13 @@ class UniFiController: """Config entry option with list of clients to control network access.""" return self.config_entry.options.get(CONF_BLOCK_CLIENT, []) + @property + def option_dpi_restrictions(self): + """Config entry option to control DPI restriction groups.""" + return self.config_entry.options.get( + CONF_DPI_RESTRICTIONS, DEFAULT_DPI_RESTRICTIONS + ) + # Statistics sensor options @property @@ -248,6 +259,18 @@ class UniFiController: self.hass, self.signal_remove, data[DATA_CLIENT_REMOVED] ) + elif DATA_DPI_GROUP in data: + for key in data[DATA_DPI_GROUP]: + if self.api.dpi_groups[key].dpiapp_ids: + async_dispatcher_send(self.hass, self.signal_update) + else: + async_dispatcher_send(self.hass, self.signal_remove, {key}) + + elif DATA_DPI_GROUP_REMOVED in data: + async_dispatcher_send( + self.hass, self.signal_remove, data[DATA_DPI_GROUP_REMOVED] + ) + @property def signal_reachable(self) -> str: """Integration specific event to signal a change in connection status.""" diff --git a/homeassistant/components/unifi/manifest.json b/homeassistant/components/unifi/manifest.json index 9c1896f0c48..48c080f82f7 100644 --- a/homeassistant/components/unifi/manifest.json +++ b/homeassistant/components/unifi/manifest.json @@ -3,7 +3,7 @@ "name": "Ubiquiti UniFi", "config_flow": true, "documentation": "https://www.home-assistant.io/integrations/unifi", - "requirements": ["aiounifi==23"], + "requirements": ["aiounifi==25"], "codeowners": ["@Kane610"], "quality_scale": "platinum" } diff --git a/homeassistant/components/unifi/strings.json b/homeassistant/components/unifi/strings.json index 9deb68f4e3b..75e23ae2ed1 100644 --- a/homeassistant/components/unifi/strings.json +++ b/homeassistant/components/unifi/strings.json @@ -39,7 +39,8 @@ "client_control": { "data": { "block_client": "Network access controlled clients", - "poe_clients": "Allow POE control of clients" + "poe_clients": "Allow POE control of clients", + "dpi_restrictions": "Allow control of DPI restriction groups" }, "description": "Configure client controls\n\nCreate switches for serial numbers you want to control network access for.", "title": "UniFi options 2/3" diff --git a/homeassistant/components/unifi/switch.py b/homeassistant/components/unifi/switch.py index 9bdb35baf4d..6aa42b0d291 100644 --- a/homeassistant/components/unifi/switch.py +++ b/homeassistant/components/unifi/switch.py @@ -1,5 +1,6 @@ """Support for devices connected to UniFi POE.""" import logging +from typing import Any from aiounifi.api import SOURCE_EVENT from aiounifi.events import ( @@ -14,12 +15,14 @@ from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.restore_state import RestoreEntity -from .const import DOMAIN as UNIFI_DOMAIN +from .const import ATTR_MANUFACTURER, DOMAIN as UNIFI_DOMAIN from .unifi_client import UniFiClient +from .unifi_entity_base import UniFiBase _LOGGER = logging.getLogger(__name__) BLOCK_SWITCH = "block" +DPI_SWITCH = "dpi" POE_SWITCH = "poe" CLIENT_BLOCKED = (WIRED_CLIENT_BLOCKED, WIRELESS_CLIENT_BLOCKED) @@ -32,7 +35,11 @@ async def async_setup_entry(hass, config_entry, async_add_entities): Switches are controlling network access and switch ports with POE. """ controller = hass.data[UNIFI_DOMAIN][config_entry.entry_id] - controller.entities[DOMAIN] = {BLOCK_SWITCH: set(), POE_SWITCH: set()} + controller.entities[DOMAIN] = { + BLOCK_SWITCH: set(), + POE_SWITCH: set(), + DPI_SWITCH: set(), + } if controller.site_role != "admin": return @@ -59,7 +66,9 @@ async def async_setup_entry(hass, config_entry, async_add_entities): @callback def items_added( - clients: set = controller.api.clients, devices: set = controller.api.devices + clients: set = controller.api.clients, + devices: set = controller.api.devices, + dpi_groups: set = controller.api.dpi_groups, ) -> None: """Update the values of the controller.""" if controller.option_block_clients: @@ -70,6 +79,9 @@ async def async_setup_entry(hass, config_entry, async_add_entities): controller, async_add_entities, clients, previously_known_poe_clients ) + if controller.option_dpi_restrictions: + add_dpi_entities(controller, async_add_entities, dpi_groups) + for signal in (controller.signal_update, controller.signal_options_update): controller.listeners.append(async_dispatcher_connect(hass, signal, items_added)) @@ -143,6 +155,24 @@ def add_poe_entities( async_add_entities(switches) +@callback +def add_dpi_entities(controller, async_add_entities, dpi_groups): + """Add new switch entities from the controller.""" + switches = [] + + for group in dpi_groups: + if ( + group in controller.entities[DOMAIN][DPI_SWITCH] + or not dpi_groups[group].dpiapp_ids + ): + continue + + switches.append(UniFiDPIRestrictionSwitch(dpi_groups[group], controller)) + + if switches: + async_add_entities(switches) + + class UniFiPOEClientSwitch(UniFiClient, SwitchEntity, RestoreEntity): """Representation of a client that uses POE.""" @@ -284,3 +314,61 @@ class UniFiBlockClientSwitch(UniFiClient, SwitchEntity): """Config entry options are updated, remove entity if option is disabled.""" if self.client.mac not in self.controller.option_block_clients: await self.remove_item({self.client.mac}) + + +class UniFiDPIRestrictionSwitch(UniFiBase, SwitchEntity): + """Representation of a DPI restriction group.""" + + DOMAIN = DOMAIN + TYPE = DPI_SWITCH + + @property + def key(self) -> Any: + """Return item key.""" + return self._item.id + + @property + def unique_id(self): + """Return a unique identifier for this switch.""" + return self._item.id + + @property + def name(self) -> str: + """Return the name of the client.""" + return self._item.name + + @property + def icon(self): + """Return the icon to use in the frontend.""" + if self._item.enabled: + return "mdi:network" + return "mdi:network-off" + + @property + def is_on(self): + """Return true if client is allowed to connect.""" + return self._item.enabled + + async def async_turn_on(self, **kwargs): + """Turn on connectivity for client.""" + await self.controller.api.dpi_groups.async_enable(self._item) + + async def async_turn_off(self, **kwargs): + """Turn off connectivity for client.""" + await self.controller.api.dpi_groups.async_disable(self._item) + + async def options_updated(self) -> None: + """Config entry options are updated, remove entity if option is disabled.""" + if not self.controller.option_dpi_restrictions: + await self.remove_item({self.key}) + + @property + def device_info(self) -> dict: + """Return a service description for device registry.""" + return { + "identifiers": {(DOMAIN, f"unifi_controller_{self._item.site_id}")}, + "name": "UniFi Controller", + "manufacturer": ATTR_MANUFACTURER, + "model": "UniFi Controller", + "entry_type": "service", + } diff --git a/homeassistant/components/unifi/translations/en.json b/homeassistant/components/unifi/translations/en.json index ed3a26b335a..968d90e377c 100644 --- a/homeassistant/components/unifi/translations/en.json +++ b/homeassistant/components/unifi/translations/en.json @@ -27,6 +27,7 @@ "client_control": { "data": { "block_client": "Network access controlled clients", + "dpi_restrictions": "Allow control of DPI restriction groups", "poe_clients": "Allow POE control of clients" }, "description": "Configure client controls\n\nCreate switches for serial numbers you want to control network access for.", diff --git a/homeassistant/components/unifi/unifi_entity_base.py b/homeassistant/components/unifi/unifi_entity_base.py index 42820cf69b0..a730c134603 100644 --- a/homeassistant/components/unifi/unifi_entity_base.py +++ b/homeassistant/components/unifi/unifi_entity_base.py @@ -1,5 +1,6 @@ """Base class for UniFi entities.""" import logging +from typing import Any from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_connect @@ -22,12 +23,20 @@ class UniFiBase(Entity): """ self._item = item self.controller = controller - self.controller.entities[self.DOMAIN][self.TYPE].add(item.mac) + self.controller.entities[self.DOMAIN][self.TYPE].add(self.key) + + @property + def key(self) -> Any: + """Return item key.""" + return self._item.mac async def async_added_to_hass(self) -> None: """Entity created.""" _LOGGER.debug( - "New %s entity %s (%s)", self.TYPE, self.entity_id, self._item.mac + "New %s entity %s (%s)", + self.TYPE, + self.entity_id, + self.key, ) for signal, method in ( (self.controller.signal_reachable, self.async_update_callback), @@ -40,16 +49,22 @@ class UniFiBase(Entity): async def async_will_remove_from_hass(self) -> None: """Disconnect object when removed.""" _LOGGER.debug( - "Removing %s entity %s (%s)", self.TYPE, self.entity_id, self._item.mac + "Removing %s entity %s (%s)", + self.TYPE, + self.entity_id, + self.key, ) self._item.remove_callback(self.async_update_callback) - self.controller.entities[self.DOMAIN][self.TYPE].remove(self._item.mac) + self.controller.entities[self.DOMAIN][self.TYPE].remove(self.key) @callback def async_update_callback(self) -> None: """Update the entity's state.""" _LOGGER.debug( - "Updating %s entity %s (%s)", self.TYPE, self.entity_id, self._item.mac + "Updating %s entity %s (%s)", + self.TYPE, + self.entity_id, + self.key, ) self.async_write_ha_state() @@ -57,15 +72,15 @@ class UniFiBase(Entity): """Config entry options are updated, remove entity if option is disabled.""" raise NotImplementedError - async def remove_item(self, mac_addresses: set) -> None: - """Remove entity if MAC is part of set. + async def remove_item(self, keys: set) -> None: + """Remove entity if key is part of set. Remove entity if no entry in entity registry exist. Remove entity registry entry if no entry in device registry exist. Remove device registry entry if there is only one linked entity (this entity). Remove entity registry entry if there are more than one entity linked to the device registry entry. """ - if self._item.mac not in mac_addresses: + if self.key not in keys: return entity_registry = await self.hass.helpers.entity_registry.async_get_registry() diff --git a/requirements_all.txt b/requirements_all.txt index 3e6262ccfd8..922a93e8442 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -227,7 +227,7 @@ aioshelly==0.5.0 aioswitcher==1.2.1 # homeassistant.components.unifi -aiounifi==23 +aiounifi==25 # homeassistant.components.yandex_transport aioymaps==1.1.0 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 3b38380de6a..50b9c00fab4 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -143,7 +143,7 @@ aioshelly==0.5.0 aioswitcher==1.2.1 # homeassistant.components.unifi -aiounifi==23 +aiounifi==25 # homeassistant.components.yandex_transport aioymaps==1.1.0 diff --git a/tests/components/unifi/test_config_flow.py b/tests/components/unifi/test_config_flow.py index 8b935d7744b..87233f9983c 100644 --- a/tests/components/unifi/test_config_flow.py +++ b/tests/components/unifi/test_config_flow.py @@ -8,6 +8,7 @@ from homeassistant.components.unifi.const import ( CONF_BLOCK_CLIENT, CONF_CONTROLLER, CONF_DETECTION_TIME, + CONF_DPI_RESTRICTIONS, CONF_IGNORE_WIRED_BUG, CONF_POE_CLIENTS, CONF_SITE_ID, @@ -72,6 +73,14 @@ WLANS = [ {"name": "SSID 2", "name_combine_enabled": False, "name_combine_suffix": "_IOT"}, ] +DPI_GROUPS = [ + { + "_id": "5ba29dd8e3c58f026e9d7c4a", + "name": "Default", + "site_id": "5ba29dd4e3c58f026e9d7c38", + }, +] + async def test_flow_works(hass, aioclient_mock, mock_discovery): """Test config flow.""" @@ -307,7 +316,12 @@ async def test_flow_fails_unknown_problem(hass, aioclient_mock): async def test_advanced_option_flow(hass): """Test advanced config flow options.""" controller = await setup_unifi_integration( - hass, clients_response=CLIENTS, devices_response=DEVICES, wlans_response=WLANS + hass, + clients_response=CLIENTS, + devices_response=DEVICES, + wlans_response=WLANS, + dpigroup_response=DPI_GROUPS, + dpiapp_response=[], ) result = await hass.config_entries.options.async_init( @@ -336,7 +350,11 @@ async def test_advanced_option_flow(hass): result = await hass.config_entries.options.async_configure( result["flow_id"], - user_input={CONF_BLOCK_CLIENT: [CLIENTS[0]["mac"]], CONF_POE_CLIENTS: False}, + user_input={ + CONF_BLOCK_CLIENT: [CLIENTS[0]["mac"]], + CONF_POE_CLIENTS: False, + CONF_DPI_RESTRICTIONS: False, + }, ) assert result["type"] == data_entry_flow.RESULT_TYPE_FORM @@ -359,6 +377,7 @@ async def test_advanced_option_flow(hass): CONF_DETECTION_TIME: 100, CONF_IGNORE_WIRED_BUG: False, CONF_POE_CLIENTS: False, + CONF_DPI_RESTRICTIONS: False, CONF_BLOCK_CLIENT: [CLIENTS[0]["mac"]], CONF_ALLOW_BANDWIDTH_SENSORS: True, CONF_ALLOW_UPTIME_SENSORS: True, @@ -368,7 +387,11 @@ async def test_advanced_option_flow(hass): async def test_simple_option_flow(hass): """Test simple config flow options.""" controller = await setup_unifi_integration( - hass, clients_response=CLIENTS, wlans_response=WLANS + hass, + clients_response=CLIENTS, + wlans_response=WLANS, + dpigroup_response=DPI_GROUPS, + dpiapp_response=[], ) result = await hass.config_entries.options.async_init( diff --git a/tests/components/unifi/test_controller.py b/tests/components/unifi/test_controller.py index 5fee4a85f9a..83732601cd6 100644 --- a/tests/components/unifi/test_controller.py +++ b/tests/components/unifi/test_controller.py @@ -81,6 +81,8 @@ async def setup_unifi_integration( devices_response=None, clients_all_response=None, wlans_response=None, + dpigroup_response=None, + dpiapp_response=None, known_wireless_clients=None, controllers=None, ): @@ -116,6 +118,14 @@ async def setup_unifi_integration( if wlans_response: mock_wlans_responses.append(wlans_response) + mock_dpigroup_responses = deque() + if dpigroup_response: + mock_dpigroup_responses.append(dpigroup_response) + + mock_dpiapp_responses = deque() + if dpiapp_response: + mock_dpiapp_responses.append(dpiapp_response) + mock_requests = [] async def mock_request(self, method, path, json=None): @@ -129,6 +139,10 @@ async def setup_unifi_integration( return mock_client_all_responses.popleft() if path == "/rest/wlanconf" and mock_wlans_responses: return mock_wlans_responses.popleft() + if path == "/rest/dpigroup" and mock_dpigroup_responses: + return mock_dpigroup_responses.popleft() + if path == "/rest/dpiapp" and mock_dpiapp_responses: + return mock_dpiapp_responses.popleft() return {} with patch("aiounifi.Controller.check_unifi_os", return_value=True), patch( diff --git a/tests/components/unifi/test_sensor.py b/tests/components/unifi/test_sensor.py index 690b9d77899..dc2fea634c9 100644 --- a/tests/components/unifi/test_sensor.py +++ b/tests/components/unifi/test_sensor.py @@ -71,7 +71,7 @@ async def test_no_clients(hass): }, ) - assert len(controller.mock_requests) == 4 + assert len(controller.mock_requests) == 6 assert len(hass.states.async_entity_ids(SENSOR_DOMAIN)) == 0 @@ -88,7 +88,7 @@ async def test_sensors(hass): clients_response=CLIENTS, ) - assert len(controller.mock_requests) == 4 + assert len(controller.mock_requests) == 6 assert len(hass.states.async_entity_ids(SENSOR_DOMAIN)) == 6 wired_client_rx = hass.states.get("sensor.wired_client_name_rx") diff --git a/tests/components/unifi/test_switch.py b/tests/components/unifi/test_switch.py index 6c4fc25d828..903db479d34 100644 --- a/tests/components/unifi/test_switch.py +++ b/tests/components/unifi/test_switch.py @@ -9,6 +9,8 @@ from homeassistant.components.device_tracker import DOMAIN as TRACKER_DOMAIN from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN from homeassistant.components.unifi.const import ( CONF_BLOCK_CLIENT, + CONF_DPI_RESTRICTIONS, + CONF_POE_CLIENTS, CONF_TRACK_CLIENTS, CONF_TRACK_DEVICES, DOMAIN as UNIFI_DOMAIN, @@ -251,6 +253,35 @@ EVENT_CLIENT_2_CONNECTED = { } +DPI_GROUPS = [ + { + "_id": "5ba29dd8e3c58f026e9d7c4a", + "attr_no_delete": True, + "attr_hidden_id": "Default", + "name": "Default", + "site_id": "name", + }, + { + "_id": "5f976f4ae3c58f018ec7dff6", + "name": "Block Media Streaming", + "site_id": "name", + "dpiapp_ids": ["5f976f62e3c58f018ec7e17d"], + }, +] + +DPI_APPS = [ + { + "_id": "5f976f62e3c58f018ec7e17d", + "apps": [], + "blocked": True, + "cats": ["4"], + "enabled": True, + "log": True, + "site_id": "name", + } +] + + async def test_platform_manually_configured(hass): """Test that we do not discover anything or try to set up a controller.""" assert ( @@ -266,10 +297,14 @@ async def test_no_clients(hass): """Test the update_clients function when no clients are found.""" controller = await setup_unifi_integration( hass, - options={CONF_TRACK_CLIENTS: False, CONF_TRACK_DEVICES: False}, + options={ + CONF_TRACK_CLIENTS: False, + CONF_TRACK_DEVICES: False, + CONF_DPI_RESTRICTIONS: False, + }, ) - assert len(controller.mock_requests) == 4 + assert len(controller.mock_requests) == 6 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 @@ -282,7 +317,7 @@ async def test_controller_not_client(hass): devices_response=[DEVICE_1], ) - assert len(controller.mock_requests) == 4 + assert len(controller.mock_requests) == 6 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 cloudkey = hass.states.get("switch.cloud_key") assert cloudkey is None @@ -300,7 +335,7 @@ async def test_not_admin(hass): devices_response=[DEVICE_1], ) - assert len(controller.mock_requests) == 4 + assert len(controller.mock_requests) == 6 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 @@ -316,10 +351,12 @@ async def test_switches(hass): clients_response=[CLIENT_1, CLIENT_4], devices_response=[DEVICE_1], clients_all_response=[BLOCKED, UNBLOCKED, CLIENT_1], + dpigroup_response=DPI_GROUPS, + dpiapp_response=DPI_APPS, ) - assert len(controller.mock_requests) == 4 - assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 3 + assert len(controller.mock_requests) == 6 + assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 4 switch_1 = hass.states.get("switch.poe_client_1") assert switch_1 is not None @@ -340,11 +377,15 @@ async def test_switches(hass): assert unblocked is not None assert unblocked.state == "on" + dpi_switch = hass.states.get("switch.block_media_streaming") + assert dpi_switch is not None + assert dpi_switch.state == "on" + await hass.services.async_call( SWITCH_DOMAIN, "turn_off", {"entity_id": "switch.block_client_1"}, blocking=True ) - assert len(controller.mock_requests) == 5 - assert controller.mock_requests[4] == { + assert len(controller.mock_requests) == 7 + assert controller.mock_requests[6] == { "json": {"mac": "00:00:00:00:01:01", "cmd": "block-sta"}, "method": "post", "path": "/cmd/stamgr", @@ -353,13 +394,39 @@ async def test_switches(hass): await hass.services.async_call( SWITCH_DOMAIN, "turn_on", {"entity_id": "switch.block_client_1"}, blocking=True ) - assert len(controller.mock_requests) == 6 - assert controller.mock_requests[5] == { + assert len(controller.mock_requests) == 8 + assert controller.mock_requests[7] == { "json": {"mac": "00:00:00:00:01:01", "cmd": "unblock-sta"}, "method": "post", "path": "/cmd/stamgr", } + await hass.services.async_call( + SWITCH_DOMAIN, + "turn_off", + {"entity_id": "switch.block_media_streaming"}, + blocking=True, + ) + assert len(controller.mock_requests) == 9 + assert controller.mock_requests[8] == { + "json": {"enabled": False}, + "method": "put", + "path": "/rest/dpiapp/5f976f62e3c58f018ec7e17d", + } + + await hass.services.async_call( + SWITCH_DOMAIN, + "turn_on", + {"entity_id": "switch.block_media_streaming"}, + blocking=True, + ) + assert len(controller.mock_requests) == 10 + assert controller.mock_requests[9] == { + "json": {"enabled": True}, + "method": "put", + "path": "/rest/dpiapp/5f976f62e3c58f018ec7e17d", + } + async def test_remove_switches(hass): """Test the update_items function with some clients.""" @@ -443,8 +510,8 @@ async def test_block_switches(hass): await hass.services.async_call( SWITCH_DOMAIN, "turn_off", {"entity_id": "switch.block_client_1"}, blocking=True ) - assert len(controller.mock_requests) == 5 - assert controller.mock_requests[4] == { + assert len(controller.mock_requests) == 7 + assert controller.mock_requests[6] == { "json": {"mac": "00:00:00:00:01:01", "cmd": "block-sta"}, "method": "post", "path": "/cmd/stamgr", @@ -453,8 +520,8 @@ async def test_block_switches(hass): await hass.services.async_call( SWITCH_DOMAIN, "turn_on", {"entity_id": "switch.block_client_1"}, blocking=True ) - assert len(controller.mock_requests) == 6 - assert controller.mock_requests[5] == { + assert len(controller.mock_requests) == 8 + assert controller.mock_requests[7] == { "json": {"mac": "00:00:00:00:01:01", "cmd": "unblock-sta"}, "method": "post", "path": "/cmd/stamgr", @@ -469,10 +536,11 @@ async def test_new_client_discovered_on_block_control(hass): CONF_BLOCK_CLIENT: [BLOCKED["mac"]], CONF_TRACK_CLIENTS: False, CONF_TRACK_DEVICES: False, + CONF_DPI_RESTRICTIONS: False, }, ) - assert len(controller.mock_requests) == 4 + assert len(controller.mock_requests) == 6 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 blocked = hass.states.get("switch.block_client_1") @@ -541,6 +609,30 @@ async def test_option_block_clients(hass): assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 +async def test_option_remove_switches(hass): + """Test removal of DPI switch when options updated.""" + controller = await setup_unifi_integration( + hass, + options={ + CONF_TRACK_CLIENTS: False, + CONF_TRACK_DEVICES: False, + }, + clients_response=[CLIENT_1], + devices_response=[DEVICE_1], + dpigroup_response=DPI_GROUPS, + dpiapp_response=DPI_APPS, + ) + assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 2 + + # Disable DPI Switches + hass.config_entries.async_update_entry( + controller.config_entry, + options={CONF_DPI_RESTRICTIONS: False, CONF_POE_CLIENTS: False}, + ) + await hass.async_block_till_done() + assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0 + + async def test_new_client_discovered_on_poe_control(hass): """Test if 2nd update has a new client.""" controller = await setup_unifi_integration( @@ -550,7 +642,7 @@ async def test_new_client_discovered_on_poe_control(hass): devices_response=[DEVICE_1], ) - assert len(controller.mock_requests) == 4 + assert len(controller.mock_requests) == 6 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 1 controller.api.websocket._data = { @@ -576,9 +668,9 @@ async def test_new_client_discovered_on_poe_control(hass): await hass.services.async_call( SWITCH_DOMAIN, "turn_off", {"entity_id": "switch.poe_client_1"}, blocking=True ) - assert len(controller.mock_requests) == 5 + assert len(controller.mock_requests) == 7 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 2 - assert controller.mock_requests[4] == { + assert controller.mock_requests[6] == { "json": { "port_overrides": [{"port_idx": 1, "portconf_id": "1a1", "poe_mode": "off"}] }, @@ -589,8 +681,8 @@ async def test_new_client_discovered_on_poe_control(hass): await hass.services.async_call( SWITCH_DOMAIN, "turn_on", {"entity_id": "switch.poe_client_1"}, blocking=True ) - assert len(controller.mock_requests) == 6 - assert controller.mock_requests[4] == { + assert len(controller.mock_requests) == 8 + assert controller.mock_requests[7] == { "json": { "port_overrides": [ {"port_idx": 1, "portconf_id": "1a1", "poe_mode": "auto"} @@ -613,7 +705,7 @@ async def test_ignore_multiple_poe_clients_on_same_port(hass): devices_response=[DEVICE_1], ) - assert len(controller.mock_requests) == 4 + assert len(controller.mock_requests) == 6 assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 3 switch_1 = hass.states.get("switch.poe_client_1") @@ -664,7 +756,7 @@ async def test_restoring_client(hass): clients_all_response=[CLIENT_1], ) - assert len(controller.mock_requests) == 4 + assert len(controller.mock_requests) == 6 assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 2 device_1 = hass.states.get("switch.client_1")