diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index ca7777da959..c67ea37447d 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -50,6 +50,9 @@ class TypeHintMatch: kwargs_type: str | None = None """kwargs_type is for the special case `**kwargs`""" has_async_counterpart: bool = False + """Used to check both `function_name` and `async_function_name` functions""" + compulsory: bool = False + """Used to bypass ignore_missing_annotations""" def need_to_check_function(self, node: nodes.FunctionDef) -> bool: """Confirm if function should be checked.""" @@ -184,6 +187,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { }, return_type="bool", has_async_counterpart=True, + compulsory=True, ), TypeHintMatch( function_name="async_setup_entry", @@ -192,6 +196,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 1: "ConfigEntry", }, return_type="bool", + compulsory=True, ), TypeHintMatch( function_name="async_remove_entry", @@ -200,6 +205,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 1: "ConfigEntry", }, return_type=None, + compulsory=True, ), TypeHintMatch( function_name="async_unload_entry", @@ -208,6 +214,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 1: "ConfigEntry", }, return_type="bool", + compulsory=True, ), TypeHintMatch( function_name="async_migrate_entry", @@ -216,6 +223,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 1: "ConfigEntry", }, return_type="bool", + compulsory=True, ), TypeHintMatch( function_name="async_remove_config_entry_device", @@ -225,6 +233,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 2: "DeviceEntry", }, return_type="bool", + compulsory=True, ), TypeHintMatch( function_name="async_reset_platform", @@ -233,6 +242,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 1: "str", }, return_type=None, + compulsory=True, ), ], "__any_platform__": [ @@ -246,6 +256,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { }, return_type=None, has_async_counterpart=True, + compulsory=True, ), TypeHintMatch( function_name="async_setup_entry", @@ -255,6 +266,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 2: "AddConfigEntryEntitiesCallback", }, return_type=None, + compulsory=True, ), ], "application_credentials": [ @@ -3185,8 +3197,11 @@ class HassTypeHintChecker(BaseChecker): self._class_matchers.reverse() - def _ignore_function( - self, node: nodes.FunctionDef, annotations: list[nodes.NodeNG | None] + def _ignore_function_match( + self, + node: nodes.FunctionDef, + annotations: list[nodes.NodeNG | None], + match: TypeHintMatch, ) -> bool: """Check if we can skip the function validation.""" return ( @@ -3194,6 +3209,8 @@ class HassTypeHintChecker(BaseChecker): not self._in_test_module # some modules have checks forced and self._module_platform not in _FORCE_ANNOTATION_PLATFORMS + # some matches have checks forced + and not match.compulsory # other modules are only checked ignore_missing_annotations and self.linter.config.ignore_missing_annotations and node.returns is None @@ -3236,7 +3253,7 @@ class HassTypeHintChecker(BaseChecker): continue annotations = _get_all_annotations(function_node) - if self._ignore_function(function_node, annotations): + if self._ignore_function_match(function_node, annotations, match): continue self._check_function(function_node, match, annotations) @@ -3245,8 +3262,6 @@ class HassTypeHintChecker(BaseChecker): def visit_functiondef(self, node: nodes.FunctionDef) -> None: """Apply relevant type hint checks on a FunctionDef node.""" annotations = _get_all_annotations(node) - if self._ignore_function(node, annotations): - return # Check method or function matchers. if node.is_method(): @@ -3267,14 +3282,15 @@ class HassTypeHintChecker(BaseChecker): matchers = self._function_matchers # Check that common arguments are correctly typed. - for arg_name, expected_type in _COMMON_ARGUMENTS.items(): - arg_node, annotation = _get_named_annotation(node, arg_name) - if arg_node and not _is_valid_type(expected_type, annotation): - self.add_message( - "hass-argument-type", - node=arg_node, - args=(arg_name, expected_type, node.name), - ) + if not self.linter.config.ignore_missing_annotations: + for arg_name, expected_type in _COMMON_ARGUMENTS.items(): + arg_node, annotation = _get_named_annotation(node, arg_name) + if arg_node and not _is_valid_type(expected_type, annotation): + self.add_message( + "hass-argument-type", + node=arg_node, + args=(arg_name, expected_type, node.name), + ) for match in matchers: if not match.need_to_check_function(node): @@ -3289,6 +3305,8 @@ class HassTypeHintChecker(BaseChecker): match: TypeHintMatch, annotations: list[nodes.NodeNG | None], ) -> None: + if self._ignore_function_match(node, annotations, match): + return # Check that all positional arguments are correctly annotated. if match.arg_types: for key, expected_type in match.arg_types.items(): diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index efa3ca9523a..3ceaf4a7e99 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -98,7 +98,7 @@ def test_regex_a_or_b( "code", [ """ - async def setup( #@ + async def async_turn_on( #@ arg1, arg2 ): pass @@ -114,7 +114,7 @@ def test_ignore_no_annotations( func_node = astroid.extract_node( code, - "homeassistant.components.pylint_test", + "homeassistant.components.pylint_test.light", ) type_hint_checker.visit_module(func_node.parent)