From 982d197ff3c493024846f38d8688985b6acd3707 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Wed, 10 Aug 2022 16:30:58 +0200 Subject: [PATCH] Add number checks to pylint plugin (#76457) * Add number checks to pylint plugin * Adjust ancestor checks * Add tests * Add comments in tests --- pylint/plugins/hass_enforce_type_hints.py | 65 ++++++++++++++++++++++- tests/pylint/test_enforce_type_hints.py | 37 +++++++++++++ 2 files changed, 100 insertions(+), 2 deletions(-) diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index 257f4fb3613..852c5b544c4 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -1465,6 +1465,55 @@ _INHERITANCE_MATCH: dict[str, list[ClassTypeHintMatch]] = { ], ), ], + "number": [ + ClassTypeHintMatch( + base_class="Entity", + matches=_ENTITY_MATCH, + ), + ClassTypeHintMatch( + base_class="NumberEntity", + matches=[ + TypeHintMatch( + function_name="device_class", + return_type=["NumberDeviceClass", "str", None], + ), + TypeHintMatch( + function_name="capability_attributes", + return_type="dict[str, Any]", + ), + TypeHintMatch( + function_name="native_min_value", + return_type="float", + ), + TypeHintMatch( + function_name="native_max_value", + return_type="float", + ), + TypeHintMatch( + function_name="native_step", + return_type=["float", None], + ), + TypeHintMatch( + function_name="mode", + return_type="NumberMode", + ), + TypeHintMatch( + function_name="native_unit_of_measurement", + return_type=["str", None], + ), + TypeHintMatch( + function_name="native_value", + return_type=["float", None], + ), + TypeHintMatch( + function_name="set_native_value", + arg_types={1: "float"}, + return_type=None, + has_async_counterpart=True, + ), + ], + ), + ], "select": [ ClassTypeHintMatch( base_class="Entity", @@ -1599,6 +1648,15 @@ def _is_valid_type( and _is_valid_type(match.group(2), node.slice) ) + # Special case for float in return type + if ( + expected_type == "float" + and in_return + and isinstance(node, nodes.Name) + and node.name in ("float", "int") + ): + return True + # Name occurs when a namespace is not used, eg. "HomeAssistant" if isinstance(node, nodes.Name) and node.name == expected_type: return True @@ -1737,12 +1795,15 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc] ): self._class_matchers.extend(property_matches) + self._class_matchers.reverse() + def visit_classdef(self, node: nodes.ClassDef) -> None: """Called when a ClassDef node is visited.""" ancestor: nodes.ClassDef checked_class_methods: set[str] = set() - for ancestor in node.ancestors(): - for class_matches in self._class_matchers: + ancestors = list(node.ancestors()) # cache result for inside loop + for class_matches in self._class_matchers: + for ancestor in ancestors: if ancestor.name == class_matches.base_class: self._visit_class_functions( node, class_matches.matches, checked_class_methods diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index d9edde9fdee..b3c233d1c3b 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -900,3 +900,40 @@ def test_invalid_device_class( ), ): type_hint_checker.visit_classdef(class_node) + + +def test_number_entity(linter: UnittestLinter, type_hint_checker: BaseChecker) -> None: + """Ensure valid hints are accepted for number entity.""" + # Set bypass option + type_hint_checker.config.ignore_missing_annotations = False + + # Ensure that device class is valid despite Entity inheritance + # Ensure that `int` is valid for `float` return type + class_node = astroid.extract_node( + """ + class Entity(): + pass + + class RestoreEntity(Entity): + pass + + class NumberEntity(Entity): + pass + + class MyNumber( #@ + RestoreEntity, NumberEntity + ): + @property + def device_class(self) -> NumberDeviceClass: + pass + + @property + def native_value(self) -> int: + pass + """, + "homeassistant.components.pylint_test.number", + ) + type_hint_checker.visit_module(class_node.parent) + + with assert_no_messages(linter): + type_hint_checker.visit_classdef(class_node)