From b6126e78211632f6c24c446680fdfa88f82521f3 Mon Sep 17 00:00:00 2001 From: Josh Pettersen <12600312+bubonicbob@users.noreply.github.com> Date: Tue, 30 Jan 2024 14:38:45 -0800 Subject: [PATCH] Convert gather calls into TaskGroups (#109010) --- .../components/powerwall/__init__.py | 84 ++++++++++--------- 1 file changed, 45 insertions(+), 39 deletions(-) diff --git a/homeassistant/components/powerwall/__init__.py b/homeassistant/components/powerwall/__init__.py index 29e890e6027..79e612deb4c 100644 --- a/homeassistant/components/powerwall/__init__.py +++ b/homeassistant/components/powerwall/__init__.py @@ -222,28 +222,31 @@ async def _login_and_fetch_base_info( async def _call_base_info(power_wall: Powerwall, host: str) -> PowerwallBaseInfo: """Return PowerwallBaseInfo for the device.""" - ( - gateway_din, - site_info, - status, - device_type, - serial_numbers, - ) = await asyncio.gather( - power_wall.get_gateway_din(), - power_wall.get_site_info(), - power_wall.get_status(), - power_wall.get_device_type(), - power_wall.get_serial_numbers(), - ) + try: + async with asyncio.TaskGroup() as tg: + gateway_din = tg.create_task(power_wall.get_gateway_din()) + site_info = tg.create_task(power_wall.get_site_info()) + status = tg.create_task(power_wall.get_status()) + device_type = tg.create_task(power_wall.get_device_type()) + serial_numbers = tg.create_task(power_wall.get_serial_numbers()) + + # Mimic the behavior of asyncio.gather by reraising the first caught exception since + # this is what is expected by the caller of this method + # + # While it would have been cleaner to use asyncio.gather in the first place instead of + # TaskGroup but in cases where you have more than 6 tasks, the linter fails due to + # missing typing information. + except BaseExceptionGroup as e: + raise e.exceptions[0] from None # Serial numbers MUST be sorted to ensure the unique_id is always the same # for backwards compatibility. return PowerwallBaseInfo( - gateway_din=gateway_din.upper(), - site_info=site_info, - status=status, - device_type=device_type, - serial_numbers=sorted(serial_numbers), + gateway_din=gateway_din.result().upper(), + site_info=site_info.result(), + status=status.result(), + device_type=device_type.result(), + serial_numbers=sorted(serial_numbers.result()), url=f"https://{host}", ) @@ -258,29 +261,32 @@ async def get_backup_reserve_percentage(power_wall: Powerwall) -> Optional[float async def _fetch_powerwall_data(power_wall: Powerwall) -> PowerwallData: """Process and update powerwall data.""" - ( - backup_reserve, - charge, - site_master, - meters, - grid_services_active, - grid_status, - ) = await asyncio.gather( - get_backup_reserve_percentage(power_wall), - power_wall.get_charge(), - power_wall.get_sitemaster(), - power_wall.get_meters(), - power_wall.is_grid_services_active(), - power_wall.get_grid_status(), - ) + + try: + async with asyncio.TaskGroup() as tg: + backup_reserve = tg.create_task(get_backup_reserve_percentage(power_wall)) + charge = tg.create_task(power_wall.get_charge()) + site_master = tg.create_task(power_wall.get_sitemaster()) + meters = tg.create_task(power_wall.get_meters()) + grid_services_active = tg.create_task(power_wall.is_grid_services_active()) + grid_status = tg.create_task(power_wall.get_grid_status()) + + # Mimic the behavior of asyncio.gather by reraising the first caught exception since + # this is what is expected by the caller of this method + # + # While it would have been cleaner to use asyncio.gather in the first place instead of + # TaskGroup but in cases where you have more than 6 tasks, the linter fails due to + # missing typing information. + except BaseExceptionGroup as e: + raise e.exceptions[0] from None return PowerwallData( - charge=charge, - site_master=site_master, - meters=meters, - grid_services_active=grid_services_active, - grid_status=grid_status, - backup_reserve=backup_reserve, + charge=charge.result(), + site_master=site_master.result(), + meters=meters.result(), + grid_services_active=grid_services_active.result(), + grid_status=grid_status.result(), + backup_reserve=backup_reserve.result(), )