server/internal/client/ollama: handle some network errors gracefully (#10317)

This commit is contained in:
Blake Mizerany 2025-04-17 12:43:09 -07:00 committed by GitHub
parent 09bb2e30f6
commit 1d99451ad7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 115 additions and 16 deletions

View File

@ -223,8 +223,21 @@ type Registry struct {
ChunkingThreshold int64
// Mask, if set, is the name used to convert non-fully qualified names
// to fully qualified names. If empty, [DefaultMask] is used.
// to fully qualified names.
// If empty, [DefaultMask] is used.
Mask string
// ReadTimeout is the maximum duration for reading the entire request,
// including the body.
// A zero or negative value means there will be no timeout.
ReadTimeout time.Duration
}
func (r *Registry) readTimeout() time.Duration {
if r.ReadTimeout > 0 {
return r.ReadTimeout
}
return 1<<63 - 1 // no timeout, max int
}
func (r *Registry) cache() (*blob.DiskCache, error) {
@ -248,8 +261,7 @@ func (r *Registry) parseName(name string) (names.Name, error) {
// DefaultRegistry returns a new Registry configured from the environment. The
// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
// system's temporary directory.
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ReadTimeout is set to 30 seconds.
//
// It returns an error if any configuration in the environment is invalid.
func DefaultRegistry() (*Registry, error) {
@ -263,6 +275,7 @@ func DefaultRegistry() (*Registry, error) {
}
var rc Registry
rc.ReadTimeout = 30 * time.Second
rc.UserAgent = UserAgent()
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
if err != nil {
@ -489,6 +502,12 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
for _, l := range layers {
var received atomic.Int64
update := func(n int64, err error) {
if n == 0 && err == nil {
// Clients expect an update with no progress and no error to mean "starting download".
// This is not the case here,
// so we don't want to send an update in this case.
return
}
completed.Add(n)
t.update(l, received.Add(n), err)
}
@ -562,6 +581,20 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
}
}()
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
timer := time.AfterFunc(r.readTimeout(), func() {
cancel(fmt.Errorf("%w: downloading %s %d-%d/%d",
context.DeadlineExceeded,
cs.Digest.Short(),
cs.Chunk.Start,
cs.Chunk.End,
l.Size,
))
})
defer timer.Stop()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
@ -574,8 +607,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
defer res.Body.Close()
tr := &trackingReader{
r: res.Body,
update: update,
r: res.Body,
update: func(n int64, err error) {
timer.Reset(r.readTimeout())
update(n, err)
},
}
if err := chunked.Put(cs.Chunk, cs.Digest, tr); err != nil {
return err
@ -930,12 +966,6 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R
// is parsed from the response body and returned. If any other error occurs, it
// is returned.
func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("request error %s: %w", r.URL, err)
}
}()
if r.URL.Scheme == "https+insecure" {
// TODO(bmizerany): clone client.Transport, set
// InsecureSkipVerify, etc.

View File

@ -0,0 +1,51 @@
// TODO: go:build goexperiment.synctest
package ollama
import (
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"time"
)
func TestPullDownloadTimeout(t *testing.T) {
rc, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
defer t.Log("upstream", r.Method, r.URL.Path)
switch {
case strings.HasPrefix(r.URL.Path, "/v2/library/smol/manifests/"):
io.WriteString(w, `{
"layers": [{"digest": "sha256:1111111111111111111111111111111111111111111111111111111111111111", "size": 3}]
}`)
case strings.HasPrefix(r.URL.Path, "/v2/library/smol/blobs/sha256:1111111111111111111111111111111111111111111111111111111111111111"):
// Get headers out to client and then hang on the response
w.WriteHeader(200)
w.(http.Flusher).Flush()
// Hang on the response and unblock when the client
// gives up
<-r.Context().Done()
default:
t.Fatalf("unexpected request: %s", r.URL.Path)
}
})
rc.ReadTimeout = 100 * time.Millisecond
done := make(chan error, 1)
go func() {
done <- rc.Pull(ctx, "http://example.com/library/smol")
}()
select {
case err := <-done:
want := context.DeadlineExceeded
if !errors.Is(err, want) {
t.Errorf("err = %v, want %v", err, want)
}
case <-time.After(3 * time.Second):
t.Error("timeout waiting for Pull to finish")
}
}

View File

@ -605,8 +605,8 @@ func checkRequest(t *testing.T, req *http.Request, method, path string) {
}
}
func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) {
s := httptest.NewServer(h)
func newRegistryClient(t *testing.T, upstream http.HandlerFunc) (*Registry, context.Context) {
s := httptest.NewServer(upstream)
t.Cleanup(s.Close)
cache, err := blob.Open(t.TempDir())
if err != nil {

View File

@ -4,6 +4,7 @@ package registry
import (
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
@ -11,6 +12,7 @@ import (
"log/slog"
"net/http"
"slices"
"strings"
"sync"
"time"
@ -330,9 +332,8 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
return err
}
err := s.Client.Pull(ctx, p.model())
var oe *ollama.Error
if errors.As(err, &oe) && oe.Temporary() {
continue // retry
if canRetry(err) {
continue
}
return err
}
@ -390,3 +391,20 @@ func decodeUserJSON[T any](r io.Reader) (T, error) {
}
return zero, err
}
func canRetry(err error) bool {
if err == nil {
return false
}
var oe *ollama.Error
if errors.As(err, &oe) {
return oe.Temporary()
}
s := err.Error()
return cmp.Or(
errors.Is(err, context.DeadlineExceeded),
strings.Contains(s, "unreachable"),
strings.Contains(s, "no route to host"),
strings.Contains(s, "connection reset by peer"),
)
}