From 2bb0843c309cbf01fb5821f8ede4d1ef2c2fad44 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Mon, 19 May 2025 10:27:07 +0200 Subject: [PATCH] Add ability to mark type hints as compulsory on specific functions (#139730) --- pylint/plugins/hass_enforce_type_hints.py | 44 ++++++++++++++++------- tests/pylint/test_enforce_type_hints.py | 4 +-- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index 3e18aacaa93..9855f688622 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -50,6 +50,9 @@ class TypeHintMatch: kwargs_type: str | None = None """kwargs_type is for the special case `**kwargs`""" has_async_counterpart: bool = False + """`function_name` and `async_function_name` share arguments and return type""" + mandatory: bool = False + """bypass ignore_missing_annotations""" def need_to_check_function(self, node: nodes.FunctionDef) -> bool: """Confirm if function should be checked.""" @@ -184,6 +187,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { }, return_type="bool", has_async_counterpart=True, + mandatory=True, ), TypeHintMatch( function_name="async_setup_entry", @@ -192,6 +196,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 1: "ConfigEntry", }, return_type="bool", + mandatory=True, ), TypeHintMatch( function_name="async_remove_entry", @@ -200,6 +205,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 1: "ConfigEntry", }, return_type=None, + mandatory=True, ), TypeHintMatch( function_name="async_unload_entry", @@ -208,6 +214,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 1: "ConfigEntry", }, return_type="bool", + mandatory=True, ), TypeHintMatch( function_name="async_migrate_entry", @@ -216,6 +223,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 1: "ConfigEntry", }, return_type="bool", + mandatory=True, ), TypeHintMatch( function_name="async_remove_config_entry_device", @@ -225,6 +233,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 2: "DeviceEntry", }, return_type="bool", + mandatory=True, ), TypeHintMatch( function_name="async_reset_platform", @@ -233,6 +242,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 1: "str", }, return_type=None, + mandatory=True, ), ], "__any_platform__": [ @@ -246,6 +256,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { }, return_type=None, has_async_counterpart=True, + mandatory=True, ), TypeHintMatch( function_name="async_setup_entry", @@ -255,6 +266,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 2: "AddConfigEntryEntitiesCallback", }, return_type=None, + mandatory=True, ), ], "application_credentials": [ @@ -3195,8 +3207,11 @@ class HassTypeHintChecker(BaseChecker): self._class_matchers.reverse() - def _ignore_function( - self, node: nodes.FunctionDef, annotations: list[nodes.NodeNG | None] + def _ignore_function_match( + self, + node: nodes.FunctionDef, + annotations: list[nodes.NodeNG | None], + match: TypeHintMatch, ) -> bool: """Check if we can skip the function validation.""" return ( @@ -3204,6 +3219,8 @@ class HassTypeHintChecker(BaseChecker): not self._in_test_module # some modules have checks forced and self._module_platform not in _FORCE_ANNOTATION_PLATFORMS + # some matches have checks forced + and not match.mandatory # other modules are only checked ignore_missing_annotations and self.linter.config.ignore_missing_annotations and node.returns is None @@ -3246,7 +3263,7 @@ class HassTypeHintChecker(BaseChecker): continue annotations = _get_all_annotations(function_node) - if self._ignore_function(function_node, annotations): + if self._ignore_function_match(function_node, annotations, match): continue self._check_function(function_node, match, annotations) @@ -3255,8 +3272,6 @@ class HassTypeHintChecker(BaseChecker): def visit_functiondef(self, node: nodes.FunctionDef) -> None: """Apply relevant type hint checks on a FunctionDef node.""" annotations = _get_all_annotations(node) - if self._ignore_function(node, annotations): - return # Check method or function matchers. if node.is_method(): @@ -3277,14 +3292,15 @@ class HassTypeHintChecker(BaseChecker): matchers = self._function_matchers # Check that common arguments are correctly typed. - for arg_name, expected_type in _COMMON_ARGUMENTS.items(): - arg_node, annotation = _get_named_annotation(node, arg_name) - if arg_node and not _is_valid_type(expected_type, annotation): - self.add_message( - "hass-argument-type", - node=arg_node, - args=(arg_name, expected_type, node.name), - ) + if not self.linter.config.ignore_missing_annotations: + for arg_name, expected_type in _COMMON_ARGUMENTS.items(): + arg_node, annotation = _get_named_annotation(node, arg_name) + if arg_node and not _is_valid_type(expected_type, annotation): + self.add_message( + "hass-argument-type", + node=arg_node, + args=(arg_name, expected_type, node.name), + ) for match in matchers: if not match.need_to_check_function(node): @@ -3299,6 +3315,8 @@ class HassTypeHintChecker(BaseChecker): match: TypeHintMatch, annotations: list[nodes.NodeNG | None], ) -> None: + if self._ignore_function_match(node, annotations, match): + return # Check that all positional arguments are correctly annotated. if match.arg_types: for key, expected_type in match.arg_types.items(): diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index c9748cc61f8..9179a545256 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -99,7 +99,7 @@ def test_regex_a_or_b( "code", [ """ - async def setup( #@ + async def async_turn_on( #@ arg1, arg2 ): pass @@ -115,7 +115,7 @@ def test_ignore_no_annotations( func_node = astroid.extract_node( code, - "homeassistant.components.pylint_test", + "homeassistant.components.pylint_test.light", ) type_hint_checker.visit_module(func_node.parent)