From e2252d0fc6ea5c410b1ac4fa0a722beda78b3431 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Wed, 5 Mar 2025 14:48:18 -0800 Subject: [PATCH] server/internal/registry: take over pulls from server package (#9485) This commit replaces the old pull implementation in the server package with the new, faster, more robust pull implementation in the registry package. The new endpoint, and now the remove endpoint too, are behind the feature gate "client2" enabled only by setting the OLLAMA_EXPERIMENT environment variable include "client2". Currently, the progress indication is wired to perform the same as the previous implementation to avoid making changes to the CLI, and because the status reports happen at the start of the download, and the end of the write to disk, the progress indication is not as smooth as it could be. This is a known issue and will be addressed in a future change. This implementation may be ~0.5-1.0% slower in rare cases, depending on network and disk speed, but is generally MUCH faster and more robust than the its predecessor in all other cases. --- api/types.go | 6 +- go.mod | 1 + go.sum | 2 + server/internal/client/ollama/registry.go | 117 +++++++++++----- .../internal/client/ollama/registry_test.go | 2 +- server/internal/client/ollama/trace.go | 10 +- server/internal/registry/server.go | 97 +++++++++++++ server/internal/registry/server_test.go | 130 +++++++++++++++++- .../library/smol/latest | 0 .../internal/registry/testdata/registry.txt | 22 +++ server/routes.go | 35 +++-- 11 files changed, 370 insertions(+), 52 deletions(-) rename server/internal/registry/testdata/models/manifests/{registry.ollama.ai => example.com}/library/smol/latest (100%) create mode 100644 server/internal/registry/testdata/registry.txt diff --git a/api/types.go b/api/types.go index 637ca2042..fef836bd6 100644 --- a/api/types.go +++ b/api/types.go @@ -361,9 +361,9 @@ type CopyRequest struct { // PullRequest is the request passed to [Client.Pull]. type PullRequest struct { Model string `json:"model"` - Insecure bool `json:"insecure,omitempty"` - Username string `json:"username"` - Password string `json:"password"` + Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored + Username string `json:"username"` // Deprecated: ignored + Password string `json:"password"` // Deprecated: ignored Stream *bool `json:"stream,omitempty"` // Deprecated: set the model name with Model instead diff --git a/go.mod b/go.mod index af0cedc86..c45c9892c 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/nlpodyssey/gopickle v0.3.0 github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c golang.org/x/image v0.22.0 + golang.org/x/tools v0.30.0 gonum.org/v1/gonum v0.15.0 ) diff --git a/go.sum b/go.sum index 013a7db71..0ab97b909 100644 --- a/go.sum +++ b/go.sum @@ -309,6 +309,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= +golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index 007de5e8a..423a6ad23 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -45,9 +45,9 @@ import ( // Errors var ( - // ErrManifestNotFound is returned when a manifest is not found in the + // ErrModelNotFound is returned when a manifest is not found in the // cache or registry. - ErrManifestNotFound = errors.New("manifest not found") + ErrModelNotFound = errors.New("model not found") // ErrManifestInvalid is returned when a manifest found in a local or // remote cache is invalid. @@ -114,7 +114,18 @@ type Error struct { } func (e *Error) Error() string { - return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message) + var b strings.Builder + b.WriteString("registry responded with status ") + b.WriteString(strconv.Itoa(e.Status)) + if e.Code != "" { + b.WriteString(": code ") + b.WriteString(e.Code) + } + if e.Message != "" { + b.WriteString(": ") + b.WriteString(e.Message) + } + return b.String() } func (e *Error) LogValue() slog.Value { @@ -355,7 +366,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { n.Model(), l.Digest, ) - res, err := r.doOK(ctx, "POST", startURL, nil) + res, err := r.send(ctx, "POST", startURL, nil) if err != nil { return err } @@ -379,7 +390,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { } req.ContentLength = l.Size - res, err = doOK(r.client(), req) + res, err = sendRequest(r.client(), req) if err == nil { res.Body.Close() } @@ -399,7 +410,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { n.Model(), n.Tag(), ) - res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data)) + res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data)) if err == nil { res.Body.Close() } @@ -448,10 +459,15 @@ func (r *Registry) Pull(ctx context.Context, name string) error { t := traceFromContext(ctx) - var g errgroup.Group + g, ctx := errgroup.WithContext(ctx) g.SetLimit(r.maxStreams()) - for _, l := range m.Layers { + layers := m.Layers + if m.Config != nil && m.Config.Digest.IsValid() { + layers = append(layers, m.Config) + } + + for _, l := range layers { if exists(l) { t.update(l, l.Size, ErrCached) continue @@ -468,7 +484,9 @@ func (r *Registry) Pull(ctx context.Context, name string) error { if l.Size <= r.maxChunkingThreshold() { g.Go(func() error { - res, err := doOK(r.client(), req) + // TODO(bmizerany): retry/backoff like below in + // the chunking case + res, err := sendRequest(r.client(), req) if err != nil { return err } @@ -494,19 +512,21 @@ func (r *Registry) Pull(ctx context.Context, name string) error { // fire an initial request to get the final URL and // then use that URL for the chunk requests. req.Header.Set("Range", "bytes=0-0") - res, err := doOK(r.client(), req) + res, err := sendRequest(r.client(), req) if err != nil { return err } res.Body.Close() req = res.Request.WithContext(req.Context()) - streamNo := 0 - tws := make([]*bufio.Writer, r.maxStreams()-1) + wp := writerPool{size: r.maxChunkSize()} + for chunk := range chunks.Of(l.Size, r.maxChunkSize()) { + if ctx.Err() != nil { + break + } + ticket := q.Take() - bufIdx := streamNo % len(tws) - streamNo++ g.Go(func() (err error) { defer func() { if err != nil { @@ -520,23 +540,18 @@ func (r *Registry) Pull(ctx context.Context, name string) error { if err != nil { return err } - err := func() error { req := req.Clone(req.Context()) req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk)) - res, err := doOK(r.client(), req) + res, err := sendRequest(r.client(), req) if err != nil { return err } defer res.Body.Close() - tw := tws[bufIdx] - if tw == nil { - tw = bufio.NewWriterSize(nil, int(r.maxChunkSize())) - tws[bufIdx] = tw - } + tw := wp.get() tw.Reset(ticket) - defer tw.Reset(nil) // release ticket + defer wp.put(tw) _, err = io.CopyN(tw, res.Body, chunk.Size()) if err != nil { @@ -595,6 +610,9 @@ type Manifest struct { Name string `json:"-"` // the canonical name of the model Data []byte `json:"-"` // the raw data of the manifest Layers []*Layer `json:"layers"` + + // For legacy reasons, we still have to download the config layer. + Config *Layer `json:"config"` } var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000") @@ -678,7 +696,7 @@ func (r *Registry) ResolveLocal(name string) (*Manifest, error) { data, err := os.ReadFile(c.GetFile(d)) if err != nil { if errors.Is(err, fs.ErrNotExist) { - return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name) + return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name) } return nil, err } @@ -701,7 +719,7 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d) } - res, err := r.doOK(ctx, "GET", manifestURL, nil) + res, err := r.send(ctx, "GET", manifestURL, nil) if err != nil { return nil, err } @@ -726,7 +744,7 @@ func (r *Registry) client() *http.Client { } // newRequest constructs a new request, ready to use, with the given method, -// url, and body, presigned with client Key and UserAgent. +// url, and body, pre-signed with client [Key] and [UserAgent]. func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { @@ -745,11 +763,17 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R return req, nil } -// doOK makes a request with the given client and request, and returns the +// sendRequest makes a request with the given client and request, and returns the // response if the status code is 200. If the status code is not 200, an Error // is parsed from the response body and returned. If any other error occurs, it // is returned. -func doOK(c *http.Client, r *http.Request) (*http.Response, error) { +func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("request error %s: %w", r.URL, err) + } + }() + if r.URL.Scheme == "https+insecure" { // TODO(bmizerany): clone client.Transport, set // InsecureSkipVerify, etc. @@ -792,20 +816,26 @@ func doOK(c *http.Client, r *http.Request) (*http.Response, error) { // Use the raw body if we can't parse it as an error object. re.Message = string(out) } + + // coerce MANIFEST_UNKNOWN to ErrManifestNotFound + if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") { + return nil, ErrModelNotFound + } + re.Status = res.StatusCode return nil, &re } return res, nil } -// doOK is a convenience method for making a request with newRequest and -// passing it to doOK with r.client(). -func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { +// send is a convenience method for making a request with newRequest and +// passing it to send with r.client(). +func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { req, err := r.newRequest(ctx, method, path, body) if err != nil { return nil, err } - return doOK(r.client(), req) + return sendRequest(r.client(), req) } // makeAuthToken creates an Ollama auth token for the given private key. @@ -960,3 +990,28 @@ func splitExtended(s string) (scheme, name, digest string) { } return scheme, s, digest } + +type writerPool struct { + size int64 // set by the caller + + mu sync.Mutex + ws []*bufio.Writer +} + +func (p *writerPool) get() *bufio.Writer { + p.mu.Lock() + defer p.mu.Unlock() + if len(p.ws) == 0 { + return bufio.NewWriterSize(nil, int(p.size)) + } + w := p.ws[len(p.ws)-1] + p.ws = p.ws[:len(p.ws)-1] + return w +} + +func (p *writerPool) put(w *bufio.Writer) { + p.mu.Lock() + defer p.mu.Unlock() + w.Reset(nil) + p.ws = append(p.ws, w) +} diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index b9b4271b9..8f4e1604f 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -608,7 +608,7 @@ func TestInsecureSkipVerify(t *testing.T) { url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name) _, err := rc.Resolve(t.Context(), url) if err == nil || !strings.Contains(err.Error(), "failed to verify") { - t.Errorf("err = %v; want cert verifiction failure", err) + t.Errorf("err = %v; want cert verification failure", err) } url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name) diff --git a/server/internal/client/ollama/trace.go b/server/internal/client/ollama/trace.go index e300870bb..69435c406 100644 --- a/server/internal/client/ollama/trace.go +++ b/server/internal/client/ollama/trace.go @@ -13,9 +13,13 @@ type Trace struct { // Update is called during [Registry.Push] and [Registry.Pull] to // report the progress of blob uploads and downloads. // - // It is called once at the beginning of the download with a zero n and - // then once per read operation with the number of bytes read so far, - // and an error if any. + // The n argument is the number of bytes transferred so far, and err is + // any error that has occurred. If n == 0, and err is nil, the download + // or upload has just started. If err is [ErrCached], the download or + // upload has been skipped because the blob is already present in the + // local cache or remote registry, respectively. Otherwise, if err is + // non-nil, the download or upload has failed. When l.Size == n, and + // err is nil, the download or upload has completed. // // A function assigned must be safe for concurrent use. The function is // called synchronously and so should not block or take long to run. diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go index 4d44aa8d0..62fefb4c7 100644 --- a/server/internal/registry/server.go +++ b/server/internal/registry/server.go @@ -7,10 +7,14 @@ import ( "cmp" "encoding/json" "errors" + "fmt" "io" "log/slog" "net/http" + "sync" + "time" + "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" ) @@ -109,6 +113,8 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) { switch r.URL.Path { case "/api/delete": return false, s.handleDelete(rec, r) + case "/api/pull": + return false, s.handlePull(rec, r) default: if s.Fallback != nil { s.Fallback.ServeHTTP(rec, r) @@ -214,6 +220,97 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error { return s.Prune() } +type progressUpdateJSON struct { + Status string `json:"status"` + Digest blob.Digest `json:"digest,omitempty,omitzero"` + Total int64 `json:"total,omitempty,omitzero"` + Completed int64 `json:"completed,omitempty,omitzero"` +} + +func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { + if r.Method != "POST" { + return errMethodNotAllowed + } + + p, err := decodeUserJSON[*params](r.Body) + if err != nil { + return err + } + + maybeFlush := func() { + fl, _ := w.(http.Flusher) + if fl != nil { + fl.Flush() + } + } + defer maybeFlush() + + var mu sync.Mutex + enc := json.NewEncoder(w) + enc.Encode(progressUpdateJSON{Status: "pulling manifest"}) + + ctx := ollama.WithTrace(r.Context(), &ollama.Trace{ + Update: func(l *ollama.Layer, n int64, err error) { + mu.Lock() + defer mu.Unlock() + + // TODO(bmizerany): coalesce these updates; writing per + // update is expensive + enc.Encode(progressUpdateJSON{ + Digest: l.Digest, + Status: "pulling", + Total: l.Size, + Completed: n, + }) + }, + }) + + done := make(chan error, 1) + go func() { + // TODO(bmizerany): continue to support non-streaming responses + done <- s.Client.Pull(ctx, p.model()) + }() + + func() { + t := time.NewTicker(100 * time.Millisecond) + defer t.Stop() + for { + select { + case <-t.C: + mu.Lock() + maybeFlush() + mu.Unlock() + case err := <-done: + if err != nil { + var status string + if errors.Is(err, ollama.ErrModelNotFound) { + status = fmt.Sprintf("error: model %q not found", p.model()) + enc.Encode(progressUpdateJSON{Status: status}) + } else { + status = fmt.Sprintf("error: %v", err) + enc.Encode(progressUpdateJSON{Status: status}) + } + return + } + + // These final updates are not strictly necessary, because they have + // already happened at this point. Our pull handler code used to do + // these steps after, not during, the pull, and they were slow, so we + // wanted to provide feedback to users what was happening. For now, we + // keep them to not jar users who are used to seeing them. We can phase + // them out with a new and nicer UX later. One without progress bars + // and digests that no one cares about. + enc.Encode(progressUpdateJSON{Status: "verifying layers"}) + enc.Encode(progressUpdateJSON{Status: "writing manifest"}) + enc.Encode(progressUpdateJSON{Status: "success"}) + return + } + } + }() + + return nil +} + func decodeUserJSON[T any](r io.Reader) (T, error) { var v T err := json.NewDecoder(r).Decode(&v) diff --git a/server/internal/registry/server_test.go b/server/internal/registry/server_test.go index e44d88c0f..597e9bd63 100644 --- a/server/internal/registry/server_test.go +++ b/server/internal/registry/server_test.go @@ -1,17 +1,27 @@ package registry import ( + "bytes" + "context" "encoding/json" + "fmt" + "io" + "io/fs" + "net" "net/http" "net/http/httptest" "os" "regexp" "strings" + "sync" "testing" "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/testutil" + "golang.org/x/tools/txtar" + + _ "embed" ) type panicTransport struct{} @@ -30,7 +40,7 @@ type bytesResetter interface { Reset() } -func newTestServer(t *testing.T) *Local { +func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local { t.Helper() dir := t.TempDir() err := os.CopyFS(dir, os.DirFS("testdata/models")) @@ -41,10 +51,25 @@ func newTestServer(t *testing.T) *Local { if err != nil { t.Fatal(err) } + + client := panicOnRoundTrip + if upstreamRegistry != nil { + s := httptest.NewTLSServer(upstreamRegistry) + t.Cleanup(s.Close) + tr := s.Client().Transport.(*http.Transport).Clone() + tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", s.Listener.Addr().String()) + } + client = &http.Client{Transport: tr} + } + rc := &ollama.Registry{ Cache: c, - HTTPClient: panicOnRoundTrip, + HTTPClient: client, + Mask: "example.com/library/_:latest", } + l := &Local{ Client: rc, Logger: testutil.Slogger(t), @@ -85,7 +110,7 @@ func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) { func TestServerDelete(t *testing.T) { check := testutil.Checker(t) - s := newTestServer(t) + s := newTestServer(t, nil) _, err := s.Client.ResolveLocal("smol") check(err) @@ -127,8 +152,105 @@ func TestServerDelete(t *testing.T) { } } +//go:embed testdata/registry.txt +var registryTXT []byte + +var registryFS = sync.OnceValue(func() fs.FS { + // Txtar gets hung up on \r\n line endings, so we need to convert them + // to \n when parsing the txtar on Windows. + data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n")) + a := txtar.Parse(data) + fmt.Printf("%q\n", a.Comment) + fsys, err := txtar.FS(a) + if err != nil { + panic(err) + } + return fsys +}) + +func TestServerPull(t *testing.T) { + modelsHandler := http.FileServerFS(registryFS()) + s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v2/library/BOOM/manifests/latest": + w.WriteHeader(999) + io.WriteString(w, `{"error": "boom"}`) + case "/v2/library/unknown/manifests/latest": + w.WriteHeader(404) + io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`) + default: + t.Logf("serving file: %s", r.URL.Path) + modelsHandler.ServeHTTP(w, r) + } + }) + + checkResponse := func(got *httptest.ResponseRecorder, wantlines string) { + t.Helper() + + if got.Code != 200 { + t.Fatalf("Code = %d; want 200", got.Code) + } + gotlines := got.Body.String() + t.Logf("got:\n%s", gotlines) + for want := range strings.Lines(wantlines) { + want = strings.TrimSpace(want) + want, unwanted := strings.CutPrefix(want, "!") + want = strings.TrimSpace(want) + if !unwanted && !strings.Contains(gotlines, want) { + t.Fatalf("! missing %q in body", want) + } + if unwanted && strings.Contains(gotlines, want) { + t.Fatalf("! unexpected %q in body", want) + } + } + } + + got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`) + checkResponse(got, ` + {"status":"pulling manifest"} + {"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"} + `) + + got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`) + checkResponse(got, ` + {"status":"pulling manifest"} + {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5} + {"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3} + {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5} + {"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3} + {"status":"verifying layers"} + {"status":"writing manifest"} + {"status":"success"} + `) + + got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`) + checkResponse(got, ` + {"status":"pulling manifest"} + {"status":"error: model \"unknown\" not found"} + `) + + got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`) + checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed") + + got = s.send(t, "POST", "/api/pull", `!`) + checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value") + + got = s.send(t, "POST", "/api/pull", ``) + checkErrorResponse(t, got, 400, "bad_request", "empty request body") + + got = s.send(t, "POST", "/api/pull", `{"model": "://"}`) + checkResponse(got, ` + {"status":"pulling manifest"} + {"status":"error: invalid or missing name: \"\""} + + !verifying + !writing + !success + `) +} + func TestServerUnknownPath(t *testing.T) { - s := newTestServer(t) + s := newTestServer(t, nil) got := s.send(t, "DELETE", "/api/unknown", `{}`) checkErrorResponse(t, got, 404, "not_found", "not found") } diff --git a/server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest b/server/internal/registry/testdata/models/manifests/example.com/library/smol/latest similarity index 100% rename from server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest rename to server/internal/registry/testdata/models/manifests/example.com/library/smol/latest diff --git a/server/internal/registry/testdata/registry.txt b/server/internal/registry/testdata/registry.txt new file mode 100644 index 000000000..2fc363fcb --- /dev/null +++ b/server/internal/registry/testdata/registry.txt @@ -0,0 +1,22 @@ +-- v2/library/smol/manifests/latest -- +{ + "schemaVersion": 2, + "mediaType": "application/vnd.docker.distribution.manifest.v2+json", + "config": { + "mediaType": "application/vnd.docker.container.image.v1+json", + "digest": "sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356", + "size": 3 + }, + "layers": [ + { + "mediaType": "application/vnd.ollama.image.model", + "digest": "sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312", + "size": 5 + } + ] +} + +-- v2/library/smol/blobs/sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312 -- +GGUF +-- v2/library/smol/blobs/sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356 -- +{} diff --git a/server/routes.go b/server/routes.go index 73e94dc65..3efa12e43 100644 --- a/server/routes.go +++ b/server/routes.go @@ -42,6 +42,12 @@ import ( "github.com/ollama/ollama/version" ) +func experimentEnabled(name string) bool { + return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) +} + +var useClient2 = experimentEnabled("client2") + var mode string = gin.DebugMode type Server struct { @@ -1173,6 +1179,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.HEAD("/api/tags", s.ListHandler) r.GET("/api/tags", s.ListHandler) r.POST("/api/show", s.ShowHandler) + r.DELETE("/api/delete", s.DeleteHandler) // Create r.POST("/api/create", s.CreateHandler) @@ -1194,16 +1201,19 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler) - // wrap old with new - rs := ®istry.Local{ - Client: rc, - Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default() - Fallback: r, + if rc != nil { + // wrap old with new + rs := ®istry.Local{ + Client: rc, + Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default() + Fallback: r, - Prune: PruneLayers, + Prune: PruneLayers, + } + return rs, nil } - return rs, nil + return r, nil } func Serve(ln net.Listener) error { @@ -1258,15 +1268,20 @@ func Serve(ln net.Listener) error { s := &Server{addr: ln.Addr()} - rc, err := ollama.DefaultRegistry() - if err != nil { - return err + var rc *ollama.Registry + if useClient2 { + var err error + rc, err = ollama.DefaultRegistry() + if err != nil { + return err + } } h, err := s.GenerateRoutes(rc) if err != nil { return err } + http.Handle("/", h) ctx, done := context.WithCancel(context.Background())