Adjust pylint checks for notify get_service (#77606)

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
epenet 2022-09-07 09:44:15 +02:00 committed by GitHub
parent a82484d7a7
commit 9fb0b3995c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 8 deletions

View File

@ -4,13 +4,20 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
import re import re
from typing import TYPE_CHECKING
from astroid import nodes from astroid import nodes
from astroid.exceptions import NameInferenceError
from pylint.checkers import BaseChecker from pylint.checkers import BaseChecker
from pylint.lint import PyLinter from pylint.lint import PyLinter
from homeassistant.const import Platform from homeassistant.const import Platform
if TYPE_CHECKING:
# InferenceResult is available only from astroid >= 2.12.0
# pre-commit should still work on out of date environments
from astroid.typing import InferenceResult
class _Special(Enum): class _Special(Enum):
"""Sentinel values""" """Sentinel values"""
@ -387,7 +394,8 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
1: "ConfigType", 1: "ConfigType",
2: "DiscoveryInfoType | None", 2: "DiscoveryInfoType | None",
}, },
return_type=_Special.UNDEFINED, return_type=["BaseNotificationService", None],
check_return_type_inheritance=True,
has_async_counterpart=True, has_async_counterpart=True,
), ),
], ],
@ -2534,21 +2542,39 @@ def _is_valid_return_type(match: TypeHintMatch, node: nodes.NodeNG) -> bool:
if ( if (
match.check_return_type_inheritance match.check_return_type_inheritance
and isinstance(match.return_type, str) and isinstance(match.return_type, (str, list))
and isinstance(node, nodes.Name) and isinstance(node, nodes.Name)
): ):
ancestor: nodes.ClassDef if isinstance(match.return_type, str):
for infer_node in node.infer(): valid_types = {match.return_type}
if isinstance(infer_node, nodes.ClassDef): else:
if infer_node.name == match.return_type: valid_types = {el for el in match.return_type if isinstance(el, str)}
try:
for infer_node in node.infer():
if _check_ancestry(infer_node, valid_types):
return True return True
for ancestor in infer_node.ancestors(): except NameInferenceError:
if ancestor.name == match.return_type: for class_node in node.root().nodes_of_class(nodes.ClassDef):
if class_node.name != node.name:
continue
for infer_node in class_node.infer():
if _check_ancestry(infer_node, valid_types):
return True return True
return False return False
def _check_ancestry(infer_node: InferenceResult, valid_types: set[str]) -> bool:
if isinstance(infer_node, nodes.ClassDef):
if infer_node.name in valid_types:
return True
for ancestor in infer_node.ancestors():
if ancestor.name in valid_types:
return True
return False
def _get_all_annotations(node: nodes.FunctionDef) -> list[nodes.NodeNG | None]: def _get_all_annotations(node: nodes.FunctionDef) -> list[nodes.NodeNG | None]:
args = node.args args = node.args
annotations: list[nodes.NodeNG | None] = ( annotations: list[nodes.NodeNG | None] = (

View File

@ -1011,3 +1011,32 @@ def test_vacuum_entity(linter: UnittestLinter, type_hint_checker: BaseChecker) -
with assert_no_messages(linter): with assert_no_messages(linter):
type_hint_checker.visit_classdef(class_node) type_hint_checker.visit_classdef(class_node)
def test_notify_get_service(
linter: UnittestLinter, type_hint_checker: BaseChecker
) -> None:
"""Ensure valid hints are accepted for async_get_service."""
func_node = astroid.extract_node(
"""
class BaseNotificationService():
pass
async def async_get_service( #@
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> CustomNotificationService:
pass
class CustomNotificationService(BaseNotificationService):
pass
""",
"homeassistant.components.pylint_test.notify",
)
type_hint_checker.visit_module(func_node.parent)
with assert_no_messages(
linter,
):
type_hint_checker.visit_asyncfunctiondef(func_node)