diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index d8d3c76a028..8194bb72ca5 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -22,6 +22,7 @@ class TypeHintMatch: function_name: str arg_types: dict[int, str] return_type: list[str] | str | None | object + check_return_type_inheritance: bool = False @dataclass @@ -380,6 +381,14 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = { ClassTypeHintMatch( base_class="ConfigFlow", matches=[ + TypeHintMatch( + function_name="async_get_options_flow", + arg_types={ + 0: "ConfigEntry", + }, + return_type="OptionsFlow", + check_return_type_inheritance=True, + ), TypeHintMatch( function_name="async_step_dhcp", arg_types={ @@ -504,6 +513,32 @@ def _is_valid_type( return isinstance(node, nodes.Attribute) and node.attrname == expected_type +def _is_valid_return_type(match: TypeHintMatch, node: nodes.NodeNG) -> bool: + if _is_valid_type(match.return_type, node): + return True + + if isinstance(node, nodes.BinOp): + return _is_valid_return_type(match, node.left) and _is_valid_return_type( + match, node.right + ) + + if ( + match.check_return_type_inheritance + and isinstance(match.return_type, str) + and isinstance(node, nodes.Name) + ): + ancestor: nodes.ClassDef + for infer_node in node.infer(): + if isinstance(infer_node, nodes.ClassDef): + if infer_node.name == match.return_type: + return True + for ancestor in infer_node.ancestors(): + if ancestor.name == match.return_type: + return True + + return False + + def _get_all_annotations(node: nodes.FunctionDef) -> list[nodes.NodeNG | None]: args = node.args annotations: list[nodes.NodeNG | None] = ( @@ -619,8 +654,10 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc] ) # Check the return type. - if not _is_valid_type(return_type := match.return_type, node.returns): - self.add_message("hass-return-type", node=node, args=return_type or "None") + if not _is_valid_return_type(match, node.returns): + self.add_message( + "hass-return-type", node=node, args=match.return_type or "None" + ) def register(linter: PyLinter) -> None: diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index fa3d32dcb04..86b06a894d0 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -349,3 +349,88 @@ def test_valid_config_flow_step( with assert_no_messages(linter): type_hint_checker.visit_classdef(class_node) + + +def test_invalid_config_flow_async_get_options_flow( + linter: UnittestLinter, type_hint_checker: BaseChecker +) -> None: + """Ensure invalid hints are rejected for ConfigFlow async_get_options_flow.""" + class_node, func_node, arg_node = astroid.extract_node( + """ + class ConfigFlow(): + pass + + class AxisOptionsFlow(): + pass + + class AxisFlowHandler( #@ + ConfigFlow, domain=AXIS_DOMAIN + ): + def async_get_options_flow( #@ + config_entry #@ + ) -> AxisOptionsFlow: + return AxisOptionsFlow(config_entry) + """, + "homeassistant.components.pylint_test.config_flow", + ) + type_hint_checker.visit_module(class_node.parent) + + with assert_adds_messages( + linter, + pylint.testutils.MessageTest( + msg_id="hass-argument-type", + node=arg_node, + args=(1, "ConfigEntry"), + line=12, + col_offset=8, + end_line=12, + end_col_offset=20, + ), + pylint.testutils.MessageTest( + msg_id="hass-return-type", + node=func_node, + args="OptionsFlow", + line=11, + col_offset=4, + end_line=11, + end_col_offset=30, + ), + ): + type_hint_checker.visit_classdef(class_node) + + +def test_valid_config_flow_async_get_options_flow( + linter: UnittestLinter, type_hint_checker: BaseChecker +) -> None: + """Ensure valid hints are accepted for ConfigFlow async_get_options_flow.""" + class_node = astroid.extract_node( + """ + class ConfigFlow(): + pass + + class OptionsFlow(): + pass + + class AxisOptionsFlow(OptionsFlow): + pass + + class OtherOptionsFlow(OptionsFlow): + pass + + class AxisFlowHandler( #@ + ConfigFlow, domain=AXIS_DOMAIN + ): + def async_get_options_flow( + config_entry: ConfigEntry + ) -> AxisOptionsFlow | OtherOptionsFlow | OptionsFlow: + if self.use_other: + return OtherOptionsFlow(config_entry) + return AxisOptionsFlow(config_entry) + + """, + "homeassistant.components.pylint_test.config_flow", + ) + type_hint_checker.visit_module(class_node.parent) + + with assert_no_messages(linter): + type_hint_checker.visit_classdef(class_node)