diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index ca7777da959..4f9f7603328 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -597,6 +597,16 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = { ), ], ), + ClassTypeHintMatch( + base_class="ConfigSubentryFlow", + matches=[ + TypeHintMatch( + function_name="async_step_*", + arg_types={}, + return_type="SubentryFlowResult", + ), + ], + ), ], } # Overriding properties and functions are normally checked by mypy, and will only diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index efa3ca9523a..c9748cc61f8 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Callable import re from types import ModuleType from unittest.mock import patch @@ -375,12 +376,11 @@ def test_invalid_config_flow_step( type_hint_checker.visit_classdef(class_node) -def test_invalid_custom_config_flow_step( - linter: UnittestLinter, type_hint_checker: BaseChecker -) -> None: - """Ensure invalid hints are rejected for ConfigFlow step.""" - class_node, func_node, arg_node = astroid.extract_node( - """ +@pytest.mark.parametrize( + ("code", "expected_messages_fn"), + [ + ( + """ class FlowHandler(): pass @@ -392,34 +392,79 @@ def test_invalid_custom_config_flow_step( ): async def async_step_axis_specific( #@ self, - device_config: dict #@ + device_config: dict ): pass - """, +""", + lambda func_node: [ + pylint.testutils.MessageTest( + msg_id="hass-return-type", + node=func_node, + args=("ConfigFlowResult", "async_step_axis_specific"), + line=11, + col_offset=4, + end_line=11, + end_col_offset=38, + ), + ], + ), + ( + """ + class FlowHandler(): + pass + + class ConfigSubentryFlow(FlowHandler): + pass + + class CustomSubentryFlowHandler(ConfigSubentryFlow): #@ + async def async_step_user( #@ + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + pass +""", + lambda func_node: [ + pylint.testutils.MessageTest( + msg_id="hass-return-type", + node=func_node, + args=("SubentryFlowResult", "async_step_user"), + line=9, + col_offset=4, + end_line=9, + end_col_offset=29, + ), + ], + ), + ], + ids=[ + "Config flow", + "Config subentry flow", + ], +) +def test_invalid_flow_step( + linter: UnittestLinter, + type_hint_checker: BaseChecker, + code: str, + expected_messages_fn: Callable[ + [astroid.NodeNG], tuple[pylint.testutils.MessageTest, ...] + ], +) -> None: + """Ensure invalid hints are rejected for flow step.""" + class_node, func_node = astroid.extract_node( + code, "homeassistant.components.pylint_test.config_flow", ) type_hint_checker.visit_module(class_node.parent) with assert_adds_messages( linter, - pylint.testutils.MessageTest( - msg_id="hass-return-type", - node=func_node, - args=("ConfigFlowResult", "async_step_axis_specific"), - line=11, - col_offset=4, - end_line=11, - end_col_offset=38, - ), + *expected_messages_fn(func_node), ): type_hint_checker.visit_classdef(class_node) -def test_valid_config_flow_step( - linter: UnittestLinter, type_hint_checker: BaseChecker -) -> None: - """Ensure valid hints are accepted for ConfigFlow step.""" - class_node = astroid.extract_node( +@pytest.mark.parametrize( + "code", + [ """ class FlowHandler(): pass @@ -436,6 +481,33 @@ def test_valid_config_flow_step( ) -> ConfigFlowResult: pass """, + """ + class FlowHandler(): + pass + + class ConfigSubentryFlow(FlowHandler): + pass + + class CustomSubentryFlowHandler(ConfigSubentryFlow): #@ + async def async_step_user( + self, user_input: dict[str, Any] | None = None + ) -> SubentryFlowResult: + pass +""", + ], + ids=[ + "Config flow", + "Config subentry flow", + ], +) +def test_valid_flow_step( + linter: UnittestLinter, + type_hint_checker: BaseChecker, + code: str, +) -> None: + """Ensure valid hints are accepted for flow step.""" + class_node = astroid.extract_node( + code, "homeassistant.components.pylint_test.config_flow", ) type_hint_checker.visit_module(class_node.parent)