diff --git a/homeassistant/components/mazda/__init__.py b/homeassistant/components/mazda/__init__.py index 469e28eb829..d704cfb7f44 100644 --- a/homeassistant/components/mazda/__init__.py +++ b/homeassistant/components/mazda/__init__.py @@ -69,16 +69,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Handle a service call.""" # Get device entry from device registry dev_reg = device_registry.async_get(hass) - device_id = service_call.data.get("device_id") + device_id = service_call.data["device_id"] device_entry = dev_reg.async_get(device_id) # Get vehicle VIN from device identifiers - mazda_identifiers = [ + mazda_identifiers = ( identifier for identifier in device_entry.identifiers if identifier[0] == DOMAIN - ] - vin_identifier = next(iter(mazda_identifiers)) + ) + vin_identifier = next(mazda_identifiers) vin = vin_identifier[1] # Get vehicle ID and API client from hass.data @@ -89,6 +89,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: if vehicle["vin"] == vin: vehicle_id = vehicle["id"] api_client = entry_data[DATA_CLIENT] + break if vehicle_id == 0 or api_client is None: raise HomeAssistantError("Vehicle ID not found") @@ -96,14 +97,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: api_method = getattr(api_client, service_call.service) try: if service_call.service == "send_poi": - latitude = service_call.data.get("latitude") - longitude = service_call.data.get("longitude") - poi_name = service_call.data.get("poi_name") + latitude = service_call.data["latitude"] + longitude = service_call.data["longitude"] + poi_name = service_call.data["poi_name"] await api_method(vehicle_id, latitude, longitude, poi_name) else: await api_method(vehicle_id) except Exception as ex: - _LOGGER.exception("Error occurred during Mazda service call: %s", ex) raise HomeAssistantError(ex) from ex def validate_mazda_device_id(device_id): @@ -119,7 +119,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: for identifier in device_entry.identifiers if identifier[0] == DOMAIN ] - if len(mazda_identifiers) < 1: + if not mazda_identifiers: raise vol.Invalid("Device ID is not a Mazda vehicle") return device_id diff --git a/tests/components/mazda/test_init.py b/tests/components/mazda/test_init.py index 0280e8f34fa..0c47ae8f2e0 100644 --- a/tests/components/mazda/test_init.py +++ b/tests/components/mazda/test_init.py @@ -7,7 +7,7 @@ from pymazda import MazdaAuthenticationException, MazdaException import pytest import voluptuous as vol -from homeassistant.components.mazda.const import DOMAIN, SERVICES +from homeassistant.components.mazda.const import DOMAIN from homeassistant.config_entries import ConfigEntryState from homeassistant.const import ( CONF_EMAIL, @@ -186,7 +186,23 @@ async def test_device_no_nickname(hass): assert reg_device.name == "2021 MAZDA3 2.5 S SE AWD" -async def test_services(hass): +@pytest.mark.parametrize( + "service, service_data, expected_args", + [ + ("start_charging", {}, [12345]), + ("start_engine", {}, [12345]), + ("stop_charging", {}, [12345]), + ("stop_engine", {}, [12345]), + ("turn_off_hazard_lights", {}, [12345]), + ("turn_on_hazard_lights", {}, [12345]), + ( + "send_poi", + {"latitude": 1.2345, "longitude": 2.3456, "poi_name": "Work"}, + [12345, 1.2345, 2.3456, "Work"], + ), + ], +) +async def test_services(hass, service, service_data, expected_args): """Test service calls.""" client_mock = await init_integration(hass) @@ -196,21 +212,13 @@ async def test_services(hass): ) device_id = reg_device.id - for service in SERVICES: - service_data = {"device_id": device_id} - if service == "send_poi": - service_data["latitude"] = 1.2345 - service_data["longitude"] = 2.3456 - service_data["poi_name"] = "Work" + service_data["device_id"] = device_id - await hass.services.async_call(DOMAIN, service, service_data, blocking=True) - await hass.async_block_till_done() + await hass.services.async_call(DOMAIN, service, service_data, blocking=True) + await hass.async_block_till_done() - api_method = getattr(client_mock, service) - if service == "send_poi": - api_method.assert_called_once_with(12345, 1.2345, 2.3456, "Work") - else: - api_method.assert_called_once_with(12345) + api_method = getattr(client_mock, service) + api_method.assert_called_once_with(*expected_args) async def test_service_invalid_device_id(hass):