server/internal/client/ollama: handle some network errors gracefully (#10317)
This commit is contained in:
parent
09bb2e30f6
commit
1d99451ad7
@ -223,8 +223,21 @@ type Registry struct {
|
||||
ChunkingThreshold int64
|
||||
|
||||
// Mask, if set, is the name used to convert non-fully qualified names
|
||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||
// to fully qualified names.
|
||||
// If empty, [DefaultMask] is used.
|
||||
Mask string
|
||||
|
||||
// ReadTimeout is the maximum duration for reading the entire request,
|
||||
// including the body.
|
||||
// A zero or negative value means there will be no timeout.
|
||||
ReadTimeout time.Duration
|
||||
}
|
||||
|
||||
func (r *Registry) readTimeout() time.Duration {
|
||||
if r.ReadTimeout > 0 {
|
||||
return r.ReadTimeout
|
||||
}
|
||||
return 1<<63 - 1 // no timeout, max int
|
||||
}
|
||||
|
||||
func (r *Registry) cache() (*blob.DiskCache, error) {
|
||||
@ -248,8 +261,7 @@ func (r *Registry) parseName(name string) (names.Name, error) {
|
||||
|
||||
// DefaultRegistry returns a new Registry configured from the environment. The
|
||||
// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
|
||||
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
|
||||
// system's temporary directory.
|
||||
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ReadTimeout is set to 30 seconds.
|
||||
//
|
||||
// It returns an error if any configuration in the environment is invalid.
|
||||
func DefaultRegistry() (*Registry, error) {
|
||||
@ -263,6 +275,7 @@ func DefaultRegistry() (*Registry, error) {
|
||||
}
|
||||
|
||||
var rc Registry
|
||||
rc.ReadTimeout = 30 * time.Second
|
||||
rc.UserAgent = UserAgent()
|
||||
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
|
||||
if err != nil {
|
||||
@ -489,6 +502,12 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
for _, l := range layers {
|
||||
var received atomic.Int64
|
||||
update := func(n int64, err error) {
|
||||
if n == 0 && err == nil {
|
||||
// Clients expect an update with no progress and no error to mean "starting download".
|
||||
// This is not the case here,
|
||||
// so we don't want to send an update in this case.
|
||||
return
|
||||
}
|
||||
completed.Add(n)
|
||||
t.update(l, received.Add(n), err)
|
||||
}
|
||||
@ -562,6 +581,20 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
|
||||
timer := time.AfterFunc(r.readTimeout(), func() {
|
||||
cancel(fmt.Errorf("%w: downloading %s %d-%d/%d",
|
||||
context.DeadlineExceeded,
|
||||
cs.Digest.Short(),
|
||||
cs.Chunk.Start,
|
||||
cs.Chunk.End,
|
||||
l.Size,
|
||||
))
|
||||
})
|
||||
defer timer.Stop()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -574,8 +607,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
defer res.Body.Close()
|
||||
|
||||
tr := &trackingReader{
|
||||
r: res.Body,
|
||||
update: update,
|
||||
r: res.Body,
|
||||
update: func(n int64, err error) {
|
||||
timer.Reset(r.readTimeout())
|
||||
update(n, err)
|
||||
},
|
||||
}
|
||||
if err := chunked.Put(cs.Chunk, cs.Digest, tr); err != nil {
|
||||
return err
|
||||
@ -930,12 +966,6 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R
|
||||
// is parsed from the response body and returned. If any other error occurs, it
|
||||
// is returned.
|
||||
func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("request error %s: %w", r.URL, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if r.URL.Scheme == "https+insecure" {
|
||||
// TODO(bmizerany): clone client.Transport, set
|
||||
// InsecureSkipVerify, etc.
|
||||
|
51
server/internal/client/ollama/registry_synctest_test.go
Normal file
51
server/internal/client/ollama/registry_synctest_test.go
Normal file
@ -0,0 +1,51 @@
|
||||
// TODO: go:build goexperiment.synctest
|
||||
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPullDownloadTimeout(t *testing.T) {
|
||||
rc, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
defer t.Log("upstream", r.Method, r.URL.Path)
|
||||
switch {
|
||||
case strings.HasPrefix(r.URL.Path, "/v2/library/smol/manifests/"):
|
||||
io.WriteString(w, `{
|
||||
"layers": [{"digest": "sha256:1111111111111111111111111111111111111111111111111111111111111111", "size": 3}]
|
||||
}`)
|
||||
case strings.HasPrefix(r.URL.Path, "/v2/library/smol/blobs/sha256:1111111111111111111111111111111111111111111111111111111111111111"):
|
||||
// Get headers out to client and then hang on the response
|
||||
w.WriteHeader(200)
|
||||
w.(http.Flusher).Flush()
|
||||
|
||||
// Hang on the response and unblock when the client
|
||||
// gives up
|
||||
<-r.Context().Done()
|
||||
default:
|
||||
t.Fatalf("unexpected request: %s", r.URL.Path)
|
||||
}
|
||||
})
|
||||
rc.ReadTimeout = 100 * time.Millisecond
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- rc.Pull(ctx, "http://example.com/library/smol")
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
want := context.DeadlineExceeded
|
||||
if !errors.Is(err, want) {
|
||||
t.Errorf("err = %v, want %v", err, want)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Error("timeout waiting for Pull to finish")
|
||||
}
|
||||
}
|
@ -605,8 +605,8 @@ func checkRequest(t *testing.T, req *http.Request, method, path string) {
|
||||
}
|
||||
}
|
||||
|
||||
func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) {
|
||||
s := httptest.NewServer(h)
|
||||
func newRegistryClient(t *testing.T, upstream http.HandlerFunc) (*Registry, context.Context) {
|
||||
s := httptest.NewServer(upstream)
|
||||
t.Cleanup(s.Close)
|
||||
cache, err := blob.Open(t.TempDir())
|
||||
if err != nil {
|
||||
|
@ -4,6 +4,7 @@ package registry
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -11,6 +12,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -330,9 +332,8 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
return err
|
||||
}
|
||||
err := s.Client.Pull(ctx, p.model())
|
||||
var oe *ollama.Error
|
||||
if errors.As(err, &oe) && oe.Temporary() {
|
||||
continue // retry
|
||||
if canRetry(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
@ -390,3 +391,20 @@ func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||
}
|
||||
return zero, err
|
||||
}
|
||||
|
||||
func canRetry(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var oe *ollama.Error
|
||||
if errors.As(err, &oe) {
|
||||
return oe.Temporary()
|
||||
}
|
||||
s := err.Error()
|
||||
return cmp.Or(
|
||||
errors.Is(err, context.DeadlineExceeded),
|
||||
strings.Contains(s, "unreachable"),
|
||||
strings.Contains(s, "no route to host"),
|
||||
strings.Contains(s, "connection reset by peer"),
|
||||
)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user