"""NextBus data update coordinator."""
from datetime import timedelta
import logging
from typing import Any, cast

from py_nextbus import NextBusClient
from py_nextbus.client import NextBusFormatError, NextBusHTTPError, RouteStop

from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed

from .const import DOMAIN
from .util import listify

_LOGGER = logging.getLogger(__name__)


class NextBusDataUpdateCoordinator(DataUpdateCoordinator):
    """Class to manage fetching NextBus data."""

    def __init__(self, hass: HomeAssistant, agency: str) -> None:
        """Initialize a global coordinator for fetching data for a given agency."""
        super().__init__(
            hass,
            _LOGGER,
            name=DOMAIN,
            update_interval=timedelta(seconds=30),
        )
        self.client = NextBusClient(output_format="json", agency=agency)
        self._agency = agency
        self._stop_routes: set[RouteStop] = set()
        self._predictions: dict[RouteStop, dict[str, Any]] = {}

    def add_stop_route(self, stop_tag: str, route_tag: str) -> None:
        """Tell coordinator to start tracking a given stop and route."""
        self._stop_routes.add(RouteStop(route_tag, stop_tag))

    def remove_stop_route(self, stop_tag: str, route_tag: str) -> None:
        """Tell coordinator to stop tracking a given stop and route."""
        self._stop_routes.remove(RouteStop(route_tag, stop_tag))

    def get_prediction_data(
        self, stop_tag: str, route_tag: str
    ) -> dict[str, Any] | None:
        """Get prediction result for a given stop and route."""
        return self._predictions.get(RouteStop(route_tag, stop_tag))

    def _calc_predictions(self, data: dict[str, Any]) -> None:
        self._predictions = {
            RouteStop(prediction["routeTag"], prediction["stopTag"]): prediction
            for prediction in listify(data.get("predictions", []))
        }

    def get_attribution(self) -> str | None:
        """Get attribution from api results."""
        return self.data.get("copyright")

    def has_routes(self) -> bool:
        """Check if this coordinator is tracking any routes."""
        return len(self._stop_routes) > 0

    async def _async_update_data(self) -> dict[str, Any]:
        """Fetch data from NextBus."""
        self.logger.debug("Updating data from API. Routes: %s", str(self._stop_routes))

        def _update_data() -> dict:
            """Fetch data from NextBus."""
            self.logger.debug("Updating data from API (executor)")
            try:
                data = self.client.get_predictions_for_multi_stops(self._stop_routes)
                # Casting here because we expect dict and not a str due to the input format selected being JSON
                data = cast(dict[str, Any], data)
                self._calc_predictions(data)
                return data
            except (NextBusHTTPError, NextBusFormatError) as ex:
                raise UpdateFailed("Failed updating nextbus data", ex) from ex

        return await self.hass.async_add_executor_job(_update_data)