Adjust async_step_discovery methods for BaseServiceInfo (#60285)

Co-authored-by: epenet <epenet@users.noreply.github.com>
This commit is contained in:
epenet 2021-11-25 02:30:02 +01:00 committed by GitHub
parent 0920e74aa2
commit 75057949d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 19 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from contextvars import ContextVar from contextvars import ContextVar
import dataclasses
from enum import Enum from enum import Enum
import functools import functools
import logging import logging
@ -1360,13 +1361,13 @@ class ConfigFlow(data_entry_flow.FlowHandler):
self, discovery_info: ZeroconfServiceInfo self, discovery_info: ZeroconfServiceInfo
) -> data_entry_flow.FlowResult: ) -> data_entry_flow.FlowResult:
"""Handle a flow initialized by Homekit discovery.""" """Handle a flow initialized by Homekit discovery."""
return await self.async_step_discovery(cast(dict, discovery_info)) return await self.async_step_discovery(dataclasses.asdict(discovery_info))
async def async_step_mqtt( async def async_step_mqtt(
self, discovery_info: MqttServiceInfo self, discovery_info: MqttServiceInfo
) -> data_entry_flow.FlowResult: ) -> data_entry_flow.FlowResult:
"""Handle a flow initialized by MQTT discovery.""" """Handle a flow initialized by MQTT discovery."""
return await self.async_step_discovery(cast(dict, discovery_info)) return await self.async_step_discovery(dataclasses.asdict(discovery_info))
async def async_step_ssdp( async def async_step_ssdp(
self, discovery_info: DiscoveryInfoType self, discovery_info: DiscoveryInfoType
@ -1378,19 +1379,19 @@ class ConfigFlow(data_entry_flow.FlowHandler):
self, discovery_info: ZeroconfServiceInfo self, discovery_info: ZeroconfServiceInfo
) -> data_entry_flow.FlowResult: ) -> data_entry_flow.FlowResult:
"""Handle a flow initialized by Zeroconf discovery.""" """Handle a flow initialized by Zeroconf discovery."""
return await self.async_step_discovery(cast(dict, discovery_info)) return await self.async_step_discovery(dataclasses.asdict(discovery_info))
async def async_step_dhcp( async def async_step_dhcp(
self, discovery_info: DhcpServiceInfo self, discovery_info: DhcpServiceInfo
) -> data_entry_flow.FlowResult: ) -> data_entry_flow.FlowResult:
"""Handle a flow initialized by DHCP discovery.""" """Handle a flow initialized by DHCP discovery."""
return await self.async_step_discovery(cast(dict, discovery_info)) return await self.async_step_discovery(dataclasses.asdict(discovery_info))
async def async_step_usb( async def async_step_usb(
self, discovery_info: UsbServiceInfo self, discovery_info: UsbServiceInfo
) -> data_entry_flow.FlowResult: ) -> data_entry_flow.FlowResult:
"""Handle a flow initialized by USB discovery.""" """Handle a flow initialized by USB discovery."""
return await self.async_step_discovery(cast(dict, discovery_info)) return await self.async_step_discovery(dataclasses.asdict(discovery_info))
@callback @callback
def async_create_entry( # pylint: disable=arguments-differ def async_create_entry( # pylint: disable=arguments-differ

View File

@ -5,6 +5,7 @@ from unittest.mock import patch
from spotipy import SpotifyException from spotipy import SpotifyException
from homeassistant import data_entry_flow, setup from homeassistant import data_entry_flow, setup
from homeassistant.components import zeroconf
from homeassistant.components.spotify.const import DOMAIN from homeassistant.components.spotify.const import DOMAIN
from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER, SOURCE_ZEROCONF from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER, SOURCE_ZEROCONF
from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET
@ -12,6 +13,15 @@ from homeassistant.helpers import config_entry_oauth2_flow
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
BLANK_ZEROCONF_INFO = zeroconf.ZeroconfServiceInfo(
host="1.2.3.4",
hostname="mock_hostname",
name="mock_name",
port=None,
properties={},
type="mock_type",
)
async def test_abort_if_no_configuration(hass): async def test_abort_if_no_configuration(hass):
"""Check flow aborts when no configuration is present.""" """Check flow aborts when no configuration is present."""
@ -23,7 +33,7 @@ async def test_abort_if_no_configuration(hass):
assert result["reason"] == "missing_configuration" assert result["reason"] == "missing_configuration"
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_ZEROCONF} DOMAIN, context={"source": SOURCE_ZEROCONF}, data=BLANK_ZEROCONF_INFO
) )
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
@ -35,7 +45,7 @@ async def test_zeroconf_abort_if_existing_entry(hass):
MockConfigEntry(domain=DOMAIN).add_to_hass(hass) MockConfigEntry(domain=DOMAIN).add_to_hass(hass)
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_ZEROCONF} DOMAIN, context={"source": SOURCE_ZEROCONF}, data=BLANK_ZEROCONF_INFO
) )
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT

View File

@ -220,7 +220,9 @@ async def test_step_discovery(hass, flow_handler, local_impl):
) )
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
TEST_DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF} TEST_DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=data_entry_flow.BaseServiceInfo(),
) )
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
@ -242,7 +244,9 @@ async def test_abort_discovered_multiple(hass, flow_handler, local_impl):
assert result["step_id"] == "pick_implementation" assert result["step_id"] == "pick_implementation"
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
TEST_DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF} TEST_DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=data_entry_flow.BaseServiceInfo(),
) )
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT

View File

@ -9,7 +9,7 @@ import pytest
from homeassistant import config_entries, data_entry_flow, loader from homeassistant import config_entries, data_entry_flow, loader
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import CoreState, callback from homeassistant.core import CoreState, callback
from homeassistant.data_entry_flow import RESULT_TYPE_ABORT from homeassistant.data_entry_flow import RESULT_TYPE_ABORT, BaseServiceInfo
from homeassistant.exceptions import ( from homeassistant.exceptions import (
ConfigEntryAuthFailed, ConfigEntryAuthFailed,
ConfigEntryNotReady, ConfigEntryNotReady,
@ -2350,13 +2350,13 @@ async def test_async_setup_update_entry(hass):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"discovery_source", "discovery_source",
( (
config_entries.SOURCE_DISCOVERY, (config_entries.SOURCE_DISCOVERY, {}),
config_entries.SOURCE_SSDP, (config_entries.SOURCE_SSDP, {}),
config_entries.SOURCE_USB, (config_entries.SOURCE_USB, BaseServiceInfo()),
config_entries.SOURCE_HOMEKIT, (config_entries.SOURCE_HOMEKIT, BaseServiceInfo()),
config_entries.SOURCE_DHCP, (config_entries.SOURCE_DHCP, BaseServiceInfo()),
config_entries.SOURCE_ZEROCONF, (config_entries.SOURCE_ZEROCONF, BaseServiceInfo()),
config_entries.SOURCE_HASSIO, (config_entries.SOURCE_HASSIO, {}),
), ),
) )
async def test_flow_with_default_discovery(hass, manager, discovery_source): async def test_flow_with_default_discovery(hass, manager, discovery_source):
@ -2382,7 +2382,7 @@ async def test_flow_with_default_discovery(hass, manager, discovery_source):
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
# Create one to be in progress # Create one to be in progress
result = await manager.flow.async_init( result = await manager.flow.async_init(
"comp", context={"source": discovery_source} "comp", context={"source": discovery_source[0]}, data=discovery_source[1]
) )
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
@ -2403,7 +2403,7 @@ async def test_flow_with_default_discovery(hass, manager, discovery_source):
entry = hass.config_entries.async_entries("comp")[0] entry = hass.config_entries.async_entries("comp")[0]
assert entry.title == "yo" assert entry.title == "yo"
assert entry.source == discovery_source assert entry.source == discovery_source[0]
assert entry.unique_id is None assert entry.unique_id is None