diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index f3dd7a106c6..31b396c3196 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -25,6 +25,14 @@ class TypeHintMatch: return_type: list[str] | str | None | object +@dataclass +class ClassTypeHintMatch: + """Class for pattern matching.""" + + base_class: str + matches: list[TypeHintMatch] + + _TYPE_HINT_MATCHERS: dict[str, re.Pattern[str]] = { # a_or_b matches items such as "DiscoveryInfoType | None" "a_or_b": re.compile(r"^(\w+) \| (\w+)$"), @@ -368,6 +376,65 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { ], } +_CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = { + "config_flow": [ + ClassTypeHintMatch( + base_class="ConfigFlow", + matches=[ + TypeHintMatch( + function_name="async_step_dhcp", + arg_types={ + 1: "DhcpServiceInfo", + }, + return_type="FlowResult", + ), + TypeHintMatch( + function_name="async_step_hassio", + arg_types={ + 1: "HassioServiceInfo", + }, + return_type="FlowResult", + ), + TypeHintMatch( + function_name="async_step_homekit", + arg_types={ + 1: "ZeroconfServiceInfo", + }, + return_type="FlowResult", + ), + TypeHintMatch( + function_name="async_step_mqtt", + arg_types={ + 1: "MqttServiceInfo", + }, + return_type="FlowResult", + ), + TypeHintMatch( + function_name="async_step_ssdp", + arg_types={ + 1: "SsdpServiceInfo", + }, + return_type="FlowResult", + ), + TypeHintMatch( + function_name="async_step_usb", + arg_types={ + 1: "UsbServiceInfo", + }, + return_type="FlowResult", + ), + TypeHintMatch( + function_name="async_step_zeroconf", + arg_types={ + 1: "ZeroconfServiceInfo", + }, + return_type="FlowResult", + ), + ], + ), + ] +} + def _is_valid_type( expected_type: list[str] | str | None | object, node: astroid.NodeNG @@ -494,10 +561,12 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc] def __init__(self, linter: PyLinter | None = None) -> None: super().__init__(linter) self._function_matchers: list[TypeHintMatch] = [] + self._class_matchers: list[ClassTypeHintMatch] = [] def visit_module(self, node: astroid.Module) -> None: """Called when a Module node is visited.""" self._function_matchers = [] + self._class_matchers = [] if (module_platform := _get_module_platform(node.name)) is None: return @@ -505,8 +574,28 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc] if module_platform in _PLATFORMS: self._function_matchers.extend(_FUNCTION_MATCH["__any_platform__"]) - if matches := _FUNCTION_MATCH.get(module_platform): - self._function_matchers.extend(matches) + if function_matches := _FUNCTION_MATCH.get(module_platform): + self._function_matchers.extend(function_matches) + + if class_matches := _CLASS_MATCH.get(module_platform): + self._class_matchers = class_matches + + def visit_classdef(self, node: astroid.ClassDef) -> None: + """Called when a ClassDef node is visited.""" + ancestor: astroid.ClassDef + for ancestor in node.ancestors(): + for class_matches in self._class_matchers: + if ancestor.name == class_matches.base_class: + self._visit_class_functions(node, class_matches.matches) + + def _visit_class_functions( + self, node: astroid.ClassDef, matches: list[TypeHintMatch] + ) -> None: + for match in matches: + for function_node in node.mymethods(): + function_name: str | None = function_node.name + if match.function_name == function_name: + self._check_function(function_node, match) def visit_functiondef(self, node: astroid.FunctionDef) -> None: """Called when a FunctionDef node is visited.""" diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index 0bd273985e3..fa3d32dcb04 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -277,3 +277,75 @@ def test_valid_list_dict_str_any( with assert_no_messages(linter): type_hint_checker.visit_asyncfunctiondef(func_node) + + +def test_invalid_config_flow_step( + linter: UnittestLinter, type_hint_checker: BaseChecker +) -> None: + """Ensure invalid hints are rejected for ConfigFlow step.""" + class_node, func_node, arg_node = astroid.extract_node( + """ + class ConfigFlow(): + pass + + class AxisFlowHandler( #@ + ConfigFlow, domain=AXIS_DOMAIN + ): + async def async_step_zeroconf( #@ + self, + device_config: dict #@ + ): + pass + """, + "homeassistant.components.pylint_test.config_flow", + ) + type_hint_checker.visit_module(class_node.parent) + + with assert_adds_messages( + linter, + pylint.testutils.MessageTest( + msg_id="hass-argument-type", + node=arg_node, + args=(2, "ZeroconfServiceInfo"), + line=10, + col_offset=8, + end_line=10, + end_col_offset=27, + ), + pylint.testutils.MessageTest( + msg_id="hass-return-type", + node=func_node, + args="FlowResult", + line=8, + col_offset=4, + end_line=8, + end_col_offset=33, + ), + ): + type_hint_checker.visit_classdef(class_node) + + +def test_valid_config_flow_step( + linter: UnittestLinter, type_hint_checker: BaseChecker +) -> None: + """Ensure valid hints are accepted for ConfigFlow step.""" + class_node = astroid.extract_node( + """ + class ConfigFlow(): + pass + + class AxisFlowHandler( #@ + ConfigFlow, domain=AXIS_DOMAIN + ): + async def async_step_zeroconf( + self, + device_config: ZeroconfServiceInfo + ) -> FlowResult: + pass + """, + "homeassistant.components.pylint_test.config_flow", + ) + type_hint_checker.visit_module(class_node.parent) + + with assert_no_messages(linter): + type_hint_checker.visit_classdef(class_node)