diff --git a/homeassistant/components/gpsd/__init__.py b/homeassistant/components/gpsd/__init__.py index a0e3db2e404..0550148d2a7 100644 --- a/homeassistant/components/gpsd/__init__.py +++ b/homeassistant/components/gpsd/__init__.py @@ -2,19 +2,45 @@ from __future__ import annotations +from gps3.agps3threaded import AGPS3mechanism + from homeassistant.config_entries import ConfigEntry -from homeassistant.const import Platform +from homeassistant.const import CONF_HOST, CONF_PORT, Platform from homeassistant.core import HomeAssistant PLATFORMS: list[Platform] = [Platform.SENSOR] +type GPSDConfigEntry = ConfigEntry[AGPS3mechanism] -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + +async def async_setup_entry(hass: HomeAssistant, entry: GPSDConfigEntry) -> bool: """Set up GPSD from a config entry.""" + agps_thread = AGPS3mechanism() + entry.runtime_data = agps_thread + + def setup_agps() -> None: + host = entry.data.get(CONF_HOST) + port = entry.data.get(CONF_PORT) + agps_thread.stream_data(host, port) + agps_thread.run_thread() + + await hass.async_add_executor_job(setup_agps) + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + return True -async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_unload_entry(hass: HomeAssistant, entry: GPSDConfigEntry) -> bool: """Unload a config entry.""" - return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) + if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS): + agps_thread = entry.runtime_data + await hass.async_add_executor_job( + lambda: agps_thread.stream_data( + host=entry.data.get(CONF_HOST), + port=entry.data.get(CONF_PORT), + enable=False, + ) + ) + + return unload_ok diff --git a/homeassistant/components/gpsd/config_flow.py b/homeassistant/components/gpsd/config_flow.py index 10fb8a3a252..59c95d0ddbf 100644 --- a/homeassistant/components/gpsd/config_flow.py +++ b/homeassistant/components/gpsd/config_flow.py @@ -27,6 +27,18 @@ class GPSDConfigFlow(ConfigFlow, domain=DOMAIN): VERSION = 1 + @staticmethod + def test_connection(host: str, port: int) -> bool: + """Test socket connection.""" + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.connect((host, port)) + sock.shutdown(2) + except OSError: + return False + else: + return True + async def async_step_import(self, import_data: dict[str, Any]) -> ConfigFlowResult: """Import a config entry from configuration.yaml.""" return await self.async_step_user(import_data) @@ -38,11 +50,11 @@ class GPSDConfigFlow(ConfigFlow, domain=DOMAIN): if user_input is not None: self._async_abort_entries_match(user_input) - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - sock.connect((user_input[CONF_HOST], user_input[CONF_PORT])) - sock.shutdown(2) - except OSError: + connected = await self.hass.async_add_executor_job( + self.test_connection, user_input[CONF_HOST], user_input[CONF_PORT] + ) + + if not connected: return self.async_abort(reason="cannot_connect") port = "" diff --git a/homeassistant/components/gpsd/sensor.py b/homeassistant/components/gpsd/sensor.py index 5a978f9f66e..e67287ae134 100644 --- a/homeassistant/components/gpsd/sensor.py +++ b/homeassistant/components/gpsd/sensor.py @@ -20,7 +20,7 @@ from homeassistant.components.sensor import ( SensorEntity, SensorEntityDescription, ) -from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry +from homeassistant.config_entries import SOURCE_IMPORT from homeassistant.const import ( ATTR_LATITUDE, ATTR_LONGITUDE, @@ -37,6 +37,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType +from . import GPSDConfigEntry from .const import DOMAIN _LOGGER = logging.getLogger(__name__) @@ -81,15 +82,14 @@ PLATFORM_SCHEMA = SENSOR_PLATFORM_SCHEMA.extend( async def async_setup_entry( hass: HomeAssistant, - config_entry: ConfigEntry, + config_entry: GPSDConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up the GPSD component.""" async_add_entities( [ GpsdSensor( - config_entry.data[CONF_HOST], - config_entry.data[CONF_PORT], + config_entry.runtime_data, config_entry.entry_id, description, ) @@ -135,8 +135,7 @@ class GpsdSensor(SensorEntity): def __init__( self, - host: str, - port: int, + agps_thread: AGPS3mechanism, unique_id: str, description: GpsdSensorDescription, ) -> None: @@ -148,9 +147,7 @@ class GpsdSensor(SensorEntity): ) self._attr_unique_id = f"{unique_id}-{self.entity_description.key}" - self.agps_thread = AGPS3mechanism() - self.agps_thread.stream_data(host=host, port=port) - self.agps_thread.run_thread() + self.agps_thread = agps_thread @property def native_value(self) -> str | None: diff --git a/tests/components/gpsd/test_config_flow.py b/tests/components/gpsd/test_config_flow.py index 6f330571076..2d68a704119 100644 --- a/tests/components/gpsd/test_config_flow.py +++ b/tests/components/gpsd/test_config_flow.py @@ -43,10 +43,7 @@ async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None: async def test_connection_error(hass: HomeAssistant) -> None: """Test connection to host error.""" - with patch("socket.socket") as mock_socket: - mock_connect = mock_socket.return_value.connect - mock_connect.side_effect = OSError - + with patch("socket.socket", side_effect=OSError): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER},