Adjust pylint checks for notify get_service (#77606)

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
epenet 2022-09-07 09:44:15 +02:00 committed by GitHub
parent a82484d7a7
commit 9fb0b3995c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 8 deletions

View File

@ -4,13 +4,20 @@ from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
import re
from typing import TYPE_CHECKING
from astroid import nodes
from astroid.exceptions import NameInferenceError
from pylint.checkers import BaseChecker
from pylint.lint import PyLinter
from homeassistant.const import Platform
if TYPE_CHECKING:
# InferenceResult is available only from astroid >= 2.12.0
# pre-commit should still work on out of date environments
from astroid.typing import InferenceResult
class _Special(Enum):
"""Sentinel values"""
@ -387,7 +394,8 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
1: "ConfigType",
2: "DiscoveryInfoType | None",
},
return_type=_Special.UNDEFINED,
return_type=["BaseNotificationService", None],
check_return_type_inheritance=True,
has_async_counterpart=True,
),
],
@ -2534,21 +2542,39 @@ def _is_valid_return_type(match: TypeHintMatch, node: nodes.NodeNG) -> bool:
if (
match.check_return_type_inheritance
and isinstance(match.return_type, str)
and isinstance(match.return_type, (str, list))
and isinstance(node, nodes.Name)
):
ancestor: nodes.ClassDef
for infer_node in node.infer():
if isinstance(infer_node, nodes.ClassDef):
if infer_node.name == match.return_type:
if isinstance(match.return_type, str):
valid_types = {match.return_type}
else:
valid_types = {el for el in match.return_type if isinstance(el, str)}
try:
for infer_node in node.infer():
if _check_ancestry(infer_node, valid_types):
return True
for ancestor in infer_node.ancestors():
if ancestor.name == match.return_type:
except NameInferenceError:
for class_node in node.root().nodes_of_class(nodes.ClassDef):
if class_node.name != node.name:
continue
for infer_node in class_node.infer():
if _check_ancestry(infer_node, valid_types):
return True
return False
def _check_ancestry(infer_node: InferenceResult, valid_types: set[str]) -> bool:
if isinstance(infer_node, nodes.ClassDef):
if infer_node.name in valid_types:
return True
for ancestor in infer_node.ancestors():
if ancestor.name in valid_types:
return True
return False
def _get_all_annotations(node: nodes.FunctionDef) -> list[nodes.NodeNG | None]:
args = node.args
annotations: list[nodes.NodeNG | None] = (

View File

@ -1011,3 +1011,32 @@ def test_vacuum_entity(linter: UnittestLinter, type_hint_checker: BaseChecker) -
with assert_no_messages(linter):
type_hint_checker.visit_classdef(class_node)
def test_notify_get_service(
linter: UnittestLinter, type_hint_checker: BaseChecker
) -> None:
"""Ensure valid hints are accepted for async_get_service."""
func_node = astroid.extract_node(
"""
class BaseNotificationService():
pass
async def async_get_service( #@
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> CustomNotificationService:
pass
class CustomNotificationService(BaseNotificationService):
pass
""",
"homeassistant.components.pylint_test.notify",
)
type_hint_checker.visit_module(func_node.parent)
with assert_no_messages(
linter,
):
type_hint_checker.visit_asyncfunctiondef(func_node)