diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index 2418b97c198..717534626af 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -4,13 +4,20 @@ from __future__ import annotations from dataclasses import dataclass from enum import Enum import re +from typing import TYPE_CHECKING from astroid import nodes +from astroid.exceptions import NameInferenceError from pylint.checkers import BaseChecker from pylint.lint import PyLinter from homeassistant.const import Platform +if TYPE_CHECKING: + # InferenceResult is available only from astroid >= 2.12.0 + # pre-commit should still work on out of date environments + from astroid.typing import InferenceResult + class _Special(Enum): """Sentinel values""" @@ -387,7 +394,8 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 1: "ConfigType", 2: "DiscoveryInfoType | None", }, - return_type=_Special.UNDEFINED, + return_type=["BaseNotificationService", None], + check_return_type_inheritance=True, has_async_counterpart=True, ), ], @@ -2534,21 +2542,39 @@ def _is_valid_return_type(match: TypeHintMatch, node: nodes.NodeNG) -> bool: if ( match.check_return_type_inheritance - and isinstance(match.return_type, str) + and isinstance(match.return_type, (str, list)) and isinstance(node, nodes.Name) ): - ancestor: nodes.ClassDef - for infer_node in node.infer(): - if isinstance(infer_node, nodes.ClassDef): - if infer_node.name == match.return_type: + if isinstance(match.return_type, str): + valid_types = {match.return_type} + else: + valid_types = {el for el in match.return_type if isinstance(el, str)} + + try: + for infer_node in node.infer(): + if _check_ancestry(infer_node, valid_types): return True - for ancestor in infer_node.ancestors(): - if ancestor.name == match.return_type: + except NameInferenceError: + for class_node in node.root().nodes_of_class(nodes.ClassDef): + if class_node.name != node.name: + continue + for infer_node in class_node.infer(): + if _check_ancestry(infer_node, valid_types): return True return False +def _check_ancestry(infer_node: InferenceResult, valid_types: set[str]) -> bool: + if isinstance(infer_node, nodes.ClassDef): + if infer_node.name in valid_types: + return True + for ancestor in infer_node.ancestors(): + if ancestor.name in valid_types: + return True + return False + + def _get_all_annotations(node: nodes.FunctionDef) -> list[nodes.NodeNG | None]: args = node.args annotations: list[nodes.NodeNG | None] = ( diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index ebea738edc4..9abd2c89a74 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -1011,3 +1011,32 @@ def test_vacuum_entity(linter: UnittestLinter, type_hint_checker: BaseChecker) - with assert_no_messages(linter): type_hint_checker.visit_classdef(class_node) + + +def test_notify_get_service( + linter: UnittestLinter, type_hint_checker: BaseChecker +) -> None: + """Ensure valid hints are accepted for async_get_service.""" + func_node = astroid.extract_node( + """ + class BaseNotificationService(): + pass + + async def async_get_service( #@ + hass: HomeAssistant, + config: ConfigType, + discovery_info: DiscoveryInfoType | None = None, + ) -> CustomNotificationService: + pass + + class CustomNotificationService(BaseNotificationService): + pass + """, + "homeassistant.components.pylint_test.notify", + ) + type_hint_checker.visit_module(func_node.parent) + + with assert_no_messages( + linter, + ): + type_hint_checker.visit_asyncfunctiondef(func_node)