diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index ba0a511c571..8dbb041fa99 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -42,7 +42,6 @@ class TypeHintMatch: """named_arg_types is for named or keyword arguments""" kwargs_type: str | None = None """kwargs_type is for the special case `**kwargs`""" - check_return_type_inheritance: bool = False has_async_counterpart: bool = False def need_to_check_function(self, node: nodes.FunctionDef) -> bool: @@ -398,7 +397,6 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 1: "ConfigType", }, return_type=["DeviceScanner", None], - check_return_type_inheritance=True, has_async_counterpart=True, ), ], @@ -466,7 +464,6 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = { 2: "DiscoveryInfoType | None", }, return_type=["BaseNotificationService", None], - check_return_type_inheritance=True, has_async_counterpart=True, ), ], @@ -493,7 +490,6 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = { 0: "ConfigEntry", }, return_type="OptionsFlow", - check_return_type_inheritance=True, ), TypeHintMatch( function_name="async_step_dhcp", @@ -681,7 +677,6 @@ _RESTORE_ENTITY_MATCH: list[TypeHintMatch] = [ TypeHintMatch( function_name="extra_restore_state_data", return_type=["ExtraStoredData", None], - check_return_type_inheritance=True, ), ] _TOGGLE_ENTITY_MATCH: list[TypeHintMatch] = [ @@ -2842,15 +2837,13 @@ def _is_valid_return_type(match: TypeHintMatch, node: nodes.NodeNG) -> bool: match, node.right ) - if ( - match.check_return_type_inheritance - and isinstance(match.return_type, (str, list)) - and isinstance(node, nodes.Name) - ): + if isinstance(match.return_type, (str, list)) and isinstance(node, nodes.Name): if isinstance(match.return_type, str): valid_types = {match.return_type} else: valid_types = {el for el in match.return_type if isinstance(el, str)} + if "Mapping[str, Any]" in valid_types: + valid_types.add("TypedDict") try: for infer_node in node.infer(): diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index 365ccc111d0..c580658b542 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -724,6 +724,7 @@ def test_invalid_mapping_return_type( "-> Mapping[str, bool | int]", "-> dict[str, Any]", "-> dict[str, str]", + "-> CustomTypedDict", ], ) def test_valid_mapping_return_type( @@ -737,6 +738,11 @@ def test_valid_mapping_return_type( class_node = astroid.extract_node( f""" + from typing import TypedDict + + class CustomTypedDict(TypedDict): + pass + class Entity(): pass