Convert gather calls into TaskGroups (#109010)

This commit is contained in:
Josh Pettersen 2024-01-30 14:38:45 -08:00 committed by GitHub
parent bea7dd756a
commit b6126e7821
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -222,28 +222,31 @@ async def _login_and_fetch_base_info(
async def _call_base_info(power_wall: Powerwall, host: str) -> PowerwallBaseInfo: async def _call_base_info(power_wall: Powerwall, host: str) -> PowerwallBaseInfo:
"""Return PowerwallBaseInfo for the device.""" """Return PowerwallBaseInfo for the device."""
( try:
gateway_din, async with asyncio.TaskGroup() as tg:
site_info, gateway_din = tg.create_task(power_wall.get_gateway_din())
status, site_info = tg.create_task(power_wall.get_site_info())
device_type, status = tg.create_task(power_wall.get_status())
serial_numbers, device_type = tg.create_task(power_wall.get_device_type())
) = await asyncio.gather( serial_numbers = tg.create_task(power_wall.get_serial_numbers())
power_wall.get_gateway_din(),
power_wall.get_site_info(), # Mimic the behavior of asyncio.gather by reraising the first caught exception since
power_wall.get_status(), # this is what is expected by the caller of this method
power_wall.get_device_type(), #
power_wall.get_serial_numbers(), # While it would have been cleaner to use asyncio.gather in the first place instead of
) # TaskGroup but in cases where you have more than 6 tasks, the linter fails due to
# missing typing information.
except BaseExceptionGroup as e:
raise e.exceptions[0] from None
# Serial numbers MUST be sorted to ensure the unique_id is always the same # Serial numbers MUST be sorted to ensure the unique_id is always the same
# for backwards compatibility. # for backwards compatibility.
return PowerwallBaseInfo( return PowerwallBaseInfo(
gateway_din=gateway_din.upper(), gateway_din=gateway_din.result().upper(),
site_info=site_info, site_info=site_info.result(),
status=status, status=status.result(),
device_type=device_type, device_type=device_type.result(),
serial_numbers=sorted(serial_numbers), serial_numbers=sorted(serial_numbers.result()),
url=f"https://{host}", url=f"https://{host}",
) )
@ -258,29 +261,32 @@ async def get_backup_reserve_percentage(power_wall: Powerwall) -> Optional[float
async def _fetch_powerwall_data(power_wall: Powerwall) -> PowerwallData: async def _fetch_powerwall_data(power_wall: Powerwall) -> PowerwallData:
"""Process and update powerwall data.""" """Process and update powerwall data."""
(
backup_reserve, try:
charge, async with asyncio.TaskGroup() as tg:
site_master, backup_reserve = tg.create_task(get_backup_reserve_percentage(power_wall))
meters, charge = tg.create_task(power_wall.get_charge())
grid_services_active, site_master = tg.create_task(power_wall.get_sitemaster())
grid_status, meters = tg.create_task(power_wall.get_meters())
) = await asyncio.gather( grid_services_active = tg.create_task(power_wall.is_grid_services_active())
get_backup_reserve_percentage(power_wall), grid_status = tg.create_task(power_wall.get_grid_status())
power_wall.get_charge(),
power_wall.get_sitemaster(), # Mimic the behavior of asyncio.gather by reraising the first caught exception since
power_wall.get_meters(), # this is what is expected by the caller of this method
power_wall.is_grid_services_active(), #
power_wall.get_grid_status(), # While it would have been cleaner to use asyncio.gather in the first place instead of
) # TaskGroup but in cases where you have more than 6 tasks, the linter fails due to
# missing typing information.
except BaseExceptionGroup as e:
raise e.exceptions[0] from None
return PowerwallData( return PowerwallData(
charge=charge, charge=charge.result(),
site_master=site_master, site_master=site_master.result(),
meters=meters, meters=meters.result(),
grid_services_active=grid_services_active, grid_services_active=grid_services_active.result(),
grid_status=grid_status, grid_status=grid_status.result(),
backup_reserve=backup_reserve, backup_reserve=backup_reserve.result(),
) )