server/internal/client/ollama: confirm all chunksums were received (#9893)

If the chunksums response is missing a chunk, the client should fail
the download. This changes the client to check that all bytes are
accounted for in the chunksums response.

It is possible there are overlaps or gaps in the chunksums response and
so the size is not the only thing left to check, but this provides
enough coverage for now. We may want to check that chunks are contiguous
later.
This commit is contained in:
Blake Mizerany 2025-03-19 14:59:57 -07:00 committed by GitHub
parent da0e345200
commit 2ddacd7516
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 134 additions and 78 deletions

View File

@ -37,7 +37,6 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/internal/backoff"
"github.com/ollama/ollama/server/internal/internal/names" "github.com/ollama/ollama/server/internal/internal/names"
_ "embed" _ "embed"
@ -213,12 +212,6 @@ type Registry struct {
// request. If zero, [DefaultChunkingThreshold] is used. // request. If zero, [DefaultChunkingThreshold] is used.
ChunkingThreshold int64 ChunkingThreshold int64
// MaxChunkSize is the maximum size of a chunk to download. If zero,
// the default is [DefaultMaxChunkSize].
//
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
// Mask, if set, is the name used to convert non-fully qualified names // Mask, if set, is the name used to convert non-fully qualified names
// to fully qualified names. If empty, [DefaultMask] is used. // to fully qualified names. If empty, [DefaultMask] is used.
Mask string Mask string
@ -447,6 +440,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
if err != nil { if err != nil {
return err return err
} }
// TODO(bmizerany): decide if this should be considered valid. Maybe
// server-side we special case '{}' to have some special meaning? Maybe
// "archiving" a tag (which is how we reason about it in the registry
// already, just with a different twist).
if len(m.Layers) == 0 { if len(m.Layers) == 0 {
return fmt.Errorf("%w: no layers", ErrManifestInvalid) return fmt.Errorf("%w: no layers", ErrManifestInvalid)
} }
@ -456,11 +454,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err return err
} }
exists := func(l *Layer) bool { // TODO(bmizerany): work to remove the need to do this
info, err := c.Get(l.Digest)
return err == nil && info.Size == l.Size
}
layers := m.Layers layers := m.Layers
if m.Config != nil && m.Config.Digest.IsValid() { if m.Config != nil && m.Config.Digest.IsValid() {
layers = append(layers, m.Config) layers = append(layers, m.Config)
@ -469,19 +463,16 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// Send initial layer trace events to allow clients to have an // Send initial layer trace events to allow clients to have an
// understanding of work to be done before work starts. // understanding of work to be done before work starts.
t := traceFromContext(ctx) t := traceFromContext(ctx)
skip := make([]bool, len(layers)) for _, l := range layers {
for i, l := range layers {
t.update(l, 0, nil) t.update(l, 0, nil)
if exists(l) {
skip[i] = true
t.update(l, l.Size, ErrCached)
}
} }
g, ctx := errgroup.WithContext(ctx) var g errgroup.Group
g.SetLimit(r.maxStreams()) g.SetLimit(r.maxStreams())
for i, l := range layers { for _, l := range layers {
if skip[i] { info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size {
t.update(l, l.Size, ErrCached)
continue continue
} }
@ -490,63 +481,50 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
t.update(l, 0, err) t.update(l, 0, err)
continue continue
} }
// TODO(bmizerany): fix this unbounded use of defer
defer chunked.Close() defer chunked.Close()
var progress atomic.Int64 var progress atomic.Int64
for cs, err := range r.chunksums(ctx, name, l) { for cs, err := range r.chunksums(ctx, name, l) {
if err != nil { if err != nil {
// Bad chunksums response, update tracing
// clients and then bail.
t.update(l, progress.Load(), err) t.update(l, progress.Load(), err)
break return err
} }
g.Go(func() (err error) { g.Go(func() (err error) {
defer func() { t.update(l, progress.Load(), err) }() defer func() {
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil { if err != nil {
return err err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
} }
err := func() error { t.update(l, progress.Load(), err)
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) }()
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
// Count bytes towards req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
// progress, as they arrive, so if err != nil {
// that our bytes piggyback return err
// other chunk updates on
// completion.
//
// This tactic is enough to
// show "smooth" progress given
// the current CLI client. In
// the near future, the server
// should report download rate
// since it knows better than
// a client that is measuring
// rate based on wall-clock
// time-since-last-update.
body := &trackingReader{r: res.Body, n: &progress}
err = chunked.Put(cs.Chunk, cs.Digest, body)
if err != nil {
return err
}
return nil
}()
if !canRetry(err) {
return err
}
} }
return nil req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
// Count bytes towards progress, as they
// arrive, so that our bytes piggyback other
// chunk updates on completion.
//
// This tactic is enough to show "smooth"
// progress given the current CLI client. In
// the near future, the server should report
// download rate since it knows better than a
// client that is measuring rate based on
// wall-clock time-since-last-update.
body := &trackingReader{r: res.Body, n: &progress}
return chunked.Put(cs.Chunk, cs.Digest, body)
}) })
} }
} }
@ -554,13 +532,10 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err return err
} }
// store the manifest blob
md := blob.DigestFromBytes(m.Data) md := blob.DigestFromBytes(m.Data)
if err := blob.PutBytes(c, md, m.Data); err != nil { if err := blob.PutBytes(c, md, m.Data); err != nil {
return err return err
} }
// commit the manifest with a link
return c.Link(m.Name, md) return c.Link(m.Name, md)
} }
@ -782,12 +757,15 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
} }
blobURL := res.Header.Get("Content-Location") blobURL := res.Header.Get("Content-Location")
var size int64
s := bufio.NewScanner(res.Body) s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords) s.Split(bufio.ScanWords)
for { for {
if !s.Scan() { if !s.Scan() {
if s.Err() != nil { if s.Err() != nil {
yield(chunksum{}, s.Err()) yield(chunksum{}, s.Err())
} else if size != l.Size {
yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size))
} }
return return
} }
@ -811,6 +789,12 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
return return
} }
size += chunk.Size()
if size > l.Size {
yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size))
return
}
cs := chunksum{ cs := chunksum{
URL: blobURL, URL: blobURL,
Chunk: chunk, Chunk: chunk,

View File

@ -17,6 +17,7 @@ import (
"reflect" "reflect"
"slices" "slices"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -56,21 +57,21 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
// newClient constructs a cache with predefined manifests for testing. The manifests are: // newClient constructs a cache with predefined manifests for testing. The manifests are:
// //
// empty: no data // empty: no data
// zero: no layers // zero: no layers
// single: one layer with the contents "exists" // single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here" // multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache // notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null]) // null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size) // sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data // invalid: a layer with invalid JSON data
// //
// Tests that want to ensure the client does not communicate with the upstream // Tests that want to ensure the client does not communicate with the upstream
// registry should pass a nil handler, which will cause a panic if // registry should pass a nil handler, which will cause a panic if
// communication is attempted. // communication is attempted.
// //
// To simulate a network error, pass a handler that returns a 499 status code. // To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper() t.Helper()
c, err := blob.Open(t.TempDir()) c, err := blob.Open(t.TempDir())
@ -88,7 +89,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
r := &Registry{ r := &Registry{
Cache: c, Cache: c,
HTTPClient: &http.Client{ HTTPClient: &http.Client{
Transport: recordRoundTripper(h), Transport: recordRoundTripper(upstreamRegistry),
}, },
} }
@ -767,3 +768,74 @@ func TestUnlink(t *testing.T) {
} }
}) })
} }
func TestPullChunksums(t *testing.T) {
check := testutil.Checker(t)
content := "hello"
var chunksums string
contentDigest := func() blob.Digest {
return blob.DigestFromBytes(content)
}
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/manifests/latest"):
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
w.Header().Set("Content-Location", loc)
io.WriteString(w, chunksums)
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
default:
t.Errorf("unexpected request: %v", r)
http.NotFound(w, r)
}
})
rc.MaxStreams = 1 // prevent concurrent chunk downloads
rc.ChunkingThreshold = 1 // for all blobs to be chunked
var mu sync.Mutex
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("Update: %v %d %v", l, n, err)
mu.Lock()
reads = append(reads, n)
mu.Unlock()
},
})
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
blob.DigestFromBytes("hel"),
blob.DigestFromBytes("lo"),
)
err := rc.Pull(ctx, "test")
check(err)
if !slices.Equal(reads, []int64{0, 3, 5}) {
t.Errorf("reads = %v; want %v", reads, []int64{0, 3, 5})
}
mw, err := rc.Resolve(t.Context(), "test")
check(err)
mg, err := rc.ResolveLocal("test")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
for i := range mg.Layers {
_, err = c.Get(mg.Layers[i].Digest)
if err != nil {
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
}
}
// missing chunks
content = "llama"
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
err = rc.Pull(ctx, "missingchunks")
if err == nil {
t.Error("expected error because of missing chunks")
}
}