diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 6508de08143..aecdf45dde5 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -216,6 +216,11 @@ class Entity: """Time that a context is considered recent.""" return timedelta(seconds=5) + @property + def entity_registry_enabled_default(self): + """Return if the entity should be enabled when first added to the entity registry.""" + return True + # DO NOT OVERWRITE # These properties and methods are either managed by Home Assistant or they # are used to perform a very specific function. Overwriting these may diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index ea71828f21a..dd19fac05c8 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -8,6 +8,7 @@ from homeassistant.core import callback, valid_entity_id, split_entity_id from homeassistant.exceptions import HomeAssistantError, PlatformNotReady from homeassistant.util.async_ import run_callback_threadsafe, run_coroutine_threadsafe +from .entity_registry import DISABLED_INTEGRATION from .event import async_track_time_interval, async_call_later @@ -333,6 +334,10 @@ class EntityPlatform: if device: device_id = device.id + disabled_by: Optional[str] = None + if not entity.entity_registry_enabled_default: + disabled_by = DISABLED_INTEGRATION + entry = entity_registry.async_get_or_create( self.domain, self.platform_name, @@ -341,6 +346,7 @@ class EntityPlatform: config_entry_id=config_entry_id, device_id=device_id, known_object_ids=self.entities.keys(), + disabled_by=disabled_by, ) if entry.disabled: diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 9a7be9ecc36..97cc213aa66 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -35,6 +35,7 @@ _LOGGER = logging.getLogger(__name__) _UNDEF = object() DISABLED_HASS = "hass" DISABLED_USER = "user" +DISABLED_INTEGRATION = "integration" STORAGE_VERSION = 1 STORAGE_KEY = "core.entity_registry" @@ -53,7 +54,9 @@ class RegistryEntry: disabled_by = attr.ib( type=str, default=None, - validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)), + validator=attr.validators.in_( + (DISABLED_HASS, DISABLED_USER, DISABLED_INTEGRATION, None) + ), ) # type: Optional[str] domain = attr.ib(type=str, init=False, repr=False) @@ -132,6 +135,7 @@ class EntityRegistry: config_entry_id=None, device_id=None, known_object_ids=None, + disabled_by=None, ): """Get entity. Create if it doesn't exist.""" entity_id = self.async_get_entity_id(domain, platform, unique_id) @@ -161,6 +165,7 @@ class EntityRegistry: device_id=device_id, unique_id=unique_id, platform=platform, + disabled_by=disabled_by, ) self.entities[entity_id] = entity _LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id) diff --git a/tests/common.py b/tests/common.py index a139ca83743..f7816bf2192 100644 --- a/tests/common.py +++ b/tests/common.py @@ -908,6 +908,11 @@ class MockEntity(entity.Entity): """Info how it links to a device.""" return self._handle("device_info") + @property + def entity_registry_enabled_default(self): + """Return if the entity should be enabled when first added to the entity registry.""" + return self._handle("entity_registry_enabled_default") + def _handle(self, attr): """Return attribute value.""" if attr in self._values: diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 0f43c6ab4aa..606a4c82096 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -775,3 +775,22 @@ async def test_device_info_not_overrides(hass): assert device.id == device2.id assert device2.manufacturer == "test-manufacturer" assert device2.model == "test-model" + + +async def test_entity_disabled_by_integration(hass): + """Test entity disabled by integration.""" + component = EntityComponent(_LOGGER, DOMAIN, hass, timedelta(seconds=20)) + + entity_default = MockEntity(unique_id="default") + entity_disabled = MockEntity( + unique_id="disabled", entity_registry_enabled_default=False + ) + + await component.async_add_entities([entity_default, entity_disabled]) + + registry = await hass.helpers.entity_registry.async_get_registry() + + entry_default = registry.async_get_or_create(DOMAIN, DOMAIN, "default") + assert entry_default.disabled_by is None + entry_disabled = registry.async_get_or_create(DOMAIN, DOMAIN, "disabled") + assert entry_disabled.disabled_by == "integration" diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index a3ffcb4d1ff..88131a58de0 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -352,3 +352,17 @@ async def test_update_entity_unique_id_conflict(registry): ) as mock_schedule_save, pytest.raises(ValueError): registry.async_update_entity(entry.entity_id, new_unique_id=entry2.unique_id) assert mock_schedule_save.call_count == 0 + + +async def test_disabled_by(registry): + """Test that we can disable an entry when we create it.""" + entry = registry.async_get_or_create("light", "hue", "5678", disabled_by="hass") + assert entry.disabled_by == "hass" + + entry = registry.async_get_or_create( + "light", "hue", "5678", disabled_by="integration" + ) + assert entry.disabled_by == "hass" + + entry2 = registry.async_get_or_create("light", "hue", "1234") + assert entry2.disabled_by is None