Convert discovery helper to use dispatcher (#47008)

This commit is contained in:
Paulus Schoutsen 2021-02-24 13:37:31 -08:00 committed by GitHub
parent 5ab11df551
commit 557ec374f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 79 additions and 149 deletions

View File

@ -1,11 +1,4 @@
""" """Starts a service to scan in intervals for new devices."""
Starts a service to scan in intervals for new devices.
Will emit EVENT_PLATFORM_DISCOVERED whenever a new service has been discovered.
Knows which components handle certain types, will make sure they are
loaded before the EVENT_PLATFORM_DISCOVERED is fired.
"""
from datetime import timedelta from datetime import timedelta
import json import json
import logging import logging

View File

@ -6,7 +6,6 @@ from aiohttp.hdrs import CONTENT_TYPE
import requests import requests
import voluptuous as vol import voluptuous as vol
from homeassistant.components.discovery import SERVICE_OCTOPRINT
from homeassistant.const import ( from homeassistant.const import (
CONF_API_KEY, CONF_API_KEY,
CONF_BINARY_SENSORS, CONF_BINARY_SENSORS,
@ -22,7 +21,6 @@ from homeassistant.const import (
TEMP_CELSIUS, TEMP_CELSIUS,
TIME_SECONDS, TIME_SECONDS,
) )
from homeassistant.helpers import discovery
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.discovery import load_platform from homeassistant.helpers.discovery import load_platform
from homeassistant.util import slugify as util_slugify from homeassistant.util import slugify as util_slugify
@ -132,12 +130,6 @@ def setup(hass, config):
printers = hass.data[DOMAIN] = {} printers = hass.data[DOMAIN] = {}
success = False success = False
def device_discovered(service, info):
"""Get called when an Octoprint server has been discovered."""
_LOGGER.debug("Found an Octoprint server: %s", info)
discovery.listen(hass, SERVICE_OCTOPRINT, device_discovered)
if DOMAIN not in config: if DOMAIN not in config:
# Skip the setup if there is no configuration present # Skip the setup if there is no configuration present
return True return True

View File

@ -210,7 +210,6 @@ EVENT_HOMEASSISTANT_STARTED = "homeassistant_started"
EVENT_HOMEASSISTANT_STOP = "homeassistant_stop" EVENT_HOMEASSISTANT_STOP = "homeassistant_stop"
EVENT_HOMEASSISTANT_FINAL_WRITE = "homeassistant_final_write" EVENT_HOMEASSISTANT_FINAL_WRITE = "homeassistant_final_write"
EVENT_LOGBOOK_ENTRY = "logbook_entry" EVENT_LOGBOOK_ENTRY = "logbook_entry"
EVENT_PLATFORM_DISCOVERED = "platform_discovered"
EVENT_SERVICE_REGISTERED = "service_registered" EVENT_SERVICE_REGISTERED = "service_registered"
EVENT_SERVICE_REMOVED = "service_removed" EVENT_SERVICE_REMOVED = "service_removed"
EVENT_STATE_CHANGED = "state_changed" EVENT_STATE_CHANGED = "state_changed"
@ -313,9 +312,6 @@ CONF_UNIT_SYSTEM_IMPERIAL: str = "imperial"
# Electrical attributes # Electrical attributes
ATTR_VOLTAGE = "voltage" ATTR_VOLTAGE = "voltage"
# Contains the information that is discovered
ATTR_DISCOVERED = "discovered"
# Location of the device/sensor # Location of the device/sensor
ATTR_LOCATION = "location" ATTR_LOCATION = "location"

View File

@ -5,62 +5,55 @@ There are two different types of discoveries that can be fired/listened for.
- listen_platform/discover_platform is for platforms. These are used by - listen_platform/discover_platform is for platforms. These are used by
components to allow discovery of their platforms. components to allow discovery of their platforms.
""" """
from typing import Any, Callable, Collection, Dict, Optional, Union from typing import Any, Callable, Dict, Optional, TypedDict
from homeassistant import core, setup from homeassistant import core, setup
from homeassistant.const import ATTR_DISCOVERED, ATTR_SERVICE, EVENT_PLATFORM_DISCOVERED
from homeassistant.core import CALLBACK_TYPE from homeassistant.core import CALLBACK_TYPE
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.async_ import run_callback_threadsafe
from .dispatcher import async_dispatcher_connect, async_dispatcher_send
from .typing import ConfigType, DiscoveryInfoType
SIGNAL_PLATFORM_DISCOVERED = "discovery.platform_discovered_{}"
EVENT_LOAD_PLATFORM = "load_platform.{}" EVENT_LOAD_PLATFORM = "load_platform.{}"
ATTR_PLATFORM = "platform" ATTR_PLATFORM = "platform"
ATTR_DISCOVERED = "discovered"
# mypy: disallow-any-generics # mypy: disallow-any-generics
@bind_hass class DiscoveryDict(TypedDict):
def listen( """Discovery data."""
hass: core.HomeAssistant,
service: Union[str, Collection[str]],
callback: CALLBACK_TYPE,
) -> None:
"""Set up listener for discovery of specific service.
Service can be a string or a list/tuple. service: str
""" platform: Optional[str]
run_callback_threadsafe(hass.loop, async_listen, hass, service, callback).result() discovered: Optional[DiscoveryInfoType]
@core.callback @core.callback
@bind_hass @bind_hass
def async_listen( def async_listen(
hass: core.HomeAssistant, hass: core.HomeAssistant,
service: Union[str, Collection[str]], service: str,
callback: CALLBACK_TYPE, callback: CALLBACK_TYPE,
) -> None: ) -> None:
"""Set up listener for discovery of specific service. """Set up listener for discovery of specific service.
Service can be a string or a list/tuple. Service can be a string or a list/tuple.
""" """
if isinstance(service, str):
service = (service,)
else:
service = tuple(service)
job = core.HassJob(callback) job = core.HassJob(callback)
async def discovery_event_listener(event: core.Event) -> None: async def discovery_event_listener(discovered: DiscoveryDict) -> None:
"""Listen for discovery events.""" """Listen for discovery events."""
if ATTR_SERVICE in event.data and event.data[ATTR_SERVICE] in service: task = hass.async_run_hass_job(
task = hass.async_run_hass_job( job, discovered["service"], discovered["discovered"]
job, event.data[ATTR_SERVICE], event.data.get(ATTR_DISCOVERED) )
) if task:
if task: await task
await task
hass.bus.async_listen(EVENT_PLATFORM_DISCOVERED, discovery_event_listener) async_dispatcher_connect(
hass, SIGNAL_PLATFORM_DISCOVERED.format(service), discovery_event_listener
)
@bind_hass @bind_hass
@ -91,22 +84,13 @@ async def async_discover(
if component is not None and component not in hass.config.components: if component is not None and component not in hass.config.components:
await setup.async_setup_component(hass, component, hass_config) await setup.async_setup_component(hass, component, hass_config)
data: Dict[str, Any] = {ATTR_SERVICE: service} data: DiscoveryDict = {
"service": service,
"platform": None,
"discovered": discovered,
}
if discovered is not None: async_dispatcher_send(hass, SIGNAL_PLATFORM_DISCOVERED.format(service), data)
data[ATTR_DISCOVERED] = discovered
hass.bus.async_fire(EVENT_PLATFORM_DISCOVERED, data)
@bind_hass
def listen_platform(
hass: core.HomeAssistant, component: str, callback: CALLBACK_TYPE
) -> None:
"""Register a platform loader listener."""
run_callback_threadsafe(
hass.loop, async_listen_platform, hass, component, callback
).result()
@bind_hass @bind_hass
@ -122,21 +106,20 @@ def async_listen_platform(
service = EVENT_LOAD_PLATFORM.format(component) service = EVENT_LOAD_PLATFORM.format(component)
job = core.HassJob(callback) job = core.HassJob(callback)
async def discovery_platform_listener(event: core.Event) -> None: async def discovery_platform_listener(discovered: DiscoveryDict) -> None:
"""Listen for platform discovery events.""" """Listen for platform discovery events."""
if event.data.get(ATTR_SERVICE) != service: platform = discovered["platform"]
return
platform = event.data.get(ATTR_PLATFORM)
if not platform: if not platform:
return return
task = hass.async_run_hass_job(job, platform, event.data.get(ATTR_DISCOVERED)) task = hass.async_run_hass_job(job, platform, discovered.get("discovered"))
if task: if task:
await task await task
hass.bus.async_listen(EVENT_PLATFORM_DISCOVERED, discovery_platform_listener) async_dispatcher_connect(
hass, SIGNAL_PLATFORM_DISCOVERED.format(service), discovery_platform_listener
)
@bind_hass @bind_hass
@ -147,16 +130,7 @@ def load_platform(
discovered: DiscoveryInfoType, discovered: DiscoveryInfoType,
hass_config: ConfigType, hass_config: ConfigType,
) -> None: ) -> None:
"""Load a component and platform dynamically. """Load a component and platform dynamically."""
Target components will be loaded and an EVENT_PLATFORM_DISCOVERED will be
fired to load the platform. The event will contain:
{ ATTR_SERVICE = EVENT_LOAD_PLATFORM + '.' + <<component>>
ATTR_PLATFORM = <<platform>>
ATTR_DISCOVERED = <<discovery info>> }
Use `listen_platform` to register a callback for these events.
"""
hass.add_job( hass.add_job(
async_load_platform( # type: ignore async_load_platform( # type: ignore
hass, component, platform, discovered, hass_config hass, component, platform, discovered, hass_config
@ -174,18 +148,10 @@ async def async_load_platform(
) -> None: ) -> None:
"""Load a component and platform dynamically. """Load a component and platform dynamically.
Target components will be loaded and an EVENT_PLATFORM_DISCOVERED will be Use `async_listen_platform` to register a callback for these events.
fired to load the platform. The event will contain:
{ ATTR_SERVICE = EVENT_LOAD_PLATFORM + '.' + <<component>>
ATTR_PLATFORM = <<platform>>
ATTR_DISCOVERED = <<discovery info>> }
Use `listen_platform` to register a callback for these events.
Warning: Do not await this inside a setup method to avoid a dead lock. Warning: Do not await this inside a setup method to avoid a dead lock.
Use `hass.async_create_task(async_load_platform(..))` instead. Use `hass.async_create_task(async_load_platform(..))` instead.
This method is a coroutine.
""" """
assert hass_config, "You need to pass in the real hass config" assert hass_config, "You need to pass in the real hass config"
@ -194,16 +160,16 @@ async def async_load_platform(
if component not in hass.config.components: if component not in hass.config.components:
setup_success = await setup.async_setup_component(hass, component, hass_config) setup_success = await setup.async_setup_component(hass, component, hass_config)
# No need to fire event if we could not set up component # No need to send signal if we could not set up component
if not setup_success: if not setup_success:
return return
data: Dict[str, Any] = { service = EVENT_LOAD_PLATFORM.format(component)
ATTR_SERVICE: EVENT_LOAD_PLATFORM.format(component),
ATTR_PLATFORM: platform, data: DiscoveryDict = {
"service": service,
"platform": platform,
"discovered": discovered,
} }
if discovered is not None: async_dispatcher_send(hass, SIGNAL_PLATFORM_DISCOVERED.format(service), data)
data[ATTR_DISCOVERED] = discovered
hass.bus.async_fire(EVENT_PLATFORM_DISCOVERED, data)

View File

@ -36,11 +36,8 @@ from homeassistant.components.device_automation import ( # noqa: F401
from homeassistant.components.mqtt.models import Message from homeassistant.components.mqtt.models import Message
from homeassistant.config import async_process_component_config from homeassistant.config import async_process_component_config
from homeassistant.const import ( from homeassistant.const import (
ATTR_DISCOVERED,
ATTR_SERVICE,
DEVICE_DEFAULT_NAME, DEVICE_DEFAULT_NAME,
EVENT_HOMEASSISTANT_CLOSE, EVENT_HOMEASSISTANT_CLOSE,
EVENT_PLATFORM_DISCOVERED,
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
EVENT_TIME_CHANGED, EVENT_TIME_CHANGED,
STATE_OFF, STATE_OFF,
@ -387,21 +384,6 @@ def async_fire_time_changed(hass, datetime_, fire_all=False):
fire_time_changed = threadsafe_callback_factory(async_fire_time_changed) fire_time_changed = threadsafe_callback_factory(async_fire_time_changed)
def fire_service_discovered(hass, service, info):
"""Fire the MQTT message."""
hass.bus.fire(
EVENT_PLATFORM_DISCOVERED, {ATTR_SERVICE: service, ATTR_DISCOVERED: info}
)
@ha.callback
def async_fire_service_discovered(hass, service, info):
"""Fire the MQTT message."""
hass.bus.async_fire(
EVENT_PLATFORM_DISCOVERED, {ATTR_SERVICE: service, ATTR_DISCOVERED: info}
)
def load_fixture(filename): def load_fixture(filename):
"""Load a fixture.""" """Load a fixture."""
path = os.path.join(os.path.dirname(__file__), "fixtures", filename) path = os.path.join(os.path.dirname(__file__), "fixtures", filename)

View File

@ -4,6 +4,8 @@ from unittest.mock import patch
from homeassistant import setup from homeassistant import setup
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers import discovery from homeassistant.helpers import discovery
from homeassistant.helpers.dispatcher import dispatcher_send
from homeassistant.util.async_ import run_callback_threadsafe
from tests.common import ( from tests.common import (
MockModule, MockModule,
@ -31,23 +33,22 @@ class TestHelpersDiscovery:
"""Test discovery listen/discover combo.""" """Test discovery listen/discover combo."""
helpers = self.hass.helpers helpers = self.hass.helpers
calls_single = [] calls_single = []
calls_multi = []
@callback @callback
def callback_single(service, info): def callback_single(service, info):
"""Service discovered callback.""" """Service discovered callback."""
calls_single.append((service, info)) calls_single.append((service, info))
@callback self.hass.add_job(
def callback_multi(service, info): helpers.discovery.async_listen, "test service", callback_single
"""Service discovered callback.""" )
calls_multi.append((service, info))
helpers.discovery.listen("test service", callback_single) self.hass.add_job(
helpers.discovery.listen(["test service", "another service"], callback_multi) helpers.discovery.async_discover,
"test service",
helpers.discovery.discover( "discovery info",
"test service", "discovery info", "test_component", {} "test_component",
{},
) )
self.hass.block_till_done() self.hass.block_till_done()
@ -56,15 +57,6 @@ class TestHelpersDiscovery:
assert len(calls_single) == 1 assert len(calls_single) == 1
assert calls_single[0] == ("test service", "discovery info") assert calls_single[0] == ("test service", "discovery info")
helpers.discovery.discover(
"another service", "discovery info", "test_component", {}
)
self.hass.block_till_done()
assert len(calls_single) == 1
assert len(calls_multi) == 2
assert ["test service", "another service"] == [info[0] for info in calls_multi]
@patch("homeassistant.setup.async_setup_component", return_value=mock_coro(True)) @patch("homeassistant.setup.async_setup_component", return_value=mock_coro(True))
def test_platform(self, mock_setup_component): def test_platform(self, mock_setup_component):
"""Test discover platform method.""" """Test discover platform method."""
@ -75,7 +67,13 @@ class TestHelpersDiscovery:
"""Platform callback method.""" """Platform callback method."""
calls.append((platform, info)) calls.append((platform, info))
discovery.listen_platform(self.hass, "test_component", platform_callback) run_callback_threadsafe(
self.hass.loop,
discovery.async_listen_platform,
self.hass,
"test_component",
platform_callback,
).result()
discovery.load_platform( discovery.load_platform(
self.hass, self.hass,
@ -105,13 +103,10 @@ class TestHelpersDiscovery:
assert len(calls) == 1 assert len(calls) == 1
assert calls[0] == ("test_platform", "discovery info") assert calls[0] == ("test_platform", "discovery info")
self.hass.bus.fire( dispatcher_send(
discovery.EVENT_PLATFORM_DISCOVERED, self.hass,
{ discovery.SIGNAL_PLATFORM_DISCOVERED,
discovery.ATTR_SERVICE: discovery.EVENT_LOAD_PLATFORM.format( {"service": discovery.EVENT_LOAD_PLATFORM.format("test_component")},
"test_component"
)
},
) )
self.hass.block_till_done() self.hass.block_till_done()
@ -179,10 +174,12 @@ class TestHelpersDiscovery:
""" """
component_calls = [] component_calls = []
def component1_setup(hass, config): async def component1_setup(hass, config):
"""Set up mock component.""" """Set up mock component."""
print("component1 setup") print("component1 setup")
discovery.discover(hass, "test_component2", {}, "test_component2", {}) await discovery.async_discover(
hass, "test_component2", {}, "test_component2", {}
)
return True return True
def component2_setup(hass, config): def component2_setup(hass, config):
@ -191,7 +188,7 @@ class TestHelpersDiscovery:
return True return True
mock_integration( mock_integration(
self.hass, MockModule("test_component1", setup=component1_setup) self.hass, MockModule("test_component1", async_setup=component1_setup)
) )
mock_integration( mock_integration(

View File

@ -441,10 +441,14 @@ class TestSetup:
"""Test all init work done till start.""" """Test all init work done till start."""
call_order = [] call_order = []
def component1_setup(hass, config): async def component1_setup(hass, config):
"""Set up mock component.""" """Set up mock component."""
discovery.discover(hass, "test_component2", {}, "test_component2", {}) await discovery.async_discover(
discovery.discover(hass, "test_component3", {}, "test_component3", {}) hass, "test_component2", {}, "test_component2", {}
)
await discovery.async_discover(
hass, "test_component3", {}, "test_component3", {}
)
return True return True
def component_track_setup(hass, config): def component_track_setup(hass, config):
@ -453,7 +457,7 @@ class TestSetup:
return True return True
mock_integration( mock_integration(
self.hass, MockModule("test_component1", setup=component1_setup) self.hass, MockModule("test_component1", async_setup=component1_setup)
) )
mock_integration( mock_integration(