diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index a1bf260e968..307510c6621 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -21,10 +21,12 @@ class TypeHintMatch: function_name: str return_type: list[str] | str | None | object - # arg_types is for positional arguments arg_types: dict[int, str] | None = None - # kwarg_types is for the special case `**kwargs` + """arg_types is for positional arguments""" + named_arg_types: dict[str, str] | None = None + """named_arg_types is for named or keyword arguments""" kwargs_type: str | None = None + """kwargs_type is for the special case `**kwargs`""" check_return_type_inheritance: bool = False @@ -448,6 +450,111 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = { # Overriding properties and functions are normally checked by mypy, and will only # be checked by pylint when --ignore-missing-annotations is False _INHERITANCE_MATCH: dict[str, list[ClassTypeHintMatch]] = { + "fan": [ + ClassTypeHintMatch( + base_class="FanEntity", + matches=[ + TypeHintMatch( + function_name="is_on", + return_type=["bool", None], + ), + TypeHintMatch( + function_name="percentage", + return_type=["int", None], + ), + TypeHintMatch( + function_name="speed_count", + return_type="int", + ), + TypeHintMatch( + function_name="percentage_step", + return_type="float", + ), + TypeHintMatch( + function_name="current_direction", + return_type=["str", None], + ), + TypeHintMatch( + function_name="oscillating", + return_type=["bool", None], + ), + TypeHintMatch( + function_name="capability_attributes", + return_type="dict[str]", + ), + TypeHintMatch( + function_name="supported_features", + return_type="int", + ), + TypeHintMatch( + function_name="preset_mode", + return_type=["str", None], + ), + TypeHintMatch( + function_name="preset_modes", + return_type=["list[str]", None], + ), + TypeHintMatch( + function_name="set_percentage", + arg_types={1: "int"}, + return_type=None, + ), + TypeHintMatch( + function_name="async_set_percentage", + arg_types={1: "int"}, + return_type=None, + ), + TypeHintMatch( + function_name="set_preset_mode", + arg_types={1: "str"}, + return_type=None, + ), + TypeHintMatch( + function_name="async_set_preset_mode", + arg_types={1: "str"}, + return_type=None, + ), + TypeHintMatch( + function_name="set_direction", + arg_types={1: "str"}, + return_type=None, + ), + TypeHintMatch( + function_name="async_set_direction", + arg_types={1: "str"}, + return_type=None, + ), + TypeHintMatch( + function_name="turn_on", + named_arg_types={ + "percentage": "int | None", + "preset_mode": "str | None", + }, + kwargs_type="Any", + return_type=None, + ), + TypeHintMatch( + function_name="async_turn_on", + named_arg_types={ + "percentage": "int | None", + "preset_mode": "str | None", + }, + kwargs_type="Any", + return_type=None, + ), + TypeHintMatch( + function_name="oscillate", + arg_types={1: "bool"}, + return_type=None, + ), + TypeHintMatch( + function_name="async_oscillate", + arg_types={1: "bool"}, + return_type=None, + ), + ], + ), + ], "lock": [ ClassTypeHintMatch( base_class="LockEntity", @@ -619,6 +726,21 @@ def _get_all_annotations(node: nodes.FunctionDef) -> list[nodes.NodeNG | None]: return annotations +def _get_named_annotation( + node: nodes.FunctionDef, key: str +) -> tuple[nodes.NodeNG, nodes.NodeNG] | tuple[None, None]: + args = node.args + for index, arg_node in enumerate(args.args): + if key == arg_node.name: + return arg_node, args.annotations[index] + + for index, arg_node in enumerate(args.kwonlyargs): + if key == arg_node.name: + return arg_node, args.kwonlyargs_annotations[index] + + return None, None + + def _has_valid_annotations( annotations: list[nodes.NodeNG | None], ) -> bool: @@ -742,6 +864,17 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc] args=(key + 1, expected_type), ) + # Check that all keyword arguments are correctly annotated. + if match.named_arg_types is not None: + for arg_name, expected_type in match.named_arg_types.items(): + arg_node, annotation = _get_named_annotation(node, arg_name) + if arg_node and not _is_valid_type(expected_type, annotation): + self.add_message( + "hass-argument-type", + node=arg_node, + args=(arg_name, expected_type), + ) + # Check that kwargs is correctly annotated. if match.kwargs_type and not _is_valid_type( match.kwargs_type, node.args.kwargannotation diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index 262ff93afa8..54c7cf6ec4c 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -565,3 +565,73 @@ def test_ignore_invalid_entity_properties( with assert_no_messages(linter): type_hint_checker.visit_classdef(class_node) + + +def test_named_arguments( + linter: UnittestLinter, type_hint_checker: BaseChecker +) -> None: + """Check missing entity properties when ignore_missing_annotations is False.""" + # Set bypass option + type_hint_checker.config.ignore_missing_annotations = False + + class_node, func_node, percentage_node, preset_mode_node = astroid.extract_node( + """ + class FanEntity(): + pass + + class MyFan( #@ + FanEntity + ): + async def async_turn_on( #@ + self, + percentage, #@ + *, + preset_mode: str, #@ + **kwargs + ) -> bool: + pass + """, + "homeassistant.components.pylint_test.fan", + ) + type_hint_checker.visit_module(class_node.parent) + + with assert_adds_messages( + linter, + pylint.testutils.MessageTest( + msg_id="hass-argument-type", + node=percentage_node, + args=("percentage", "int | None"), + line=10, + col_offset=8, + end_line=10, + end_col_offset=18, + ), + pylint.testutils.MessageTest( + msg_id="hass-argument-type", + node=preset_mode_node, + args=("preset_mode", "str | None"), + line=12, + col_offset=8, + end_line=12, + end_col_offset=24, + ), + pylint.testutils.MessageTest( + msg_id="hass-argument-type", + node=func_node, + args=("kwargs", "Any"), + line=8, + col_offset=4, + end_line=8, + end_col_offset=27, + ), + pylint.testutils.MessageTest( + msg_id="hass-return-type", + node=func_node, + args="None", + line=8, + col_offset=4, + end_line=8, + end_col_offset=27, + ), + ): + type_hint_checker.visit_classdef(class_node)