From 75057949d1a0d84cf906591b622381dd0f5599c5 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Thu, 25 Nov 2021 02:30:02 +0100 Subject: [PATCH] Adjust async_step_discovery methods for BaseServiceInfo (#60285) Co-authored-by: epenet --- homeassistant/config_entries.py | 11 +++++----- tests/components/spotify/test_config_flow.py | 14 +++++++++++-- .../helpers/test_config_entry_oauth2_flow.py | 8 ++++++-- tests/test_config_entries.py | 20 +++++++++---------- 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 9a6709820ff..9f1731e40d6 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio from collections.abc import Iterable, Mapping from contextvars import ContextVar +import dataclasses from enum import Enum import functools import logging @@ -1360,13 +1361,13 @@ class ConfigFlow(data_entry_flow.FlowHandler): self, discovery_info: ZeroconfServiceInfo ) -> data_entry_flow.FlowResult: """Handle a flow initialized by Homekit discovery.""" - return await self.async_step_discovery(cast(dict, discovery_info)) + return await self.async_step_discovery(dataclasses.asdict(discovery_info)) async def async_step_mqtt( self, discovery_info: MqttServiceInfo ) -> data_entry_flow.FlowResult: """Handle a flow initialized by MQTT discovery.""" - return await self.async_step_discovery(cast(dict, discovery_info)) + return await self.async_step_discovery(dataclasses.asdict(discovery_info)) async def async_step_ssdp( self, discovery_info: DiscoveryInfoType @@ -1378,19 +1379,19 @@ class ConfigFlow(data_entry_flow.FlowHandler): self, discovery_info: ZeroconfServiceInfo ) -> data_entry_flow.FlowResult: """Handle a flow initialized by Zeroconf discovery.""" - return await self.async_step_discovery(cast(dict, discovery_info)) + return await self.async_step_discovery(dataclasses.asdict(discovery_info)) async def async_step_dhcp( self, discovery_info: DhcpServiceInfo ) -> data_entry_flow.FlowResult: """Handle a flow initialized by DHCP discovery.""" - return await self.async_step_discovery(cast(dict, discovery_info)) + return await self.async_step_discovery(dataclasses.asdict(discovery_info)) async def async_step_usb( self, discovery_info: UsbServiceInfo ) -> data_entry_flow.FlowResult: """Handle a flow initialized by USB discovery.""" - return await self.async_step_discovery(cast(dict, discovery_info)) + return await self.async_step_discovery(dataclasses.asdict(discovery_info)) @callback def async_create_entry( # pylint: disable=arguments-differ diff --git a/tests/components/spotify/test_config_flow.py b/tests/components/spotify/test_config_flow.py index 8ff18882e8b..fb0279f9112 100644 --- a/tests/components/spotify/test_config_flow.py +++ b/tests/components/spotify/test_config_flow.py @@ -5,6 +5,7 @@ from unittest.mock import patch from spotipy import SpotifyException from homeassistant import data_entry_flow, setup +from homeassistant.components import zeroconf from homeassistant.components.spotify.const import DOMAIN from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER, SOURCE_ZEROCONF from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET @@ -12,6 +13,15 @@ from homeassistant.helpers import config_entry_oauth2_flow from tests.common import MockConfigEntry +BLANK_ZEROCONF_INFO = zeroconf.ZeroconfServiceInfo( + host="1.2.3.4", + hostname="mock_hostname", + name="mock_name", + port=None, + properties={}, + type="mock_type", +) + async def test_abort_if_no_configuration(hass): """Check flow aborts when no configuration is present.""" @@ -23,7 +33,7 @@ async def test_abort_if_no_configuration(hass): assert result["reason"] == "missing_configuration" result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_ZEROCONF} + DOMAIN, context={"source": SOURCE_ZEROCONF}, data=BLANK_ZEROCONF_INFO ) assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT @@ -35,7 +45,7 @@ async def test_zeroconf_abort_if_existing_entry(hass): MockConfigEntry(domain=DOMAIN).add_to_hass(hass) result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_ZEROCONF} + DOMAIN, context={"source": SOURCE_ZEROCONF}, data=BLANK_ZEROCONF_INFO ) assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT diff --git a/tests/helpers/test_config_entry_oauth2_flow.py b/tests/helpers/test_config_entry_oauth2_flow.py index 52dda703f1e..48868a3727b 100644 --- a/tests/helpers/test_config_entry_oauth2_flow.py +++ b/tests/helpers/test_config_entry_oauth2_flow.py @@ -220,7 +220,9 @@ async def test_step_discovery(hass, flow_handler, local_impl): ) result = await hass.config_entries.flow.async_init( - TEST_DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF} + TEST_DOMAIN, + context={"source": config_entries.SOURCE_ZEROCONF}, + data=data_entry_flow.BaseServiceInfo(), ) assert result["type"] == data_entry_flow.RESULT_TYPE_FORM @@ -242,7 +244,9 @@ async def test_abort_discovered_multiple(hass, flow_handler, local_impl): assert result["step_id"] == "pick_implementation" result = await hass.config_entries.flow.async_init( - TEST_DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF} + TEST_DOMAIN, + context={"source": config_entries.SOURCE_ZEROCONF}, + data=data_entry_flow.BaseServiceInfo(), ) assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 85d64de70a2..0b4d7fa8799 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -9,7 +9,7 @@ import pytest from homeassistant import config_entries, data_entry_flow, loader from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.core import CoreState, callback -from homeassistant.data_entry_flow import RESULT_TYPE_ABORT +from homeassistant.data_entry_flow import RESULT_TYPE_ABORT, BaseServiceInfo from homeassistant.exceptions import ( ConfigEntryAuthFailed, ConfigEntryNotReady, @@ -2350,13 +2350,13 @@ async def test_async_setup_update_entry(hass): @pytest.mark.parametrize( "discovery_source", ( - config_entries.SOURCE_DISCOVERY, - config_entries.SOURCE_SSDP, - config_entries.SOURCE_USB, - config_entries.SOURCE_HOMEKIT, - config_entries.SOURCE_DHCP, - config_entries.SOURCE_ZEROCONF, - config_entries.SOURCE_HASSIO, + (config_entries.SOURCE_DISCOVERY, {}), + (config_entries.SOURCE_SSDP, {}), + (config_entries.SOURCE_USB, BaseServiceInfo()), + (config_entries.SOURCE_HOMEKIT, BaseServiceInfo()), + (config_entries.SOURCE_DHCP, BaseServiceInfo()), + (config_entries.SOURCE_ZEROCONF, BaseServiceInfo()), + (config_entries.SOURCE_HASSIO, {}), ), ) async def test_flow_with_default_discovery(hass, manager, discovery_source): @@ -2382,7 +2382,7 @@ async def test_flow_with_default_discovery(hass, manager, discovery_source): with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): # Create one to be in progress result = await manager.flow.async_init( - "comp", context={"source": discovery_source} + "comp", context={"source": discovery_source[0]}, data=discovery_source[1] ) assert result["type"] == data_entry_flow.RESULT_TYPE_FORM @@ -2403,7 +2403,7 @@ async def test_flow_with_default_discovery(hass, manager, discovery_source): entry = hass.config_entries.async_entries("comp")[0] assert entry.title == "yo" - assert entry.source == discovery_source + assert entry.source == discovery_source[0] assert entry.unique_id is None