diff --git a/homeassistant/components/rflink/cover.py b/homeassistant/components/rflink/cover.py index 5db82a1d4e8..794542cb9d4 100644 --- a/homeassistant/components/rflink/cover.py +++ b/homeassistant/components/rflink/cover.py @@ -23,6 +23,8 @@ from . import ( _LOGGER = logging.getLogger(__name__) +PARALLEL_UPDATES = 0 + TYPE_STANDARD = "standard" TYPE_INVERTED = "inverted" diff --git a/homeassistant/components/rflink/light.py b/homeassistant/components/rflink/light.py index db616b92fc4..1ed19569585 100644 --- a/homeassistant/components/rflink/light.py +++ b/homeassistant/components/rflink/light.py @@ -31,6 +31,8 @@ from . import ( _LOGGER = logging.getLogger(__name__) +PARALLEL_UPDATES = 0 + TYPE_DIMMABLE = "dimmable" TYPE_SWITCHABLE = "switchable" TYPE_HYBRID = "hybrid" diff --git a/homeassistant/components/rflink/switch.py b/homeassistant/components/rflink/switch.py index 8e0ce9a0c8e..990d76101cc 100644 --- a/homeassistant/components/rflink/switch.py +++ b/homeassistant/components/rflink/switch.py @@ -22,6 +22,8 @@ from . import ( _LOGGER = logging.getLogger(__name__) +PARALLEL_UPDATES = 0 + PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( { vol.Optional( diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index a2a0ae840e0..fa649561e3d 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -568,7 +568,6 @@ class Entity(ABC): # call an requests async def async_request_call(self, coro): """Process request batched.""" - if self.parallel_updates: await self.parallel_updates.acquire() diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 8fedc198fe2..e71b28f1713 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -62,22 +62,42 @@ class EntityPlatform: # Platform is None for the EntityComponent "catch-all" EntityPlatform # which powers entity_component.add_entities if platform is None: - self.parallel_updates = None - self.parallel_updates_semaphore: Optional[asyncio.Semaphore] = None + self.parallel_updates_created = True + self.parallel_updates: Optional[asyncio.Semaphore] = None return - self.parallel_updates = getattr(platform, "PARALLEL_UPDATES", None) - # semaphore will be created on demand - self.parallel_updates_semaphore = None + self.parallel_updates_created = False + self.parallel_updates = None - def _get_parallel_updates_semaphore(self) -> asyncio.Semaphore: - """Get or create a semaphore for parallel updates.""" - if self.parallel_updates_semaphore is None: - self.parallel_updates_semaphore = asyncio.Semaphore( - self.parallel_updates if self.parallel_updates else 1, - loop=self.hass.loop, - ) - return self.parallel_updates_semaphore + @callback + def _get_parallel_updates_semaphore( + self, entity_has_async_update: bool + ) -> Optional[asyncio.Semaphore]: + """Get or create a semaphore for parallel updates. + + Semaphore will be created on demand because we base it off if update method is async or not. + + If parallel updates is set to 0, we skip the semaphore. + If parallel updates is set to a number, we initialize the semaphore to that number. + Default for entities with `async_update` method is 1. Otherwise it's 0. + """ + if self.parallel_updates_created: + return self.parallel_updates + + self.parallel_updates_created = True + + parallel_updates = getattr(self.platform, "PARALLEL_UPDATES", None) + + if parallel_updates is None and not entity_has_async_update: + parallel_updates = 1 + + if parallel_updates == 0: + parallel_updates = None + + if parallel_updates is not None: + self.parallel_updates = asyncio.Semaphore(parallel_updates) + + return self.parallel_updates async def async_setup(self, platform_config, discovery_info=None): """Set up the platform from a config file.""" @@ -282,21 +302,9 @@ class EntityPlatform: entity.hass = self.hass entity.platform = self - - # Async entity - # PARALLEL_UPDATES == None: entity.parallel_updates = None - # PARALLEL_UPDATES == 0: entity.parallel_updates = None - # PARALLEL_UPDATES > 0: entity.parallel_updates = Semaphore(p) - # Sync entity - # PARALLEL_UPDATES == None: entity.parallel_updates = Semaphore(1) - # PARALLEL_UPDATES == 0: entity.parallel_updates = None - # PARALLEL_UPDATES > 0: entity.parallel_updates = Semaphore(p) - if hasattr(entity, "async_update") and not self.parallel_updates: - entity.parallel_updates = None - elif not hasattr(entity, "async_update") and self.parallel_updates == 0: - entity.parallel_updates = None - else: - entity.parallel_updates = self._get_parallel_updates_semaphore() + entity.parallel_updates = self._get_parallel_updates_semaphore( + hasattr(entity, "async_update") + ) # Update properties before we generate the entity_id if update_before_add: diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 51f881181af..46ebc467c0b 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -316,16 +316,15 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non # Check the permissions - # A list with for each platform in platforms a list of entities to call - # the service on. - platforms_entities = [] + # A list with entities to call the service on. + entity_candidates = [] if entity_perms is None: for platform in platforms: if target_all_entities: - platforms_entities.append(list(platform.entities.values())) + entity_candidates.extend(platform.entities.values()) else: - platforms_entities.append( + entity_candidates.extend( [ entity for entity in platform.entities.values() @@ -337,7 +336,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non # If we target all entities, we will select all entities the user # is allowed to control. for platform in platforms: - platforms_entities.append( + entity_candidates.extend( [ entity for entity in platform.entities.values() @@ -362,39 +361,20 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non platform_entities.append(entity) - platforms_entities.append(platform_entities) + entity_candidates.extend(platform_entities) if not target_all_entities: - for platform_entities in platforms_entities: - for entity in platform_entities: - entity_ids.remove(entity.entity_id) + for entity in entity_candidates: + entity_ids.remove(entity.entity_id) if entity_ids: _LOGGER.warning( "Unable to find referenced entities %s", ", ".join(sorted(entity_ids)) ) - tasks = [ - _handle_service_platform_call( - hass, func, data, entities, call.context, required_features - ) - for platform, entities in zip(platforms, platforms_entities) - ] + entities = [] - if tasks: - done, pending = await asyncio.wait(tasks) - assert not pending - for future in done: - future.result() # pop exception if have - - -async def _handle_service_platform_call( - hass, func, data, entities, context, required_features -): - """Handle a function call.""" - tasks = [] - - for entity in entities: + for entity in entity_candidates: if not entity.available: continue @@ -404,27 +384,33 @@ async def _handle_service_platform_call( ): continue - entity.async_set_context(context) + entities.append(entity) - if isinstance(func, str): - result = hass.async_add_job(partial(getattr(entity, func), **data)) - else: - result = hass.async_add_job(func, entity, data) + if not entities: + return - # Guard because callback functions do not return a task when passed to async_add_job. - if result is not None: - result = await result - - if asyncio.iscoroutine(result): - _LOGGER.error( - "Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.", - func, - entity.entity_id, + done, pending = await asyncio.wait( + [ + entity.async_request_call( + _handle_entity_call(hass, entity, func, data, call.context) ) - await result + for entity in entities + ] + ) + assert not pending + for future in done: + future.result() # pop exception if have - if entity.should_poll: - tasks.append(entity.async_update_ha_state(True)) + tasks = [] + + for entity in entities: + if not entity.should_poll: + continue + + # Context expires if the turn on commands took a long time. + # Set context again so it's there when we update + entity.async_set_context(call.context) + tasks.append(entity.async_update_ha_state(True)) if tasks: done, pending = await asyncio.wait(tasks) @@ -433,6 +419,28 @@ async def _handle_service_platform_call( future.result() # pop exception if have +async def _handle_entity_call(hass, entity, func, data, context): + """Handle calling service method.""" + entity.async_set_context(context) + + if isinstance(func, str): + result = hass.async_add_job(partial(getattr(entity, func), **data)) + else: + result = hass.async_add_job(func, entity, data) + + # Guard because callback functions do not return a task when passed to async_add_job. + if result is not None: + await result + + if asyncio.iscoroutine(result): + _LOGGER.error( + "Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.", + func, + entity.entity_id, + ) + await result + + @bind_hass @ha.callback def async_register_admin_service( @@ -474,6 +482,7 @@ def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable: return await service_handler(call) user = await hass.auth.async_get_user(call.context.user_id) + if user is None: raise UnknownUser( context=call.context, @@ -482,14 +491,12 @@ def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable: ) reg = await hass.helpers.entity_registry.async_get_registry() - entities = [ - entity.entity_id - for entity in reg.entities.values() - if entity.platform == domain - ] - for entity_id in entities: - if user.permissions.check_entity(entity_id, POLICY_CONTROL): + for entity in reg.entities.values(): + if entity.platform != domain: + continue + + if user.permissions.check_entity(entity.entity_id, POLICY_CONTROL): return await service_handler(call) raise Unauthorized( diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 74105689957..ee43f5d4f1d 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -270,8 +270,6 @@ async def test_parallel_updates_async_platform_with_constant(hass): handle = list(component._platforms.values())[-1] - assert handle.parallel_updates == 2 - class AsyncEntity(MockEntity): """Mock entity that has async_update.""" @@ -296,7 +294,6 @@ async def test_parallel_updates_sync_platform(hass): await component.async_setup({DOMAIN: {"platform": "platform"}}) handle = list(component._platforms.values())[-1] - assert handle.parallel_updates is None class SyncEntity(MockEntity): """Mock entity that has update.""" @@ -323,7 +320,6 @@ async def test_parallel_updates_sync_platform_with_constant(hass): await component.async_setup({DOMAIN: {"platform": "platform"}}) handle = list(component._platforms.values())[-1] - assert handle.parallel_updates == 2 class SyncEntity(MockEntity): """Mock entity that has update.""" diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 106fdfabf2d..cc4098a613a 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -39,31 +39,29 @@ from tests.common import ( @pytest.fixture -def mock_service_platform_call(): +def mock_handle_entity_call(): """Mock service platform call.""" with patch( - "homeassistant.helpers.service._handle_service_platform_call", + "homeassistant.helpers.service._handle_entity_call", side_effect=lambda *args: mock_coro(), ) as mock_call: yield mock_call @pytest.fixture -def mock_entities(): +def mock_entities(hass): """Return mock entities in an ordered dict.""" - kitchen = Mock( + kitchen = MockEntity( entity_id="light.kitchen", available=True, should_poll=False, supported_features=1, - platform="test_domain", ) - living_room = Mock( + living_room = MockEntity( entity_id="light.living_room", available=True, should_poll=False, supported_features=0, - platform="test_domain", ) entities = OrderedDict() entities[kitchen.entity_id] = kitchen @@ -374,7 +372,7 @@ async def test_call_context_user_not_exist(hass): assert err.value.context.user_id == "non-existing" -async def test_call_context_target_all(hass, mock_service_platform_call, mock_entities): +async def test_call_context_target_all(hass, mock_handle_entity_call, mock_entities): """Check we only target allowed entities if targeting all.""" with patch( "homeassistant.auth.AuthManager.async_get_user", @@ -398,13 +396,12 @@ async def test_call_context_target_all(hass, mock_service_platform_call, mock_en ), ) - assert len(mock_service_platform_call.mock_calls) == 1 - entities = mock_service_platform_call.mock_calls[0][1][3] - assert entities == [mock_entities["light.kitchen"]] + assert len(mock_handle_entity_call.mock_calls) == 1 + assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen" async def test_call_context_target_specific( - hass, mock_service_platform_call, mock_entities + hass, mock_handle_entity_call, mock_entities ): """Check targeting specific entities.""" with patch( @@ -429,13 +426,12 @@ async def test_call_context_target_specific( ), ) - assert len(mock_service_platform_call.mock_calls) == 1 - entities = mock_service_platform_call.mock_calls[0][1][3] - assert entities == [mock_entities["light.kitchen"]] + assert len(mock_handle_entity_call.mock_calls) == 1 + assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen" async def test_call_context_target_specific_no_auth( - hass, mock_service_platform_call, mock_entities + hass, mock_handle_entity_call, mock_entities ): """Check targeting specific entities without auth.""" with pytest.raises(exceptions.Unauthorized) as err: @@ -459,9 +455,7 @@ async def test_call_context_target_specific_no_auth( assert err.value.entity_id == "light.kitchen" -async def test_call_no_context_target_all( - hass, mock_service_platform_call, mock_entities -): +async def test_call_no_context_target_all(hass, mock_handle_entity_call, mock_entities): """Check we target all if no user context given.""" await service.entity_service_call( hass, @@ -472,13 +466,14 @@ async def test_call_no_context_target_all( ), ) - assert len(mock_service_platform_call.mock_calls) == 1 - entities = mock_service_platform_call.mock_calls[0][1][3] - assert entities == list(mock_entities.values()) + assert len(mock_handle_entity_call.mock_calls) == 2 + assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list( + mock_entities.values() + ) async def test_call_no_context_target_specific( - hass, mock_service_platform_call, mock_entities + hass, mock_handle_entity_call, mock_entities ): """Check we can target specified entities.""" await service.entity_service_call( @@ -492,13 +487,12 @@ async def test_call_no_context_target_specific( ), ) - assert len(mock_service_platform_call.mock_calls) == 1 - entities = mock_service_platform_call.mock_calls[0][1][3] - assert entities == [mock_entities["light.kitchen"]] + assert len(mock_handle_entity_call.mock_calls) == 1 + assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen" async def test_call_with_match_all( - hass, mock_service_platform_call, mock_entities, caplog + hass, mock_handle_entity_call, mock_entities, caplog ): """Check we only target allowed entities if targeting all.""" await service.entity_service_call( @@ -508,20 +502,13 @@ async def test_call_with_match_all( ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}), ) - assert len(mock_service_platform_call.mock_calls) == 1 - entities = mock_service_platform_call.mock_calls[0][1][3] - assert entities == [ - mock_entities["light.kitchen"], - mock_entities["light.living_room"], - ] - assert ( - "Not passing an entity ID to a service to target all entities is deprecated" - ) not in caplog.text + assert len(mock_handle_entity_call.mock_calls) == 2 + assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list( + mock_entities.values() + ) -async def test_call_with_omit_entity_id( - hass, mock_service_platform_call, mock_entities -): +async def test_call_with_omit_entity_id(hass, mock_handle_entity_call, mock_entities): """Check service call if we do not pass an entity ID.""" await service.entity_service_call( hass, @@ -530,9 +517,7 @@ async def test_call_with_omit_entity_id( ha.ServiceCall("test_domain", "test_service"), ) - assert len(mock_service_platform_call.mock_calls) == 1 - entities = mock_service_platform_call.mock_calls[0][1][3] - assert entities == [] + assert len(mock_handle_entity_call.mock_calls) == 0 async def test_register_admin_service(hass, hass_read_only_user, hass_admin_user): @@ -644,96 +629,113 @@ async def test_domain_control_unknown(hass, mock_entities): assert len(calls) == 0 -async def test_domain_control_unauthorized(hass, hass_read_only_user, mock_entities): +async def test_domain_control_unauthorized(hass, hass_read_only_user): """Test domain verification in a service call with an unauthorized user.""" - calls = [] - - async def mock_service_log(call): - """Define a protected service.""" - calls.append(call) - - with patch( - "homeassistant.helpers.entity_registry.async_get_registry", - return_value=mock_coro(Mock(entities=mock_entities)), - ): - protected_mock_service = hass.helpers.service.verify_domain_control( - "test_domain" - )(mock_service_log) - - hass.services.async_register( - "test_domain", "test_service", protected_mock_service, schema=None - ) - - with pytest.raises(exceptions.Unauthorized): - await hass.services.async_call( - "test_domain", - "test_service", - {}, - blocking=True, - context=ha.Context(user_id=hass_read_only_user.id), + mock_registry( + hass, + { + "light.kitchen": ent_reg.RegistryEntry( + entity_id="light.kitchen", unique_id="kitchen", platform="test_domain", ) + }, + ) + + calls = [] + + async def mock_service_log(call): + """Define a protected service.""" + calls.append(call) + + protected_mock_service = hass.helpers.service.verify_domain_control("test_domain")( + mock_service_log + ) + + hass.services.async_register( + "test_domain", "test_service", protected_mock_service, schema=None + ) + + with pytest.raises(exceptions.Unauthorized): + await hass.services.async_call( + "test_domain", + "test_service", + {}, + blocking=True, + context=ha.Context(user_id=hass_read_only_user.id), + ) + + assert len(calls) == 0 -async def test_domain_control_admin(hass, hass_admin_user, mock_entities): +async def test_domain_control_admin(hass, hass_admin_user): """Test domain verification in a service call with an admin user.""" + mock_registry( + hass, + { + "light.kitchen": ent_reg.RegistryEntry( + entity_id="light.kitchen", unique_id="kitchen", platform="test_domain", + ) + }, + ) + calls = [] async def mock_service_log(call): """Define a protected service.""" calls.append(call) - with patch( - "homeassistant.helpers.entity_registry.async_get_registry", - return_value=mock_coro(Mock(entities=mock_entities)), - ): - protected_mock_service = hass.helpers.service.verify_domain_control( - "test_domain" - )(mock_service_log) + protected_mock_service = hass.helpers.service.verify_domain_control("test_domain")( + mock_service_log + ) - hass.services.async_register( - "test_domain", "test_service", protected_mock_service, schema=None - ) + hass.services.async_register( + "test_domain", "test_service", protected_mock_service, schema=None + ) - await hass.services.async_call( - "test_domain", - "test_service", - {}, - blocking=True, - context=ha.Context(user_id=hass_admin_user.id), - ) + await hass.services.async_call( + "test_domain", + "test_service", + {}, + blocking=True, + context=ha.Context(user_id=hass_admin_user.id), + ) - assert len(calls) == 1 + assert len(calls) == 1 -async def test_domain_control_no_user(hass, mock_entities): +async def test_domain_control_no_user(hass): """Test domain verification in a service call with no user.""" + mock_registry( + hass, + { + "light.kitchen": ent_reg.RegistryEntry( + entity_id="light.kitchen", unique_id="kitchen", platform="test_domain", + ) + }, + ) + calls = [] async def mock_service_log(call): """Define a protected service.""" calls.append(call) - with patch( - "homeassistant.helpers.entity_registry.async_get_registry", - return_value=mock_coro(Mock(entities=mock_entities)), - ): - protected_mock_service = hass.helpers.service.verify_domain_control( - "test_domain" - )(mock_service_log) + protected_mock_service = hass.helpers.service.verify_domain_control("test_domain")( + mock_service_log + ) - hass.services.async_register( - "test_domain", "test_service", protected_mock_service, schema=None - ) + hass.services.async_register( + "test_domain", "test_service", protected_mock_service, schema=None + ) - await hass.services.async_call( - "test_domain", - "test_service", - {}, - blocking=True, - context=ha.Context(user_id=None), - ) + await hass.services.async_call( + "test_domain", + "test_service", + {}, + blocking=True, + context=ha.Context(user_id=None), + ) - assert len(calls) == 1 + assert len(calls) == 1 async def test_extract_from_service_available_device(hass):