Improve type hints in stream (#51837)

* Improve type hints in stream

* Fix import locations

* Add stream to .strict-typing

Co-authored-by: Ruslan Sayfutdinov <ruslan@sayfutdinov.com>
This commit is contained in:
uvjustin 2021-06-14 23:59:25 +08:00 committed by GitHub
parent 7cd57dd156
commit 97e77ab229
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 135 additions and 85 deletions

View File

@ -65,6 +65,7 @@ homeassistant.components.sensor.*
homeassistant.components.slack.* homeassistant.components.slack.*
homeassistant.components.sonos.media_player homeassistant.components.sonos.media_player
homeassistant.components.ssdp.* homeassistant.components.ssdp.*
homeassistant.components.stream.*
homeassistant.components.sun.* homeassistant.components.sun.*
homeassistant.components.switch.* homeassistant.components.switch.*
homeassistant.components.synology_dsm.* homeassistant.components.synology_dsm.*

View File

@ -16,16 +16,19 @@ to always keep workers active.
""" """
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping
import logging import logging
import re import re
import secrets import secrets
import threading import threading
import time import time
from types import MappingProxyType from types import MappingProxyType
from typing import cast
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import (
ATTR_ENDPOINTS, ATTR_ENDPOINTS,
@ -40,18 +43,21 @@ from .const import (
) )
from .core import PROVIDERS, IdleTimer, StreamOutput from .core import PROVIDERS, IdleTimer, StreamOutput
from .hls import async_setup_hls from .hls import async_setup_hls
from .recorder import RecorderOutput
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
STREAM_SOURCE_RE = re.compile("//.*:.*@") STREAM_SOURCE_RE = re.compile("//.*:.*@")
def redact_credentials(data): def redact_credentials(data: str) -> str:
"""Redact credentials from string data.""" """Redact credentials from string data."""
return STREAM_SOURCE_RE.sub("//****:****@", data) return STREAM_SOURCE_RE.sub("//****:****@", data)
def create_stream(hass, stream_source, options=None): def create_stream(
hass: HomeAssistant, stream_source: str, options: dict[str, str]
) -> Stream:
"""Create a stream with the specified identfier based on the source url. """Create a stream with the specified identfier based on the source url.
The stream_source is typically an rtsp url and options are passed into The stream_source is typically an rtsp url and options are passed into
@ -60,9 +66,6 @@ def create_stream(hass, stream_source, options=None):
if DOMAIN not in hass.config.components: if DOMAIN not in hass.config.components:
raise HomeAssistantError("Stream integration is not set up.") raise HomeAssistantError("Stream integration is not set up.")
if options is None:
options = {}
# For RTSP streams, prefer TCP # For RTSP streams, prefer TCP
if isinstance(stream_source, str) and stream_source[:7] == "rtsp://": if isinstance(stream_source, str) and stream_source[:7] == "rtsp://":
options = { options = {
@ -76,7 +79,7 @@ def create_stream(hass, stream_source, options=None):
return stream return stream
async def async_setup(hass, config): async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up stream.""" """Set up stream."""
# Set log level to error for libav # Set log level to error for libav
logging.getLogger("libav").setLevel(logging.ERROR) logging.getLogger("libav").setLevel(logging.ERROR)
@ -98,7 +101,7 @@ async def async_setup(hass, config):
async_setup_recorder(hass) async_setup_recorder(hass)
@callback @callback
def shutdown(event): def shutdown(event: Event) -> None:
"""Stop all stream workers.""" """Stop all stream workers."""
for stream in hass.data[DOMAIN][ATTR_STREAMS]: for stream in hass.data[DOMAIN][ATTR_STREAMS]:
stream.keepalive = False stream.keepalive = False
@ -113,41 +116,43 @@ async def async_setup(hass, config):
class Stream: class Stream:
"""Represents a single stream.""" """Represents a single stream."""
def __init__(self, hass, source, options=None): def __init__(
self, hass: HomeAssistant, source: str, options: dict[str, str]
) -> None:
"""Initialize a stream.""" """Initialize a stream."""
self.hass = hass self.hass = hass
self.source = source self.source = source
self.options = options self.options = options
self.keepalive = False self.keepalive = False
self.access_token = None self.access_token: str | None = None
self._thread = None self._thread: threading.Thread | None = None
self._thread_quit = threading.Event() self._thread_quit = threading.Event()
self._outputs: dict[str, StreamOutput] = {} self._outputs: dict[str, StreamOutput] = {}
self._fast_restart_once = False self._fast_restart_once = False
if self.options is None:
self.options = {}
def endpoint_url(self, fmt: str) -> str: def endpoint_url(self, fmt: str) -> str:
"""Start the stream and returns a url for the output format.""" """Start the stream and returns a url for the output format."""
if fmt not in self._outputs: if fmt not in self._outputs:
raise ValueError(f"Stream is not configured for format '{fmt}'") raise ValueError(f"Stream is not configured for format '{fmt}'")
if not self.access_token: if not self.access_token:
self.access_token = secrets.token_hex() self.access_token = secrets.token_hex()
return self.hass.data[DOMAIN][ATTR_ENDPOINTS][fmt].format(self.access_token) endpoint_fmt: str = self.hass.data[DOMAIN][ATTR_ENDPOINTS][fmt]
return endpoint_fmt.format(self.access_token)
def outputs(self): def outputs(self) -> Mapping[str, StreamOutput]:
"""Return a copy of the stream outputs.""" """Return a copy of the stream outputs."""
# A copy is returned so the caller can iterate through the outputs # A copy is returned so the caller can iterate through the outputs
# without concern about self._outputs being modified from another thread. # without concern about self._outputs being modified from another thread.
return MappingProxyType(self._outputs.copy()) return MappingProxyType(self._outputs.copy())
def add_provider(self, fmt, timeout=OUTPUT_IDLE_TIMEOUT): def add_provider(
self, fmt: str, timeout: int = OUTPUT_IDLE_TIMEOUT
) -> StreamOutput:
"""Add provider output stream.""" """Add provider output stream."""
if not self._outputs.get(fmt): if not self._outputs.get(fmt):
@callback @callback
def idle_callback(): def idle_callback() -> None:
if ( if (
not self.keepalive or fmt == RECORDER_PROVIDER not self.keepalive or fmt == RECORDER_PROVIDER
) and fmt in self._outputs: ) and fmt in self._outputs:
@ -160,7 +165,7 @@ class Stream:
self._outputs[fmt] = provider self._outputs[fmt] = provider
return self._outputs[fmt] return self._outputs[fmt]
def remove_provider(self, provider): def remove_provider(self, provider: StreamOutput) -> None:
"""Remove provider output stream.""" """Remove provider output stream."""
if provider.name in self._outputs: if provider.name in self._outputs:
self._outputs[provider.name].cleanup() self._outputs[provider.name].cleanup()
@ -169,12 +174,12 @@ class Stream:
if not self._outputs: if not self._outputs:
self.stop() self.stop()
def check_idle(self): def check_idle(self) -> None:
"""Reset access token if all providers are idle.""" """Reset access token if all providers are idle."""
if all(p.idle for p in self._outputs.values()): if all(p.idle for p in self._outputs.values()):
self.access_token = None self.access_token = None
def start(self): def start(self) -> None:
"""Start a stream.""" """Start a stream."""
if self._thread is None or not self._thread.is_alive(): if self._thread is None or not self._thread.is_alive():
if self._thread is not None: if self._thread is not None:
@ -189,14 +194,14 @@ class Stream:
self._thread.start() self._thread.start()
_LOGGER.info("Started stream: %s", redact_credentials(str(self.source))) _LOGGER.info("Started stream: %s", redact_credentials(str(self.source)))
def update_source(self, new_source): def update_source(self, new_source: str) -> None:
"""Restart the stream with a new stream source.""" """Restart the stream with a new stream source."""
_LOGGER.debug("Updating stream source %s", new_source) _LOGGER.debug("Updating stream source %s", new_source)
self.source = new_source self.source = new_source
self._fast_restart_once = True self._fast_restart_once = True
self._thread_quit.set() self._thread_quit.set()
def _run_worker(self): def _run_worker(self) -> None:
"""Handle consuming streams and restart keepalive streams.""" """Handle consuming streams and restart keepalive streams."""
# Keep import here so that we can import stream integration without installing reqs # Keep import here so that we can import stream integration without installing reqs
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
@ -229,17 +234,17 @@ class Stream:
) )
self._worker_finished() self._worker_finished()
def _worker_finished(self): def _worker_finished(self) -> None:
"""Schedule cleanup of all outputs.""" """Schedule cleanup of all outputs."""
@callback @callback
def remove_outputs(): def remove_outputs() -> None:
for provider in self.outputs().values(): for provider in self.outputs().values():
self.remove_provider(provider) self.remove_provider(provider)
self.hass.loop.call_soon_threadsafe(remove_outputs) self.hass.loop.call_soon_threadsafe(remove_outputs)
def stop(self): def stop(self) -> None:
"""Remove outputs and access token.""" """Remove outputs and access token."""
self._outputs = {} self._outputs = {}
self.access_token = None self.access_token = None
@ -247,7 +252,7 @@ class Stream:
if not self.keepalive: if not self.keepalive:
self._stop() self._stop()
def _stop(self): def _stop(self) -> None:
"""Stop worker thread.""" """Stop worker thread."""
if self._thread is not None: if self._thread is not None:
self._thread_quit.set() self._thread_quit.set()
@ -255,7 +260,9 @@ class Stream:
self._thread = None self._thread = None
_LOGGER.info("Stopped stream: %s", redact_credentials(str(self.source))) _LOGGER.info("Stopped stream: %s", redact_credentials(str(self.source)))
async def async_record(self, video_path, duration=30, lookback=5): async def async_record(
self, video_path: str, duration: int = 30, lookback: int = 5
) -> None:
"""Make a .mp4 recording from a provided stream.""" """Make a .mp4 recording from a provided stream."""
# Check for file access # Check for file access
@ -265,10 +272,13 @@ class Stream:
# Add recorder # Add recorder
recorder = self.outputs().get(RECORDER_PROVIDER) recorder = self.outputs().get(RECORDER_PROVIDER)
if recorder: if recorder:
assert isinstance(recorder, RecorderOutput)
raise HomeAssistantError( raise HomeAssistantError(
f"Stream already recording to {recorder.video_path}!" f"Stream already recording to {recorder.video_path}!"
) )
recorder = self.add_provider(RECORDER_PROVIDER, timeout=duration) recorder = cast(
RecorderOutput, self.add_provider(RECORDER_PROVIDER, timeout=duration)
)
recorder.video_path = video_path recorder.video_path = video_path
self.start() self.start()

View File

@ -4,18 +4,21 @@ from __future__ import annotations
import asyncio import asyncio
from collections import deque from collections import deque
import datetime import datetime
from typing import Callable from typing import TYPE_CHECKING
from aiohttp import web from aiohttp import web
import attr import attr
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http.view import HomeAssistantView
from homeassistant.core import HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers.event import async_call_later from homeassistant.helpers.event import async_call_later
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
from .const import ATTR_STREAMS, DOMAIN from .const import ATTR_STREAMS, DOMAIN
if TYPE_CHECKING:
from . import Stream
PROVIDERS = Registry() PROVIDERS = Registry()
@ -59,34 +62,34 @@ class IdleTimer:
""" """
def __init__( def __init__(
self, hass: HomeAssistant, timeout: int, idle_callback: Callable[[], None] self, hass: HomeAssistant, timeout: int, idle_callback: CALLBACK_TYPE
) -> None: ) -> None:
"""Initialize IdleTimer.""" """Initialize IdleTimer."""
self._hass = hass self._hass = hass
self._timeout = timeout self._timeout = timeout
self._callback = idle_callback self._callback = idle_callback
self._unsub = None self._unsub: CALLBACK_TYPE | None = None
self.idle = False self.idle = False
def start(self): def start(self) -> None:
"""Start the idle timer if not already started.""" """Start the idle timer if not already started."""
self.idle = False self.idle = False
if self._unsub is None: if self._unsub is None:
self._unsub = async_call_later(self._hass, self._timeout, self.fire) self._unsub = async_call_later(self._hass, self._timeout, self.fire)
def awake(self): def awake(self) -> None:
"""Keep the idle time alive by resetting the timeout.""" """Keep the idle time alive by resetting the timeout."""
self.idle = False self.idle = False
# Reset idle timeout # Reset idle timeout
self.clear() self.clear()
self._unsub = async_call_later(self._hass, self._timeout, self.fire) self._unsub = async_call_later(self._hass, self._timeout, self.fire)
def clear(self): def clear(self) -> None:
"""Clear and disable the timer if it has not already fired.""" """Clear and disable the timer if it has not already fired."""
if self._unsub is not None: if self._unsub is not None:
self._unsub() self._unsub()
def fire(self, _now=None): def fire(self, _now: datetime.datetime) -> None:
"""Invoke the idle timeout callback, called when the alarm fires.""" """Invoke the idle timeout callback, called when the alarm fires."""
self.idle = True self.idle = True
self._unsub = None self._unsub = None
@ -97,7 +100,10 @@ class StreamOutput:
"""Represents a stream output.""" """Represents a stream output."""
def __init__( def __init__(
self, hass: HomeAssistant, idle_timer: IdleTimer, deque_maxlen: int = None self,
hass: HomeAssistant,
idle_timer: IdleTimer,
deque_maxlen: int | None = None,
) -> None: ) -> None:
"""Initialize a stream output.""" """Initialize a stream output."""
self._hass = hass self._hass = hass
@ -172,7 +178,7 @@ class StreamOutput:
self._event.set() self._event.set()
self._event.clear() self._event.clear()
def cleanup(self): def cleanup(self) -> None:
"""Handle cleanup.""" """Handle cleanup."""
self._event.set() self._event.set()
self.idle_timer.clear() self.idle_timer.clear()
@ -190,7 +196,9 @@ class StreamView(HomeAssistantView):
requires_auth = False requires_auth = False
platform = None platform = None
async def get(self, request, token, sequence=None): async def get(
self, request: web.Request, token: str, sequence: str = ""
) -> web.StreamResponse:
"""Start a GET request.""" """Start a GET request."""
hass = request.app["hass"] hass = request.app["hass"]
@ -207,6 +215,8 @@ class StreamView(HomeAssistantView):
return await self.handle(request, stream, sequence) return await self.handle(request, stream, sequence)
async def handle(self, request, stream, sequence): async def handle(
self, request: web.Request, stream: Stream, sequence: str
) -> web.StreamResponse:
"""Handle the stream request.""" """Handle the stream request."""
raise NotImplementedError() raise NotImplementedError()

View File

@ -1,7 +1,11 @@
"""Provide functionality to stream HLS.""" """Provide functionality to stream HLS."""
from __future__ import annotations
from typing import TYPE_CHECKING
from aiohttp import web from aiohttp import web
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from .const import ( from .const import (
EXT_X_START, EXT_X_START,
@ -10,12 +14,15 @@ from .const import (
MAX_SEGMENTS, MAX_SEGMENTS,
NUM_PLAYLIST_SEGMENTS, NUM_PLAYLIST_SEGMENTS,
) )
from .core import PROVIDERS, HomeAssistant, IdleTimer, StreamOutput, StreamView from .core import PROVIDERS, IdleTimer, StreamOutput, StreamView
from .fmp4utils import get_codec_string from .fmp4utils import get_codec_string
if TYPE_CHECKING:
from . import Stream
@callback @callback
def async_setup_hls(hass): def async_setup_hls(hass: HomeAssistant) -> str:
"""Set up api endpoints.""" """Set up api endpoints."""
hass.http.register_view(HlsPlaylistView()) hass.http.register_view(HlsPlaylistView())
hass.http.register_view(HlsSegmentView()) hass.http.register_view(HlsSegmentView())
@ -32,12 +39,13 @@ class HlsMasterPlaylistView(StreamView):
cors_allowed = True cors_allowed = True
@staticmethod @staticmethod
def render(track): def render(track: StreamOutput) -> str:
"""Render M3U8 file.""" """Render M3U8 file."""
# Need to calculate max bandwidth as input_container.bit_rate doesn't seem to work # Need to calculate max bandwidth as input_container.bit_rate doesn't seem to work
# Calculate file size / duration and use a small multiplier to account for variation # Calculate file size / duration and use a small multiplier to account for variation
# hls spec already allows for 25% variation # hls spec already allows for 25% variation
segment = track.get_segment(track.sequences[-2]) if not (segment := track.get_segment(track.sequences[-2])):
return ""
bandwidth = round( bandwidth = round(
(len(segment.init) + sum(len(part.data) for part in segment.parts)) (len(segment.init) + sum(len(part.data) for part in segment.parts))
* 8 * 8
@ -52,7 +60,9 @@ class HlsMasterPlaylistView(StreamView):
] ]
return "\n".join(lines) + "\n" return "\n".join(lines) + "\n"
async def handle(self, request, stream, sequence): async def handle(
self, request: web.Request, stream: Stream, sequence: str
) -> web.Response:
"""Return m3u8 playlist.""" """Return m3u8 playlist."""
track = stream.add_provider(HLS_PROVIDER) track = stream.add_provider(HLS_PROVIDER)
stream.start() stream.start()
@ -73,7 +83,7 @@ class HlsPlaylistView(StreamView):
cors_allowed = True cors_allowed = True
@staticmethod @staticmethod
def render(track): def render(track: StreamOutput) -> str:
"""Render playlist.""" """Render playlist."""
# NUM_PLAYLIST_SEGMENTS+1 because most recent is probably not yet complete # NUM_PLAYLIST_SEGMENTS+1 because most recent is probably not yet complete
segments = list(track.get_segments())[-(NUM_PLAYLIST_SEGMENTS + 1) :] segments = list(track.get_segments())[-(NUM_PLAYLIST_SEGMENTS + 1) :]
@ -130,7 +140,9 @@ class HlsPlaylistView(StreamView):
return "\n".join(playlist) + "\n" return "\n".join(playlist) + "\n"
async def handle(self, request, stream, sequence): async def handle(
self, request: web.Request, stream: Stream, sequence: str
) -> web.Response:
"""Return m3u8 playlist.""" """Return m3u8 playlist."""
track = stream.add_provider(HLS_PROVIDER) track = stream.add_provider(HLS_PROVIDER)
stream.start() stream.start()
@ -154,7 +166,9 @@ class HlsInitView(StreamView):
name = "api:stream:hls:init" name = "api:stream:hls:init"
cors_allowed = True cors_allowed = True
async def handle(self, request, stream, sequence): async def handle(
self, request: web.Request, stream: Stream, sequence: str
) -> web.Response:
"""Return init.mp4.""" """Return init.mp4."""
track = stream.add_provider(HLS_PROVIDER) track = stream.add_provider(HLS_PROVIDER)
if not (segments := track.get_segments()): if not (segments := track.get_segments()):
@ -170,7 +184,9 @@ class HlsSegmentView(StreamView):
name = "api:stream:hls:segment" name = "api:stream:hls:segment"
cors_allowed = True cors_allowed = True
async def handle(self, request, stream, sequence): async def handle(
self, request: web.Request, stream: Stream, sequence: str
) -> web.Response:
"""Return fmp4 segment.""" """Return fmp4 segment."""
track = stream.add_provider(HLS_PROVIDER) track = stream.add_provider(HLS_PROVIDER)
track.idle_timer.awake() track.idle_timer.awake()

View File

@ -23,11 +23,11 @@ _LOGGER = logging.getLogger(__name__)
@callback @callback
def async_setup_recorder(hass): def async_setup_recorder(hass: HomeAssistant) -> None:
"""Only here so Provider Registry works.""" """Only here so Provider Registry works."""
def recorder_save_worker(file_out: str, segments: deque[Segment]): def recorder_save_worker(file_out: str, segments: deque[Segment]) -> None:
"""Handle saving stream.""" """Handle saving stream."""
if not segments: if not segments:
@ -121,7 +121,7 @@ class RecorderOutput(StreamOutput):
def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None: def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None:
"""Initialize recorder output.""" """Initialize recorder output."""
super().__init__(hass, idle_timer) super().__init__(hass, idle_timer)
self.video_path = None self.video_path: str
@property @property
def name(self) -> str: def name(self) -> str:
@ -132,7 +132,7 @@ class RecorderOutput(StreamOutput):
"""Prepend segments to existing list.""" """Prepend segments to existing list."""
self._segments.extendleft(reversed(segments)) self._segments.extendleft(reversed(segments))
def cleanup(self): def cleanup(self) -> None:
"""Write recording and clean up.""" """Write recording and clean up."""
_LOGGER.debug("Starting recorder worker thread") _LOGGER.debug("Starting recorder worker thread")
thread = threading.Thread( thread = threading.Thread(

View File

@ -7,7 +7,7 @@ from fractions import Fraction
from io import BytesIO from io import BytesIO
import logging import logging
from threading import Event from threading import Event
from typing import Callable, cast from typing import Any, Callable, cast
import av import av
@ -45,9 +45,9 @@ class SegmentBuffer:
self._memory_file: BytesIO = cast(BytesIO, None) self._memory_file: BytesIO = cast(BytesIO, None)
self._av_output: av.container.OutputContainer = None self._av_output: av.container.OutputContainer = None
self._input_video_stream: av.video.VideoStream = None self._input_video_stream: av.video.VideoStream = None
self._input_audio_stream = None # av.audio.AudioStream | None self._input_audio_stream: Any | None = None # av.audio.AudioStream | None
self._output_video_stream: av.video.VideoStream = None self._output_video_stream: av.video.VideoStream = None
self._output_audio_stream = None # av.audio.AudioStream | None self._output_audio_stream: Any | None = None # av.audio.AudioStream | None
self._segment: Segment | None = None self._segment: Segment | None = None
self._segment_last_write_pos: int = cast(int, None) self._segment_last_write_pos: int = cast(int, None)
self._part_start_dts: int = cast(int, None) self._part_start_dts: int = cast(int, None)
@ -82,7 +82,7 @@ class SegmentBuffer:
def set_streams( def set_streams(
self, self,
video_stream: av.video.VideoStream, video_stream: av.video.VideoStream,
audio_stream, audio_stream: Any,
# no type hint for audio_stream until https://github.com/PyAV-Org/PyAV/pull/775 is merged # no type hint for audio_stream until https://github.com/PyAV-Org/PyAV/pull/775 is merged
) -> None: ) -> None:
"""Initialize output buffer with streams from container.""" """Initialize output buffer with streams from container."""
@ -206,7 +206,10 @@ class SegmentBuffer:
def stream_worker( # noqa: C901 def stream_worker( # noqa: C901
source: str, options: dict, segment_buffer: SegmentBuffer, quit_event: Event source: str,
options: dict[str, str],
segment_buffer: SegmentBuffer,
quit_event: Event,
) -> None: ) -> None:
"""Handle consuming streams.""" """Handle consuming streams."""
@ -259,7 +262,7 @@ def stream_worker( # noqa: C901
found_audio = False found_audio = False
try: try:
container_packets = container.demux((video_stream, audio_stream)) container_packets = container.demux((video_stream, audio_stream))
first_packet = None first_packet: av.Packet | None = None
# Get to first video keyframe # Get to first video keyframe
while first_packet is None: while first_packet is None:
packet = next(container_packets) packet = next(container_packets)
@ -315,7 +318,6 @@ def stream_worker( # noqa: C901
_LOGGER.warning( _LOGGER.warning(
"Audio stream not found" "Audio stream not found"
) # Some streams declare an audio stream and never send any packets ) # Some streams declare an audio stream and never send any packets
audio_stream = None
except (av.AVError, StopIteration) as ex: except (av.AVError, StopIteration) as ex:
_LOGGER.error( _LOGGER.error(

View File

@ -726,6 +726,17 @@ no_implicit_optional = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.stream.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.sun.*] [mypy-homeassistant.components.sun.*]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true

View File

@ -107,7 +107,7 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync):
# Setup demo HLS track # Setup demo HLS track
source = generate_h264_video() source = generate_h264_video()
stream = create_stream(hass, source) stream = create_stream(hass, source, {})
# Request stream # Request stream
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
@ -148,7 +148,7 @@ async def test_stream_timeout(hass, hass_client, stream_worker_sync):
# Setup demo HLS track # Setup demo HLS track
source = generate_h264_video() source = generate_h264_video()
stream = create_stream(hass, source) stream = create_stream(hass, source, {})
# Request stream # Request stream
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
@ -190,7 +190,7 @@ async def test_stream_timeout_after_stop(hass, hass_client, stream_worker_sync):
# Setup demo HLS track # Setup demo HLS track
source = generate_h264_video() source = generate_h264_video()
stream = create_stream(hass, source) stream = create_stream(hass, source, {})
# Request stream # Request stream
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
@ -212,7 +212,7 @@ async def test_stream_keepalive(hass):
# Setup demo HLS track # Setup demo HLS track
source = "test_stream_keepalive_source" source = "test_stream_keepalive_source"
stream = create_stream(hass, source) stream = create_stream(hass, source, {})
track = stream.add_provider(HLS_PROVIDER) track = stream.add_provider(HLS_PROVIDER)
track.num_segments = 2 track.num_segments = 2
@ -247,7 +247,7 @@ async def test_hls_playlist_view_no_output(hass, hass_client, hls_stream):
"""Test rendering the hls playlist with no output segments.""" """Test rendering the hls playlist with no output segments."""
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
stream = create_stream(hass, STREAM_SOURCE) stream = create_stream(hass, STREAM_SOURCE, {})
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
hls_client = await hls_stream(stream) hls_client = await hls_stream(stream)
@ -261,7 +261,7 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
"""Test rendering the hls playlist with 1 and 2 output segments.""" """Test rendering the hls playlist with 1 and 2 output segments."""
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
stream = create_stream(hass, STREAM_SOURCE) stream = create_stream(hass, STREAM_SOURCE, {})
stream_worker_sync.pause() stream_worker_sync.pause()
hls = stream.add_provider(HLS_PROVIDER) hls = stream.add_provider(HLS_PROVIDER)
@ -295,7 +295,7 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync):
"""Test rendering the hls playlist with more segments than the segment deque can hold.""" """Test rendering the hls playlist with more segments than the segment deque can hold."""
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
stream = create_stream(hass, STREAM_SOURCE) stream = create_stream(hass, STREAM_SOURCE, {})
stream_worker_sync.pause() stream_worker_sync.pause()
hls = stream.add_provider(HLS_PROVIDER) hls = stream.add_provider(HLS_PROVIDER)
@ -347,7 +347,7 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s
"""Test a discontinuity across segments in the stream with 3 segments.""" """Test a discontinuity across segments in the stream with 3 segments."""
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
stream = create_stream(hass, STREAM_SOURCE) stream = create_stream(hass, STREAM_SOURCE, {})
stream_worker_sync.pause() stream_worker_sync.pause()
hls = stream.add_provider(HLS_PROVIDER) hls = stream.add_provider(HLS_PROVIDER)
@ -389,7 +389,7 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy
"""Test a discontinuity with more segments than the segment deque can hold.""" """Test a discontinuity with more segments than the segment deque can hold."""
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
stream = create_stream(hass, STREAM_SOURCE) stream = create_stream(hass, STREAM_SOURCE, {})
stream_worker_sync.pause() stream_worker_sync.pause()
hls = stream.add_provider(HLS_PROVIDER) hls = stream.add_provider(HLS_PROVIDER)

View File

@ -34,7 +34,7 @@ async def test_record_stream(hass, hass_client, record_worker_sync):
# Setup demo track # Setup demo track
source = generate_h264_video() source = generate_h264_video()
stream = create_stream(hass, source) stream = create_stream(hass, source, {})
with patch.object(hass.config, "is_allowed_path", return_value=True): with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path") await stream.async_record("/example/path")
@ -56,7 +56,7 @@ async def test_record_lookback(
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
source = generate_h264_video() source = generate_h264_video()
stream = create_stream(hass, source) stream = create_stream(hass, source, {})
# Start an HLS feed to enable lookback # Start an HLS feed to enable lookback
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
@ -85,7 +85,7 @@ async def test_recorder_timeout(hass, hass_client, stream_worker_sync):
# Setup demo track # Setup demo track
source = generate_h264_video() source = generate_h264_video()
stream = create_stream(hass, source) stream = create_stream(hass, source, {})
with patch.object(hass.config, "is_allowed_path", return_value=True): with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path") await stream.async_record("/example/path")
recorder = stream.add_provider(RECORDER_PROVIDER) recorder = stream.add_provider(RECORDER_PROVIDER)
@ -111,7 +111,7 @@ async def test_record_path_not_allowed(hass, hass_client):
# Setup demo track # Setup demo track
source = generate_h264_video() source = generate_h264_video()
stream = create_stream(hass, source) stream = create_stream(hass, source, {})
with patch.object( with patch.object(
hass.config, "is_allowed_path", return_value=False hass.config, "is_allowed_path", return_value=False
), pytest.raises(HomeAssistantError): ), pytest.raises(HomeAssistantError):
@ -203,7 +203,7 @@ async def test_record_stream_audio(
source = generate_h264_video( source = generate_h264_video(
container_format="mov", audio_codec=a_codec container_format="mov", audio_codec=a_codec
) # mov can store PCM ) # mov can store PCM
stream = create_stream(hass, source) stream = create_stream(hass, source, {})
with patch.object(hass.config, "is_allowed_path", return_value=True): with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path") await stream.async_record("/example/path")
recorder = stream.add_provider(RECORDER_PROVIDER) recorder = stream.add_provider(RECORDER_PROVIDER)
@ -234,7 +234,7 @@ async def test_record_stream_audio(
async def test_recorder_log(hass, caplog): async def test_recorder_log(hass, caplog):
"""Test starting a stream to record logs the url without username and password.""" """Test starting a stream to record logs the url without username and password."""
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
stream = create_stream(hass, "https://abcd:efgh@foo.bar") stream = create_stream(hass, "https://abcd:efgh@foo.bar", {})
with patch.object(hass.config, "is_allowed_path", return_value=True): with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path") await stream.async_record("/example/path")
assert "https://abcd:efgh@foo.bar" not in caplog.text assert "https://abcd:efgh@foo.bar" not in caplog.text

View File

@ -220,7 +220,7 @@ class MockFlushPart:
async def async_decode_stream(hass, packets, py_av=None): async def async_decode_stream(hass, packets, py_av=None):
"""Start a stream worker that decodes incoming stream packets into output segments.""" """Start a stream worker that decodes incoming stream packets into output segments."""
stream = Stream(hass, STREAM_SOURCE) stream = Stream(hass, STREAM_SOURCE, {})
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
if not py_av: if not py_av:
@ -244,7 +244,7 @@ async def async_decode_stream(hass, packets, py_av=None):
async def test_stream_open_fails(hass): async def test_stream_open_fails(hass):
"""Test failure on stream open.""" """Test failure on stream open."""
stream = Stream(hass, STREAM_SOURCE) stream = Stream(hass, STREAM_SOURCE, {})
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
with patch("av.open") as av_open: with patch("av.open") as av_open:
av_open.side_effect = av.error.InvalidDataError(-2, "error") av_open.side_effect = av.error.InvalidDataError(-2, "error")
@ -565,7 +565,7 @@ async def test_stream_stopped_while_decoding(hass):
worker_open = threading.Event() worker_open = threading.Event()
worker_wake = threading.Event() worker_wake = threading.Event()
stream = Stream(hass, STREAM_SOURCE) stream = Stream(hass, STREAM_SOURCE, {})
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
py_av = MockPyAv() py_av = MockPyAv()
@ -592,7 +592,7 @@ async def test_update_stream_source(hass):
worker_open = threading.Event() worker_open = threading.Event()
worker_wake = threading.Event() worker_wake = threading.Event()
stream = Stream(hass, STREAM_SOURCE) stream = Stream(hass, STREAM_SOURCE, {})
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
# Note that keepalive is not set here. The stream is "restarted" even though # Note that keepalive is not set here. The stream is "restarted" even though
# it is not stopping due to failure. # it is not stopping due to failure.
@ -636,7 +636,7 @@ async def test_update_stream_source(hass):
async def test_worker_log(hass, caplog): async def test_worker_log(hass, caplog):
"""Test that the worker logs the url without username and password.""" """Test that the worker logs the url without username and password."""
stream = Stream(hass, "https://abcd:efgh@foo.bar") stream = Stream(hass, "https://abcd:efgh@foo.bar", {})
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
with patch("av.open") as av_open: with patch("av.open") as av_open:
av_open.side_effect = av.error.InvalidDataError(-2, "error") av_open.side_effect = av.error.InvalidDataError(-2, "error")
@ -654,7 +654,7 @@ async def test_durations(hass, record_worker_sync):
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
source = generate_h264_video() source = generate_h264_video()
stream = create_stream(hass, source) stream = create_stream(hass, source, {})
# use record_worker_sync to grab output segments # use record_worker_sync to grab output segments
with patch.object(hass.config, "is_allowed_path", return_value=True): with patch.object(hass.config, "is_allowed_path", return_value=True):
@ -693,7 +693,7 @@ async def test_has_keyframe(hass, record_worker_sync):
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
source = generate_h264_video() source = generate_h264_video()
stream = create_stream(hass, source) stream = create_stream(hass, source, {})
# use record_worker_sync to grab output segments # use record_worker_sync to grab output segments
with patch.object(hass.config, "is_allowed_path", return_value=True): with patch.object(hass.config, "is_allowed_path", return_value=True):