Disconnect before reconnecting to satellite (#105500)

Disconnect before reconnecting
This commit is contained in:
Michael Hansen 2023-12-11 10:18:46 -06:00 committed by GitHub
parent b71f488d3e
commit 80607f7750
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 8 deletions

View File

@ -71,11 +71,11 @@ class WyomingSatellite:
while self.is_running: while self.is_running:
try: try:
# Check if satellite has been disabled # Check if satellite has been disabled
if not self.device.is_enabled: while not self.device.is_enabled:
await self.on_disabled() await self.on_disabled()
if not self.is_running: if not self.is_running:
# Satellite was stopped while waiting to be enabled # Satellite was stopped while waiting to be enabled
break return
# Connect and run pipeline loop # Connect and run pipeline loop
await self._run_once() await self._run_once()
@ -130,6 +130,7 @@ class WyomingSatellite:
self._audio_queue.put_nowait(None) self._audio_queue.put_nowait(None)
self._enabled_changed_event.set() self._enabled_changed_event.set()
self._enabled_changed_event.clear()
def _pipeline_changed(self) -> None: def _pipeline_changed(self) -> None:
"""Run when device pipeline changes.""" """Run when device pipeline changes."""
@ -255,9 +256,17 @@ class WyomingSatellite:
chunk = AudioChunk.from_event(client_event) chunk = AudioChunk.from_event(client_event)
chunk = self._chunk_converter.convert(chunk) chunk = self._chunk_converter.convert(chunk)
self._audio_queue.put_nowait(chunk.audio) self._audio_queue.put_nowait(chunk.audio)
elif AudioStop.is_type(client_event.type):
# Stop pipeline
_LOGGER.debug("Client requested pipeline to stop")
self._audio_queue.put_nowait(b"")
break
else: else:
_LOGGER.debug("Unexpected event from satellite: %s", client_event) _LOGGER.debug("Unexpected event from satellite: %s", client_event)
# Ensure task finishes
await _pipeline_task
_LOGGER.debug("Pipeline finished") _LOGGER.debug("Pipeline finished")
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
@ -348,12 +357,23 @@ class WyomingSatellite:
async def _connect(self) -> None: async def _connect(self) -> None:
"""Connect to satellite over TCP.""" """Connect to satellite over TCP."""
await self._disconnect()
_LOGGER.debug( _LOGGER.debug(
"Connecting to satellite at %s:%s", self.service.host, self.service.port "Connecting to satellite at %s:%s", self.service.host, self.service.port
) )
self._client = AsyncTcpClient(self.service.host, self.service.port) self._client = AsyncTcpClient(self.service.host, self.service.port)
await self._client.connect() await self._client.connect()
async def _disconnect(self) -> None:
"""Disconnect if satellite is currently connected."""
if self._client is None:
return
_LOGGER.debug("Disconnecting from satellite")
await self._client.disconnect()
self._client = None
async def _stream_tts(self, media_id: str) -> None: async def _stream_tts(self, media_id: str) -> None:
"""Stream TTS WAV audio to satellite in chunks.""" """Stream TTS WAV audio to satellite in chunks."""
assert self._client is not None assert self._client is not None

View File

@ -322,11 +322,12 @@ async def test_satellite_disabled(hass: HomeAssistant) -> None:
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
): ):
satellite = original_make_satellite(hass, config_entry, service) satellite = original_make_satellite(hass, config_entry, service)
satellite.device.is_enabled = False satellite.device.set_is_enabled(False)
return satellite return satellite
async def on_disabled(self): async def on_disabled(self):
self.device.set_is_enabled(True)
on_disabled_event.set() on_disabled_event.set()
with patch( with patch(
@ -368,11 +369,19 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
async def test_satellite_reconnect(hass: HomeAssistant) -> None: async def test_satellite_reconnect(hass: HomeAssistant) -> None:
"""Test satellite reconnect call after connection refused.""" """Test satellite reconnect call after connection refused."""
on_reconnect_event = asyncio.Event() num_reconnects = 0
reconnect_event = asyncio.Event()
stopped_event = asyncio.Event()
async def on_reconnect(self): async def on_reconnect(self):
nonlocal num_reconnects
num_reconnects += 1
if num_reconnects >= 2:
reconnect_event.set()
self.stop() self.stop()
on_reconnect_event.set()
async def on_stopped(self):
stopped_event.set()
with patch( with patch(
"homeassistant.components.wyoming.data.load_wyoming_info", "homeassistant.components.wyoming.data.load_wyoming_info",
@ -383,10 +392,14 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
), patch( ), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect", "homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
on_reconnect, on_reconnect,
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
on_stopped,
): ):
await setup_config_entry(hass) await setup_config_entry(hass)
async with asyncio.timeout(1): async with asyncio.timeout(1):
await on_reconnect_event.wait() await reconnect_event.wait()
await stopped_event.wait()
async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None: async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None: