mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 03:07:37 +00:00
Enforce type hints for config_flow (#72756)
* Enforce type hints for config_flow * Keep astroid migration for another PR * Defer elif case * Adjust tests * Use ancestors * Match on single base_class * Invert for loops * Review comments * slots is new in 3.10
This commit is contained in:
parent
5d2326386d
commit
4c7837a576
@ -25,6 +25,14 @@ class TypeHintMatch:
|
|||||||
return_type: list[str] | str | None | object
|
return_type: list[str] | str | None | object
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClassTypeHintMatch:
|
||||||
|
"""Class for pattern matching."""
|
||||||
|
|
||||||
|
base_class: str
|
||||||
|
matches: list[TypeHintMatch]
|
||||||
|
|
||||||
|
|
||||||
_TYPE_HINT_MATCHERS: dict[str, re.Pattern[str]] = {
|
_TYPE_HINT_MATCHERS: dict[str, re.Pattern[str]] = {
|
||||||
# a_or_b matches items such as "DiscoveryInfoType | None"
|
# a_or_b matches items such as "DiscoveryInfoType | None"
|
||||||
"a_or_b": re.compile(r"^(\w+) \| (\w+)$"),
|
"a_or_b": re.compile(r"^(\w+) \| (\w+)$"),
|
||||||
@ -368,6 +376,65 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = {
|
||||||
|
"config_flow": [
|
||||||
|
ClassTypeHintMatch(
|
||||||
|
base_class="ConfigFlow",
|
||||||
|
matches=[
|
||||||
|
TypeHintMatch(
|
||||||
|
function_name="async_step_dhcp",
|
||||||
|
arg_types={
|
||||||
|
1: "DhcpServiceInfo",
|
||||||
|
},
|
||||||
|
return_type="FlowResult",
|
||||||
|
),
|
||||||
|
TypeHintMatch(
|
||||||
|
function_name="async_step_hassio",
|
||||||
|
arg_types={
|
||||||
|
1: "HassioServiceInfo",
|
||||||
|
},
|
||||||
|
return_type="FlowResult",
|
||||||
|
),
|
||||||
|
TypeHintMatch(
|
||||||
|
function_name="async_step_homekit",
|
||||||
|
arg_types={
|
||||||
|
1: "ZeroconfServiceInfo",
|
||||||
|
},
|
||||||
|
return_type="FlowResult",
|
||||||
|
),
|
||||||
|
TypeHintMatch(
|
||||||
|
function_name="async_step_mqtt",
|
||||||
|
arg_types={
|
||||||
|
1: "MqttServiceInfo",
|
||||||
|
},
|
||||||
|
return_type="FlowResult",
|
||||||
|
),
|
||||||
|
TypeHintMatch(
|
||||||
|
function_name="async_step_ssdp",
|
||||||
|
arg_types={
|
||||||
|
1: "SsdpServiceInfo",
|
||||||
|
},
|
||||||
|
return_type="FlowResult",
|
||||||
|
),
|
||||||
|
TypeHintMatch(
|
||||||
|
function_name="async_step_usb",
|
||||||
|
arg_types={
|
||||||
|
1: "UsbServiceInfo",
|
||||||
|
},
|
||||||
|
return_type="FlowResult",
|
||||||
|
),
|
||||||
|
TypeHintMatch(
|
||||||
|
function_name="async_step_zeroconf",
|
||||||
|
arg_types={
|
||||||
|
1: "ZeroconfServiceInfo",
|
||||||
|
},
|
||||||
|
return_type="FlowResult",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _is_valid_type(
|
def _is_valid_type(
|
||||||
expected_type: list[str] | str | None | object, node: astroid.NodeNG
|
expected_type: list[str] | str | None | object, node: astroid.NodeNG
|
||||||
@ -494,10 +561,12 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc]
|
|||||||
def __init__(self, linter: PyLinter | None = None) -> None:
|
def __init__(self, linter: PyLinter | None = None) -> None:
|
||||||
super().__init__(linter)
|
super().__init__(linter)
|
||||||
self._function_matchers: list[TypeHintMatch] = []
|
self._function_matchers: list[TypeHintMatch] = []
|
||||||
|
self._class_matchers: list[ClassTypeHintMatch] = []
|
||||||
|
|
||||||
def visit_module(self, node: astroid.Module) -> None:
|
def visit_module(self, node: astroid.Module) -> None:
|
||||||
"""Called when a Module node is visited."""
|
"""Called when a Module node is visited."""
|
||||||
self._function_matchers = []
|
self._function_matchers = []
|
||||||
|
self._class_matchers = []
|
||||||
|
|
||||||
if (module_platform := _get_module_platform(node.name)) is None:
|
if (module_platform := _get_module_platform(node.name)) is None:
|
||||||
return
|
return
|
||||||
@ -505,8 +574,28 @@ class HassTypeHintChecker(BaseChecker): # type: ignore[misc]
|
|||||||
if module_platform in _PLATFORMS:
|
if module_platform in _PLATFORMS:
|
||||||
self._function_matchers.extend(_FUNCTION_MATCH["__any_platform__"])
|
self._function_matchers.extend(_FUNCTION_MATCH["__any_platform__"])
|
||||||
|
|
||||||
if matches := _FUNCTION_MATCH.get(module_platform):
|
if function_matches := _FUNCTION_MATCH.get(module_platform):
|
||||||
self._function_matchers.extend(matches)
|
self._function_matchers.extend(function_matches)
|
||||||
|
|
||||||
|
if class_matches := _CLASS_MATCH.get(module_platform):
|
||||||
|
self._class_matchers = class_matches
|
||||||
|
|
||||||
|
def visit_classdef(self, node: astroid.ClassDef) -> None:
|
||||||
|
"""Called when a ClassDef node is visited."""
|
||||||
|
ancestor: astroid.ClassDef
|
||||||
|
for ancestor in node.ancestors():
|
||||||
|
for class_matches in self._class_matchers:
|
||||||
|
if ancestor.name == class_matches.base_class:
|
||||||
|
self._visit_class_functions(node, class_matches.matches)
|
||||||
|
|
||||||
|
def _visit_class_functions(
|
||||||
|
self, node: astroid.ClassDef, matches: list[TypeHintMatch]
|
||||||
|
) -> None:
|
||||||
|
for match in matches:
|
||||||
|
for function_node in node.mymethods():
|
||||||
|
function_name: str | None = function_node.name
|
||||||
|
if match.function_name == function_name:
|
||||||
|
self._check_function(function_node, match)
|
||||||
|
|
||||||
def visit_functiondef(self, node: astroid.FunctionDef) -> None:
|
def visit_functiondef(self, node: astroid.FunctionDef) -> None:
|
||||||
"""Called when a FunctionDef node is visited."""
|
"""Called when a FunctionDef node is visited."""
|
||||||
|
@ -277,3 +277,75 @@ def test_valid_list_dict_str_any(
|
|||||||
|
|
||||||
with assert_no_messages(linter):
|
with assert_no_messages(linter):
|
||||||
type_hint_checker.visit_asyncfunctiondef(func_node)
|
type_hint_checker.visit_asyncfunctiondef(func_node)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_config_flow_step(
|
||||||
|
linter: UnittestLinter, type_hint_checker: BaseChecker
|
||||||
|
) -> None:
|
||||||
|
"""Ensure invalid hints are rejected for ConfigFlow step."""
|
||||||
|
class_node, func_node, arg_node = astroid.extract_node(
|
||||||
|
"""
|
||||||
|
class ConfigFlow():
|
||||||
|
pass
|
||||||
|
|
||||||
|
class AxisFlowHandler( #@
|
||||||
|
ConfigFlow, domain=AXIS_DOMAIN
|
||||||
|
):
|
||||||
|
async def async_step_zeroconf( #@
|
||||||
|
self,
|
||||||
|
device_config: dict #@
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
""",
|
||||||
|
"homeassistant.components.pylint_test.config_flow",
|
||||||
|
)
|
||||||
|
type_hint_checker.visit_module(class_node.parent)
|
||||||
|
|
||||||
|
with assert_adds_messages(
|
||||||
|
linter,
|
||||||
|
pylint.testutils.MessageTest(
|
||||||
|
msg_id="hass-argument-type",
|
||||||
|
node=arg_node,
|
||||||
|
args=(2, "ZeroconfServiceInfo"),
|
||||||
|
line=10,
|
||||||
|
col_offset=8,
|
||||||
|
end_line=10,
|
||||||
|
end_col_offset=27,
|
||||||
|
),
|
||||||
|
pylint.testutils.MessageTest(
|
||||||
|
msg_id="hass-return-type",
|
||||||
|
node=func_node,
|
||||||
|
args="FlowResult",
|
||||||
|
line=8,
|
||||||
|
col_offset=4,
|
||||||
|
end_line=8,
|
||||||
|
end_col_offset=33,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
type_hint_checker.visit_classdef(class_node)
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_config_flow_step(
|
||||||
|
linter: UnittestLinter, type_hint_checker: BaseChecker
|
||||||
|
) -> None:
|
||||||
|
"""Ensure valid hints are accepted for ConfigFlow step."""
|
||||||
|
class_node = astroid.extract_node(
|
||||||
|
"""
|
||||||
|
class ConfigFlow():
|
||||||
|
pass
|
||||||
|
|
||||||
|
class AxisFlowHandler( #@
|
||||||
|
ConfigFlow, domain=AXIS_DOMAIN
|
||||||
|
):
|
||||||
|
async def async_step_zeroconf(
|
||||||
|
self,
|
||||||
|
device_config: ZeroconfServiceInfo
|
||||||
|
) -> FlowResult:
|
||||||
|
pass
|
||||||
|
""",
|
||||||
|
"homeassistant.components.pylint_test.config_flow",
|
||||||
|
)
|
||||||
|
type_hint_checker.visit_module(class_node.parent)
|
||||||
|
|
||||||
|
with assert_no_messages(linter):
|
||||||
|
type_hint_checker.visit_classdef(class_node)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user