mirror of
https://github.com/home-assistant/supervisor.git
synced 2025-12-04 15:08:12 +00:00
Add registry manifest fetcher for size-based pull progress
Fetch image manifests directly from container registries before pulling to get accurate layer sizes upfront. This enables size-weighted progress tracking where each layer contributes proportionally to its byte size, rather than equal weight per layer. Key changes: - Add RegistryManifestFetcher that handles auth discovery via WWW-Authenticate headers, token fetching with optional credentials, and multi-arch manifest list resolution - Update ImagePullProgress to accept manifest layer sizes via set_manifest() and calculate size-weighted progress - Fall back to count-based progress when manifest fetch fails - Pre-populate layer sizes from manifest when creating layer trackers The manifest fetcher supports ghcr.io, Docker Hub, and private registries by using credentials from Docker config when available. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -213,9 +213,26 @@ class DockerInterface(JobGroup, ABC):
|
||||
raise ValueError("Cannot pull without an image!")
|
||||
|
||||
image_arch = arch or self.sys_arch.supervisor
|
||||
platform = MAP_ARCH[image_arch]
|
||||
pull_progress = ImagePullProgress()
|
||||
current_job = self.sys_jobs.current
|
||||
|
||||
# Try to fetch manifest for accurate size-based progress
|
||||
# This is optional - if it fails, we fall back to count-based progress
|
||||
try:
|
||||
manifest = await self.sys_docker.manifest_fetcher.get_manifest(
|
||||
image, str(version), platform=platform
|
||||
)
|
||||
if manifest:
|
||||
pull_progress.set_manifest(manifest)
|
||||
_LOGGER.debug(
|
||||
"Using manifest for progress: %d layers, %d bytes",
|
||||
manifest.layer_count,
|
||||
manifest.total_size,
|
||||
)
|
||||
except Exception as err: # noqa: BLE001
|
||||
_LOGGER.debug("Could not fetch manifest for progress: %s", err)
|
||||
|
||||
async def process_pull_event(event: PullLogEntry) -> None:
|
||||
"""Process pull event and update job progress."""
|
||||
if event.job_id != current_job.uuid:
|
||||
@@ -244,7 +261,7 @@ class DockerInterface(JobGroup, ABC):
|
||||
current_job.uuid,
|
||||
image,
|
||||
str(version),
|
||||
platform=MAP_ARCH[image_arch],
|
||||
platform=platform,
|
||||
auth=credentials,
|
||||
)
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ from ..exceptions import (
|
||||
from ..utils.common import FileConfiguration
|
||||
from ..validate import SCHEMA_DOCKER_CONFIG
|
||||
from .const import DOCKER_HUB, DOCKER_HUB_LEGACY, LABEL_MANAGED
|
||||
from .manifest import RegistryManifestFetcher
|
||||
from .monitor import DockerMonitor
|
||||
from .network import DockerNetwork
|
||||
from .utils import get_registry_from_image
|
||||
@@ -258,6 +259,9 @@ class DockerAPI(CoreSysAttributes):
|
||||
self._info: DockerInfo | None = None
|
||||
self.config: DockerConfig = DockerConfig()
|
||||
self._monitor: DockerMonitor = DockerMonitor(coresys)
|
||||
self._manifest_fetcher: RegistryManifestFetcher = RegistryManifestFetcher(
|
||||
coresys
|
||||
)
|
||||
|
||||
async def post_init(self) -> Self:
|
||||
"""Post init actions that must be done in event loop."""
|
||||
@@ -323,6 +327,11 @@ class DockerAPI(CoreSysAttributes):
|
||||
"""Return docker events monitor."""
|
||||
return self._monitor
|
||||
|
||||
@property
|
||||
def manifest_fetcher(self) -> RegistryManifestFetcher:
|
||||
"""Return manifest fetcher for registry access."""
|
||||
return self._manifest_fetcher
|
||||
|
||||
async def load(self) -> None:
|
||||
"""Start docker events monitor."""
|
||||
await self.monitor.load()
|
||||
|
||||
354
supervisor/docker/manifest.py
Normal file
354
supervisor/docker/manifest.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""Docker registry manifest fetcher.
|
||||
|
||||
Fetches image manifests directly from container registries to get layer sizes
|
||||
before pulling an image. This enables accurate size-based progress tracking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .const import DOCKER_HUB, IMAGE_WITH_HOST
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..coresys import CoreSys
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Default registry for images without explicit host
|
||||
DEFAULT_REGISTRY = "registry-1.docker.io"
|
||||
|
||||
# Media types for manifest requests
|
||||
MANIFEST_MEDIA_TYPES = (
|
||||
"application/vnd.docker.distribution.manifest.v2+json",
|
||||
"application/vnd.oci.image.manifest.v1+json",
|
||||
"application/vnd.docker.distribution.manifest.list.v2+json",
|
||||
"application/vnd.oci.image.index.v1+json",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageManifest:
|
||||
"""Container image manifest with layer information."""
|
||||
|
||||
digest: str
|
||||
total_size: int
|
||||
layers: dict[str, int] # digest -> size in bytes
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
"""Return number of layers."""
|
||||
return len(self.layers)
|
||||
|
||||
|
||||
def parse_image_reference(image: str, tag: str) -> tuple[str, str, str]:
|
||||
"""Parse image reference into (registry, repository, tag).
|
||||
|
||||
Examples:
|
||||
ghcr.io/home-assistant/home-assistant:2025.1.0
|
||||
-> (ghcr.io, home-assistant/home-assistant, 2025.1.0)
|
||||
homeassistant/home-assistant:latest
|
||||
-> (registry-1.docker.io, homeassistant/home-assistant, latest)
|
||||
alpine:3.18
|
||||
-> (registry-1.docker.io, library/alpine, 3.18)
|
||||
|
||||
"""
|
||||
# Check if image has explicit registry host
|
||||
match = IMAGE_WITH_HOST.match(image)
|
||||
if match:
|
||||
registry = match.group(1)
|
||||
repository = image[len(registry) + 1 :] # Remove "registry/" prefix
|
||||
else:
|
||||
registry = DEFAULT_REGISTRY
|
||||
repository = image
|
||||
# Docker Hub requires "library/" prefix for official images
|
||||
if "/" not in repository:
|
||||
repository = f"library/{repository}"
|
||||
|
||||
return registry, repository, tag
|
||||
|
||||
|
||||
class RegistryManifestFetcher:
|
||||
"""Fetches manifests from container registries."""
|
||||
|
||||
def __init__(self, coresys: CoreSys) -> None:
|
||||
"""Initialize the fetcher."""
|
||||
self.coresys = coresys
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
def _get_credentials(self, registry: str) -> tuple[str, str] | None:
|
||||
"""Get credentials for registry from Docker config.
|
||||
|
||||
Returns (username, password) tuple or None if no credentials.
|
||||
"""
|
||||
registries = self.coresys.docker.config.registries
|
||||
|
||||
# Map registry hostname to config key
|
||||
# Docker Hub can be stored as "hub.docker.com" in config
|
||||
if registry == DEFAULT_REGISTRY:
|
||||
if DOCKER_HUB in registries:
|
||||
creds = registries[DOCKER_HUB]
|
||||
return creds.get("username"), creds.get("password")
|
||||
elif registry in registries:
|
||||
creds = registries[registry]
|
||||
return creds.get("username"), creds.get("password")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_auth_token(
|
||||
self,
|
||||
session: aiohttp.ClientSession,
|
||||
registry: str,
|
||||
repository: str,
|
||||
) -> str | None:
|
||||
"""Get authentication token for registry.
|
||||
|
||||
Uses the WWW-Authenticate header from a 401 response to discover
|
||||
the token endpoint, then requests a token with appropriate scope.
|
||||
"""
|
||||
# First, make an unauthenticated request to get WWW-Authenticate header
|
||||
manifest_url = f"https://{registry}/v2/{repository}/manifests/latest"
|
||||
|
||||
try:
|
||||
async with session.get(manifest_url) as resp:
|
||||
if resp.status == 200:
|
||||
# No auth required
|
||||
return None
|
||||
|
||||
if resp.status != 401:
|
||||
_LOGGER.warning(
|
||||
"Unexpected status %d from registry %s", resp.status, registry
|
||||
)
|
||||
return None
|
||||
|
||||
www_auth = resp.headers.get("WWW-Authenticate", "")
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to connect to registry %s: %s", registry, err)
|
||||
return None
|
||||
|
||||
# Parse WWW-Authenticate: Bearer realm="...",service="...",scope="..."
|
||||
if not www_auth.startswith("Bearer "):
|
||||
_LOGGER.warning("Unsupported auth type from %s: %s", registry, www_auth)
|
||||
return None
|
||||
|
||||
params = {}
|
||||
for match in re.finditer(r'(\w+)="([^"]*)"', www_auth):
|
||||
params[match.group(1)] = match.group(2)
|
||||
|
||||
realm = params.get("realm")
|
||||
service = params.get("service")
|
||||
|
||||
if not realm:
|
||||
_LOGGER.warning("No realm in WWW-Authenticate from %s", registry)
|
||||
return None
|
||||
|
||||
# Build token request URL
|
||||
token_url = f"{realm}?scope=repository:{repository}:pull"
|
||||
if service:
|
||||
token_url += f"&service={service}"
|
||||
|
||||
# Check for credentials
|
||||
auth = None
|
||||
credentials = self._get_credentials(registry)
|
||||
if credentials:
|
||||
username, password = credentials
|
||||
if username and password:
|
||||
auth = aiohttp.BasicAuth(username, password)
|
||||
_LOGGER.debug("Using credentials for %s", registry)
|
||||
|
||||
try:
|
||||
async with session.get(token_url, auth=auth) as resp:
|
||||
if resp.status != 200:
|
||||
_LOGGER.warning(
|
||||
"Failed to get token from %s: %d", realm, resp.status
|
||||
)
|
||||
return None
|
||||
|
||||
data = await resp.json()
|
||||
return data.get("token") or data.get("access_token")
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to get auth token: %s", err)
|
||||
return None
|
||||
|
||||
async def _fetch_manifest(
|
||||
self,
|
||||
session: aiohttp.ClientSession,
|
||||
registry: str,
|
||||
repository: str,
|
||||
reference: str,
|
||||
token: str | None,
|
||||
platform: str | None = None,
|
||||
) -> dict | None:
|
||||
"""Fetch manifest from registry.
|
||||
|
||||
If the manifest is a manifest list (multi-arch), fetches the
|
||||
platform-specific manifest.
|
||||
"""
|
||||
manifest_url = f"https://{registry}/v2/{repository}/manifests/{reference}"
|
||||
|
||||
headers = {"Accept": ", ".join(MANIFEST_MEDIA_TYPES)}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
try:
|
||||
async with session.get(manifest_url, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
_LOGGER.warning(
|
||||
"Failed to fetch manifest for %s/%s:%s - %d",
|
||||
registry,
|
||||
repository,
|
||||
reference,
|
||||
resp.status,
|
||||
)
|
||||
return None
|
||||
|
||||
manifest = await resp.json()
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to fetch manifest: %s", err)
|
||||
return None
|
||||
|
||||
media_type = manifest.get("mediaType", "")
|
||||
|
||||
# Check if this is a manifest list (multi-arch image)
|
||||
if "list" in media_type or "index" in media_type:
|
||||
manifests = manifest.get("manifests", [])
|
||||
if not manifests:
|
||||
_LOGGER.warning("Empty manifest list for %s/%s", registry, repository)
|
||||
return None
|
||||
|
||||
# Find matching platform
|
||||
target_os = "linux"
|
||||
target_arch = "amd64" # Default
|
||||
|
||||
if platform:
|
||||
# Platform format is "linux/amd64", "linux/arm64", etc.
|
||||
parts = platform.split("/")
|
||||
if len(parts) >= 2:
|
||||
target_os, target_arch = parts[0], parts[1]
|
||||
|
||||
platform_manifest = None
|
||||
for m in manifests:
|
||||
plat = m.get("platform", {})
|
||||
if (
|
||||
plat.get("os") == target_os
|
||||
and plat.get("architecture") == target_arch
|
||||
):
|
||||
platform_manifest = m
|
||||
break
|
||||
|
||||
if not platform_manifest:
|
||||
# Fall back to first manifest
|
||||
_LOGGER.debug(
|
||||
"Platform %s/%s not found, using first manifest",
|
||||
target_os,
|
||||
target_arch,
|
||||
)
|
||||
platform_manifest = manifests[0]
|
||||
|
||||
# Fetch the platform-specific manifest
|
||||
return await self._fetch_manifest(
|
||||
session,
|
||||
registry,
|
||||
repository,
|
||||
platform_manifest["digest"],
|
||||
token,
|
||||
platform,
|
||||
)
|
||||
|
||||
return manifest
|
||||
|
||||
async def get_manifest(
|
||||
self,
|
||||
image: str,
|
||||
tag: str,
|
||||
platform: str | None = None,
|
||||
) -> ImageManifest | None:
|
||||
"""Fetch manifest and extract layer sizes.
|
||||
|
||||
Args:
|
||||
image: Image name (e.g., "ghcr.io/home-assistant/home-assistant")
|
||||
tag: Image tag (e.g., "2025.1.0")
|
||||
platform: Target platform (e.g., "linux/amd64")
|
||||
|
||||
Returns:
|
||||
ImageManifest with layer sizes, or None if fetch failed.
|
||||
|
||||
"""
|
||||
registry, repository, tag = parse_image_reference(image, tag)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Fetching manifest for %s/%s:%s (platform=%s)",
|
||||
registry,
|
||||
repository,
|
||||
tag,
|
||||
platform,
|
||||
)
|
||||
|
||||
session = await self._get_session()
|
||||
|
||||
# Get auth token
|
||||
token = await self._get_auth_token(session, registry, repository)
|
||||
|
||||
# Fetch manifest
|
||||
manifest = await self._fetch_manifest(
|
||||
session, registry, repository, tag, token, platform
|
||||
)
|
||||
|
||||
if not manifest:
|
||||
return None
|
||||
|
||||
# Extract layer information
|
||||
layers = manifest.get("layers", [])
|
||||
if not layers:
|
||||
_LOGGER.warning(
|
||||
"No layers in manifest for %s/%s:%s", registry, repository, tag
|
||||
)
|
||||
return None
|
||||
|
||||
layer_sizes: dict[str, int] = {}
|
||||
total_size = 0
|
||||
|
||||
for layer in layers:
|
||||
digest = layer.get("digest", "")
|
||||
size = layer.get("size", 0)
|
||||
if digest and size:
|
||||
# Store by short digest (first 12 chars after sha256:)
|
||||
short_digest = (
|
||||
digest.split(":")[1][:12] if ":" in digest else digest[:12]
|
||||
)
|
||||
layer_sizes[short_digest] = size
|
||||
total_size += size
|
||||
|
||||
digest = manifest.get("config", {}).get("digest", "")
|
||||
|
||||
_LOGGER.debug(
|
||||
"Manifest for %s/%s:%s - %d layers, %d bytes total",
|
||||
registry,
|
||||
repository,
|
||||
tag,
|
||||
len(layer_sizes),
|
||||
total_size,
|
||||
)
|
||||
|
||||
return ImageManifest(
|
||||
digest=digest,
|
||||
total_size=total_size,
|
||||
layers=layer_sizes,
|
||||
)
|
||||
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .manager import PullLogEntry
|
||||
from .manifest import ImageManifest
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -109,23 +110,43 @@ class LayerProgress:
|
||||
class ImagePullProgress:
|
||||
"""Track overall progress of pulling an image.
|
||||
|
||||
Uses count-based progress where each layer contributes equally regardless of size.
|
||||
This avoids progress regression when large layers are discovered late due to
|
||||
Docker's rate-limiting of concurrent downloads.
|
||||
When manifest layer sizes are provided, uses size-weighted progress where
|
||||
each layer contributes proportionally to its size. This gives accurate
|
||||
progress based on actual bytes to download.
|
||||
|
||||
Progress is only reported after the first "Downloading" event, since Docker
|
||||
sends "Already exists" and "Pulling fs layer" events before we know the full
|
||||
layer count.
|
||||
When manifest is not available, falls back to count-based progress where
|
||||
each layer contributes equally.
|
||||
|
||||
Layers that already exist locally are excluded from the progress calculation.
|
||||
"""
|
||||
|
||||
layers: dict[str, LayerProgress] = field(default_factory=dict)
|
||||
_last_reported_progress: float = field(default=0.0, repr=False)
|
||||
_seen_downloading: bool = field(default=False, repr=False)
|
||||
_manifest_layer_sizes: dict[str, int] = field(default_factory=dict, repr=False)
|
||||
_total_manifest_size: int = field(default=0, repr=False)
|
||||
|
||||
def set_manifest(self, manifest: ImageManifest) -> None:
|
||||
"""Set manifest layer sizes for accurate size-based progress.
|
||||
|
||||
Should be called before processing pull events.
|
||||
"""
|
||||
self._manifest_layer_sizes = dict(manifest.layers)
|
||||
self._total_manifest_size = manifest.total_size
|
||||
_LOGGER.debug(
|
||||
"Manifest set: %d layers, %d bytes total",
|
||||
len(self._manifest_layer_sizes),
|
||||
self._total_manifest_size,
|
||||
)
|
||||
|
||||
def get_or_create_layer(self, layer_id: str) -> LayerProgress:
|
||||
"""Get existing layer or create new one."""
|
||||
if layer_id not in self.layers:
|
||||
self.layers[layer_id] = LayerProgress(layer_id=layer_id)
|
||||
# If we have manifest sizes, pre-populate the layer's total_size
|
||||
manifest_size = self._manifest_layer_sizes.get(layer_id, 0)
|
||||
self.layers[layer_id] = LayerProgress(
|
||||
layer_id=layer_id, total_size=manifest_size
|
||||
)
|
||||
return self.layers[layer_id]
|
||||
|
||||
def process_event(self, entry: PullLogEntry) -> None:
|
||||
@@ -237,8 +258,13 @@ class ImagePullProgress:
|
||||
def calculate_progress(self) -> float:
|
||||
"""Calculate overall progress 0-100.
|
||||
|
||||
Uses count-based progress where each layer that needs pulling contributes
|
||||
equally. Layers that already exist locally are excluded from the calculation.
|
||||
When manifest layer sizes are available, uses size-weighted progress
|
||||
where each layer contributes proportionally to its size.
|
||||
|
||||
When manifest is not available, falls back to count-based progress
|
||||
where each layer contributes equally.
|
||||
|
||||
Layers that already exist locally are excluded from the calculation.
|
||||
|
||||
Returns 0 until we've seen the first "Downloading" event, since Docker
|
||||
reports "Already exists" and "Pulling fs layer" events before we know
|
||||
@@ -258,10 +284,39 @@ class ImagePullProgress:
|
||||
# All layers already exist, nothing to download
|
||||
return 100.0
|
||||
|
||||
# Each layer contributes equally: sum of layer progresses / total layers
|
||||
# Use size-weighted progress if manifest sizes are available
|
||||
if self._manifest_layer_sizes:
|
||||
return self._calculate_size_weighted_progress(layers_to_pull)
|
||||
|
||||
# Fall back to count-based progress
|
||||
total_progress = sum(layer.calculate_progress() for layer in layers_to_pull)
|
||||
return total_progress / len(layers_to_pull)
|
||||
|
||||
def _calculate_size_weighted_progress(
|
||||
self, layers_to_pull: list[LayerProgress]
|
||||
) -> float:
|
||||
"""Calculate size-weighted progress.
|
||||
|
||||
Each layer contributes to progress proportionally to its size.
|
||||
Progress = sum(layer_progress * layer_size) / total_size
|
||||
"""
|
||||
# Calculate total size of layers that need pulling
|
||||
total_size = sum(layer.total_size for layer in layers_to_pull)
|
||||
|
||||
if total_size == 0:
|
||||
# No size info available, fall back to count-based
|
||||
total_progress = sum(layer.calculate_progress() for layer in layers_to_pull)
|
||||
return total_progress / len(layers_to_pull)
|
||||
|
||||
# Weight each layer's progress by its size
|
||||
weighted_progress = 0.0
|
||||
for layer in layers_to_pull:
|
||||
if layer.total_size > 0:
|
||||
layer_weight = layer.total_size / total_size
|
||||
weighted_progress += layer.calculate_progress() * layer_weight
|
||||
|
||||
return weighted_progress
|
||||
|
||||
def get_stage(self) -> str | None:
|
||||
"""Get current stage based on layer states."""
|
||||
if not self.layers:
|
||||
|
||||
@@ -154,6 +154,9 @@ async def docker() -> DockerAPI:
|
||||
docker_obj.info.storage = "overlay2"
|
||||
docker_obj.info.version = AwesomeVersion("1.0.0")
|
||||
|
||||
# Mock manifest fetcher to return None (falls back to count-based progress)
|
||||
docker_obj._manifest_fetcher.get_manifest = AsyncMock(return_value=None)
|
||||
|
||||
yield docker_obj
|
||||
|
||||
|
||||
|
||||
164
tests/docker/test_manifest.py
Normal file
164
tests/docker/test_manifest.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Tests for registry manifest fetcher."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from supervisor.docker.manifest import (
|
||||
DEFAULT_REGISTRY,
|
||||
ImageManifest,
|
||||
RegistryManifestFetcher,
|
||||
parse_image_reference,
|
||||
)
|
||||
|
||||
|
||||
class TestParseImageReference:
|
||||
"""Tests for parse_image_reference function."""
|
||||
|
||||
def test_ghcr_io_image(self):
|
||||
"""Test parsing ghcr.io image."""
|
||||
registry, repo, tag = parse_image_reference(
|
||||
"ghcr.io/home-assistant/home-assistant", "2025.1.0"
|
||||
)
|
||||
assert registry == "ghcr.io"
|
||||
assert repo == "home-assistant/home-assistant"
|
||||
assert tag == "2025.1.0"
|
||||
|
||||
def test_docker_hub_with_org(self):
|
||||
"""Test parsing Docker Hub image with organization."""
|
||||
registry, repo, tag = parse_image_reference(
|
||||
"homeassistant/home-assistant", "latest"
|
||||
)
|
||||
assert registry == DEFAULT_REGISTRY
|
||||
assert repo == "homeassistant/home-assistant"
|
||||
assert tag == "latest"
|
||||
|
||||
def test_docker_hub_official_image(self):
|
||||
"""Test parsing Docker Hub official image (no org)."""
|
||||
registry, repo, tag = parse_image_reference("alpine", "3.18")
|
||||
assert registry == DEFAULT_REGISTRY
|
||||
assert repo == "library/alpine"
|
||||
assert tag == "3.18"
|
||||
|
||||
def test_gcr_io_image(self):
|
||||
"""Test parsing gcr.io image."""
|
||||
registry, repo, tag = parse_image_reference("gcr.io/project/image", "v1")
|
||||
assert registry == "gcr.io"
|
||||
assert repo == "project/image"
|
||||
assert tag == "v1"
|
||||
|
||||
|
||||
class TestImageManifest:
|
||||
"""Tests for ImageManifest dataclass."""
|
||||
|
||||
def test_layer_count(self):
|
||||
"""Test layer_count property."""
|
||||
manifest = ImageManifest(
|
||||
digest="sha256:abc",
|
||||
total_size=1000,
|
||||
layers={"layer1": 500, "layer2": 500},
|
||||
)
|
||||
assert manifest.layer_count == 2
|
||||
|
||||
|
||||
class TestRegistryManifestFetcher:
|
||||
"""Tests for RegistryManifestFetcher class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_coresys(self):
|
||||
"""Create mock coresys."""
|
||||
coresys = MagicMock()
|
||||
coresys.docker.config.registries = {}
|
||||
return coresys
|
||||
|
||||
@pytest.fixture
|
||||
def fetcher(self, mock_coresys):
|
||||
"""Create fetcher instance."""
|
||||
return RegistryManifestFetcher(mock_coresys)
|
||||
|
||||
async def test_get_manifest_success(self, fetcher):
|
||||
"""Test successful manifest fetch by mocking internal methods."""
|
||||
manifest_data = {
|
||||
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
"config": {"digest": "sha256:abc123"},
|
||||
"layers": [
|
||||
{"digest": "sha256:layer1abc123def456789012", "size": 1000},
|
||||
{"digest": "sha256:layer2def456abc789012345", "size": 2000},
|
||||
],
|
||||
}
|
||||
|
||||
# Mock the internal methods instead of the session
|
||||
with (
|
||||
patch.object(
|
||||
fetcher, "_get_auth_token", new=AsyncMock(return_value="test-token")
|
||||
),
|
||||
patch.object(
|
||||
fetcher, "_fetch_manifest", new=AsyncMock(return_value=manifest_data)
|
||||
),
|
||||
patch.object(fetcher, "_get_session", new=AsyncMock()),
|
||||
):
|
||||
result = await fetcher.get_manifest(
|
||||
"test.io/org/image", "v1.0", platform="linux/amd64"
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.total_size == 3000
|
||||
assert result.layer_count == 2
|
||||
# First 12 chars after sha256:
|
||||
assert "layer1abc123" in result.layers
|
||||
assert result.layers["layer1abc123"] == 1000
|
||||
|
||||
async def test_get_manifest_returns_none_on_failure(self, fetcher):
|
||||
"""Test that get_manifest returns None on failure."""
|
||||
with (
|
||||
patch.object(
|
||||
fetcher, "_get_auth_token", new=AsyncMock(return_value="test-token")
|
||||
),
|
||||
patch.object(fetcher, "_fetch_manifest", new=AsyncMock(return_value=None)),
|
||||
patch.object(fetcher, "_get_session", new=AsyncMock()),
|
||||
):
|
||||
result = await fetcher.get_manifest(
|
||||
"test.io/org/image", "v1.0", platform="linux/amd64"
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_close_session(self, fetcher):
|
||||
"""Test session cleanup."""
|
||||
mock_session = AsyncMock()
|
||||
mock_session.closed = False
|
||||
mock_session.close = AsyncMock()
|
||||
fetcher._session = mock_session
|
||||
|
||||
await fetcher.close()
|
||||
|
||||
mock_session.close.assert_called_once()
|
||||
assert fetcher._session is None
|
||||
|
||||
def test_get_credentials_docker_hub(self, mock_coresys, fetcher):
|
||||
"""Test getting Docker Hub credentials."""
|
||||
mock_coresys.docker.config.registries = {
|
||||
"hub.docker.com": {"username": "user", "password": "pass"}
|
||||
}
|
||||
|
||||
creds = fetcher._get_credentials(DEFAULT_REGISTRY)
|
||||
|
||||
assert creds == ("user", "pass")
|
||||
|
||||
def test_get_credentials_custom_registry(self, mock_coresys, fetcher):
|
||||
"""Test getting credentials for custom registry."""
|
||||
mock_coresys.docker.config.registries = {
|
||||
"ghcr.io": {"username": "user", "password": "token"}
|
||||
}
|
||||
|
||||
creds = fetcher._get_credentials("ghcr.io")
|
||||
|
||||
assert creds == ("user", "token")
|
||||
|
||||
def test_get_credentials_not_found(self, mock_coresys, fetcher):
|
||||
"""Test no credentials found."""
|
||||
mock_coresys.docker.config.registries = {}
|
||||
|
||||
creds = fetcher._get_credentials("unknown.io")
|
||||
|
||||
assert creds is None
|
||||
@@ -3,6 +3,7 @@
|
||||
import pytest
|
||||
|
||||
from supervisor.docker.manager import PullLogEntry, PullProgressDetail
|
||||
from supervisor.docker.manifest import ImageManifest
|
||||
from supervisor.docker.pull_progress import (
|
||||
DOWNLOAD_WEIGHT,
|
||||
EXTRACT_WEIGHT,
|
||||
@@ -784,3 +785,218 @@ class TestImagePullProgress:
|
||||
)
|
||||
|
||||
assert progress.calculate_progress() == 100.0
|
||||
|
||||
def test_size_weighted_progress_with_manifest(self):
|
||||
"""Test size-weighted progress when manifest layer sizes are known."""
|
||||
# Create manifest with known layer sizes
|
||||
# Small layer: 1KB, Large layer: 100KB
|
||||
manifest = ImageManifest(
|
||||
digest="sha256:test",
|
||||
total_size=101000,
|
||||
layers={
|
||||
"small123456": 1000, # 1KB - ~1% of total
|
||||
"large123456": 100000, # 100KB - ~99% of total
|
||||
},
|
||||
)
|
||||
|
||||
progress = ImagePullProgress()
|
||||
progress.set_manifest(manifest)
|
||||
|
||||
# Layer events - small layer first
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="small123456",
|
||||
status="Pulling fs layer",
|
||||
progress_detail=PullProgressDetail(),
|
||||
)
|
||||
)
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="large123456",
|
||||
status="Pulling fs layer",
|
||||
progress_detail=PullProgressDetail(),
|
||||
)
|
||||
)
|
||||
|
||||
# Small layer downloads completely
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="small123456",
|
||||
status="Downloading",
|
||||
progress_detail=PullProgressDetail(current=1000, total=1000),
|
||||
)
|
||||
)
|
||||
|
||||
# Size-weighted: small layer is ~1% of total size
|
||||
# Small layer at 70% (download done) = contributes ~0.7% to overall
|
||||
assert progress.calculate_progress() == pytest.approx(0.69, rel=0.1)
|
||||
|
||||
# Large layer starts downloading (1% of its size)
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="large123456",
|
||||
status="Downloading",
|
||||
progress_detail=PullProgressDetail(current=1000, total=100000),
|
||||
)
|
||||
)
|
||||
|
||||
# Large layer at 1% download = contributes ~0.7% (1% * 70% * 99% weight)
|
||||
# Total: ~0.7% + ~0.7% = ~1.4%
|
||||
current = progress.calculate_progress()
|
||||
assert current > 0.7 # More than just small layer
|
||||
assert current < 5.0 # But not much more
|
||||
|
||||
# Complete both layers
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="small123456",
|
||||
status="Pull complete",
|
||||
progress_detail=PullProgressDetail(),
|
||||
)
|
||||
)
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="large123456",
|
||||
status="Pull complete",
|
||||
progress_detail=PullProgressDetail(),
|
||||
)
|
||||
)
|
||||
|
||||
assert progress.calculate_progress() == 100.0
|
||||
|
||||
def test_size_weighted_excludes_already_exists(self):
|
||||
"""Test that already existing layers are excluded from size-weighted progress."""
|
||||
# Manifest has 3 layers, but one will already exist locally
|
||||
manifest = ImageManifest(
|
||||
digest="sha256:test",
|
||||
total_size=200000,
|
||||
layers={
|
||||
"cached12345": 100000, # Will be cached - shouldn't count
|
||||
"layer1_1234": 50000, # Needs pulling
|
||||
"layer2_1234": 50000, # Needs pulling
|
||||
},
|
||||
)
|
||||
|
||||
progress = ImagePullProgress()
|
||||
progress.set_manifest(manifest)
|
||||
|
||||
# Cached layer already exists
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="cached12345",
|
||||
status="Already exists",
|
||||
progress_detail=PullProgressDetail(),
|
||||
)
|
||||
)
|
||||
|
||||
# Other layers need pulling
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="layer1_1234",
|
||||
status="Pulling fs layer",
|
||||
progress_detail=PullProgressDetail(),
|
||||
)
|
||||
)
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="layer2_1234",
|
||||
status="Pulling fs layer",
|
||||
progress_detail=PullProgressDetail(),
|
||||
)
|
||||
)
|
||||
|
||||
# Start downloading layer1 (50% of its size)
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="layer1_1234",
|
||||
status="Downloading",
|
||||
progress_detail=PullProgressDetail(current=25000, total=50000),
|
||||
)
|
||||
)
|
||||
|
||||
# layer1 is 50% of total that needs pulling (50KB out of 100KB)
|
||||
# At 50% download = 35% layer progress (70% * 50%)
|
||||
# Size-weighted: 50% * 35% = 17.5%
|
||||
assert progress.calculate_progress() == pytest.approx(17.5)
|
||||
|
||||
# Complete layer1
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="layer1_1234",
|
||||
status="Pull complete",
|
||||
progress_detail=PullProgressDetail(),
|
||||
)
|
||||
)
|
||||
|
||||
# layer1 at 100%, layer2 at 0%
|
||||
# Size-weighted: 50% * 100% + 50% * 0% = 50%
|
||||
assert progress.calculate_progress() == pytest.approx(50.0)
|
||||
|
||||
def test_fallback_to_count_based_without_manifest(self):
|
||||
"""Test that without manifest, count-based progress is used."""
|
||||
progress = ImagePullProgress()
|
||||
|
||||
# No manifest set - should use count-based progress
|
||||
|
||||
# Two layers of different sizes
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="small",
|
||||
status="Pulling fs layer",
|
||||
progress_detail=PullProgressDetail(),
|
||||
)
|
||||
)
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="large",
|
||||
status="Pulling fs layer",
|
||||
progress_detail=PullProgressDetail(),
|
||||
)
|
||||
)
|
||||
|
||||
# Small layer (1KB) completes
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="small",
|
||||
status="Downloading",
|
||||
progress_detail=PullProgressDetail(current=1000, total=1000),
|
||||
)
|
||||
)
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="small",
|
||||
status="Pull complete",
|
||||
progress_detail=PullProgressDetail(),
|
||||
)
|
||||
)
|
||||
|
||||
# Large layer (100MB) at 1%
|
||||
progress.process_event(
|
||||
PullLogEntry(
|
||||
job_id="test",
|
||||
id="large",
|
||||
status="Downloading",
|
||||
progress_detail=PullProgressDetail(current=1000000, total=100000000),
|
||||
)
|
||||
)
|
||||
|
||||
# Count-based: each layer is 50% weight
|
||||
# small: 100% * 50% = 50%
|
||||
# large: 0.7% (1% * 70%) * 50% = 0.35%
|
||||
# Total: ~50.35%
|
||||
assert progress.calculate_progress() == pytest.approx(50.35, rel=0.01)
|
||||
|
||||
Reference in New Issue
Block a user