diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index 9ec2fa83806..6555510ff58 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -435,6 +435,117 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = { } # Overriding properties and functions are normally checked by mypy, and will only # be checked by pylint when --ignore-missing-annotations is False +_ENTITY_MATCH: list[TypeHintMatch] = [ + TypeHintMatch( + function_name="should_poll", + return_type="bool", + ), + TypeHintMatch( + function_name="unique_id", + return_type=["str", None], + ), + TypeHintMatch( + function_name="name", + return_type=["str", None], + ), + TypeHintMatch( + function_name="state", + return_type=["StateType", None, "str", "int", "float"], + ), + TypeHintMatch( + function_name="capability_attributes", + return_type=["Mapping[str, Any]", None], + ), + TypeHintMatch( + function_name="state_attributes", + return_type=["dict[str, Any]", None], + ), + TypeHintMatch( + function_name="device_state_attributes", + return_type=["Mapping[str, Any]", None], + ), + TypeHintMatch( + function_name="extra_state_attributes", + return_type=["Mapping[str, Any]", None], + ), + TypeHintMatch( + function_name="device_info", + return_type=["DeviceInfo", None], + ), + TypeHintMatch( + function_name="device_class", + return_type=["str", None], + ), + TypeHintMatch( + function_name="unit_of_measurement", + return_type=["str", None], + ), + TypeHintMatch( + function_name="icon", + return_type=["str", None], + ), + TypeHintMatch( + function_name="entity_picture", + return_type=["str", None], + ), + TypeHintMatch( + function_name="available", + return_type="bool", + ), + TypeHintMatch( + function_name="assumed_state", + return_type="bool", + ), + TypeHintMatch( + function_name="force_update", + return_type="bool", + ), + TypeHintMatch( + function_name="supported_features", + return_type=["int", None], + ), + TypeHintMatch( + function_name="context_recent_time", + return_type="timedelta", + ), + TypeHintMatch( + function_name="entity_registry_enabled_default", + return_type="bool", + ), + TypeHintMatch( + function_name="entity_registry_visible_default", + return_type="bool", + ), + TypeHintMatch( + function_name="attribution", + return_type=["str", None], + ), + TypeHintMatch( + function_name="entity_category", + return_type=["EntityCategory", None], + ), + TypeHintMatch( + function_name="async_removed_from_registry", + return_type=None, + ), + TypeHintMatch( + function_name="async_added_to_hass", + return_type=None, + ), + TypeHintMatch( + function_name="async_will_remove_from_hass", + return_type=None, + ), + TypeHintMatch( + function_name="async_registry_entry_updated", + return_type=None, + ), + TypeHintMatch( + function_name="update", + return_type=None, + has_async_counterpart=True, + ), +] _TOGGLE_ENTITY_MATCH: list[TypeHintMatch] = [ TypeHintMatch( function_name="is_on", @@ -461,6 +572,10 @@ _TOGGLE_ENTITY_MATCH: list[TypeHintMatch] = [ ] _INHERITANCE_MATCH: dict[str, list[ClassTypeHintMatch]] = { "fan": [ + ClassTypeHintMatch( + base_class="Entity", + matches=_ENTITY_MATCH, + ), ClassTypeHintMatch( base_class="ToggleEntity", matches=_TOGGLE_ENTITY_MATCH, @@ -488,14 +603,6 @@ _INHERITANCE_MATCH: dict[str, list[ClassTypeHintMatch]] = { function_name="oscillating", return_type=["bool", None], ), - TypeHintMatch( - function_name="capability_attributes", - return_type="dict[str]", - ), - TypeHintMatch( - function_name="supported_features", - return_type="int", - ), TypeHintMatch( function_name="preset_mode", return_type=["str", None], @@ -542,6 +649,10 @@ _INHERITANCE_MATCH: dict[str, list[ClassTypeHintMatch]] = { ), ], "lock": [ + ClassTypeHintMatch( + base_class="Entity", + matches=_ENTITY_MATCH, + ), ClassTypeHintMatch( base_class="LockEntity", matches=[ @@ -594,7 +705,9 @@ _INHERITANCE_MATCH: dict[str, list[ClassTypeHintMatch]] = { def _is_valid_type( - expected_type: list[str] | str | None | object, node: nodes.NodeNG + expected_type: list[str] | str | None | object, + node: nodes.NodeNG, + in_return: bool = False, ) -> bool: """Check the argument node against the expected type.""" if expected_type is UNDEFINED: @@ -602,7 +715,7 @@ def _is_valid_type( if isinstance(expected_type, list): for expected_type_item in expected_type: - if _is_valid_type(expected_type_item, node): + if _is_valid_type(expected_type_item, node, in_return): return True return False @@ -638,6 +751,18 @@ def _is_valid_type( # Special case for xxx[yyy, zzz]` if match := _TYPE_HINT_MATCHERS["x_of_y_comma_z"].match(expected_type): + # Handle special case of Mapping[xxx, Any] + if in_return and match.group(1) == "Mapping" and match.group(3) == "Any": + return ( + isinstance(node, nodes.Subscript) + and isinstance(node.value, nodes.Name) + # We accept dict when Mapping is needed + and node.value.name in ("Mapping", "dict") + and isinstance(node.slice, nodes.Tuple) + and _is_valid_type(match.group(2), node.slice.elts[0]) + # Ignore second item + # and _is_valid_type(match.group(3), node.slice.elts[1]) + ) return ( isinstance(node, nodes.Subscript) and _is_valid_type(match.group(1), node.value) @@ -663,7 +788,7 @@ def _is_valid_type( def _is_valid_return_type(match: TypeHintMatch, node: nodes.NodeNG) -> bool: - if _is_valid_type(match.return_type, node): + if _is_valid_type(match.return_type, node, True): return True if isinstance(node, nodes.BinOp): diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index 54c7cf6ec4c..8b4b8d4d058 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -635,3 +635,105 @@ def test_named_arguments( ), ): type_hint_checker.visit_classdef(class_node) + + +@pytest.mark.parametrize( + "return_hint", + [ + "", + "-> Mapping[int, int]", + "-> dict[int, Any]", + ], +) +def test_invalid_mapping_return_type( + linter: UnittestLinter, + type_hint_checker: BaseChecker, + return_hint: str, +) -> None: + """Check that Mapping[xxx, Any] doesn't accept invalid Mapping or dict.""" + # Set bypass option + type_hint_checker.config.ignore_missing_annotations = False + + class_node, property_node = astroid.extract_node( + f""" + class Entity(): + pass + + class ToggleEntity(Entity): + pass + + class FanEntity(ToggleEntity): + pass + + class MyFanA( #@ + FanEntity + ): + @property + def capability_attributes( #@ + self + ){return_hint}: + pass + """, + "homeassistant.components.pylint_test.fan", + ) + type_hint_checker.visit_module(class_node.parent) + + with assert_adds_messages( + linter, + pylint.testutils.MessageTest( + msg_id="hass-return-type", + node=property_node, + args=["Mapping[str, Any]", None], + line=15, + col_offset=4, + end_line=15, + end_col_offset=29, + ), + ): + type_hint_checker.visit_classdef(class_node) + + +@pytest.mark.parametrize( + "return_hint", + [ + "-> Mapping[str, Any]", + "-> Mapping[str, bool | int]", + "-> dict[str, Any]", + "-> dict[str, str]", + ], +) +def test_valid_mapping_return_type( + linter: UnittestLinter, + type_hint_checker: BaseChecker, + return_hint: str, +) -> None: + """Check that Mapping[xxx, Any] accepts both Mapping and dict.""" + # Set bypass option + type_hint_checker.config.ignore_missing_annotations = False + + class_node = astroid.extract_node( + f""" + class Entity(): + pass + + class ToggleEntity(Entity): + pass + + class FanEntity(ToggleEntity): + pass + + class MyFanA( #@ + FanEntity + ): + @property + def capability_attributes( + self + ){return_hint}: + pass + """, + "homeassistant.components.pylint_test.fan", + ) + type_hint_checker.visit_module(class_node.parent) + + with assert_no_messages(linter): + type_hint_checker.visit_classdef(class_node)