diff --git a/api/types.go b/api/types.go index 637ca2042..fef836bd6 100644 --- a/api/types.go +++ b/api/types.go @@ -361,9 +361,9 @@ type CopyRequest struct { // PullRequest is the request passed to [Client.Pull]. type PullRequest struct { Model string `json:"model"` - Insecure bool `json:"insecure,omitempty"` - Username string `json:"username"` - Password string `json:"password"` + Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored + Username string `json:"username"` // Deprecated: ignored + Password string `json:"password"` // Deprecated: ignored Stream *bool `json:"stream,omitempty"` // Deprecated: set the model name with Model instead diff --git a/go.mod b/go.mod index af0cedc86..c45c9892c 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/nlpodyssey/gopickle v0.3.0 github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c golang.org/x/image v0.22.0 + golang.org/x/tools v0.30.0 gonum.org/v1/gonum v0.15.0 ) diff --git a/go.sum b/go.sum index 013a7db71..0ab97b909 100644 --- a/go.sum +++ b/go.sum @@ -309,6 +309,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= +golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index 007de5e8a..423a6ad23 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -45,9 +45,9 @@ import ( // Errors var ( - // ErrManifestNotFound is returned when a manifest is not found in the + // ErrModelNotFound is returned when a manifest is not found in the // cache or registry. - ErrManifestNotFound = errors.New("manifest not found") + ErrModelNotFound = errors.New("model not found") // ErrManifestInvalid is returned when a manifest found in a local or // remote cache is invalid. @@ -114,7 +114,18 @@ type Error struct { } func (e *Error) Error() string { - return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message) + var b strings.Builder + b.WriteString("registry responded with status ") + b.WriteString(strconv.Itoa(e.Status)) + if e.Code != "" { + b.WriteString(": code ") + b.WriteString(e.Code) + } + if e.Message != "" { + b.WriteString(": ") + b.WriteString(e.Message) + } + return b.String() } func (e *Error) LogValue() slog.Value { @@ -355,7 +366,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { n.Model(), l.Digest, ) - res, err := r.doOK(ctx, "POST", startURL, nil) + res, err := r.send(ctx, "POST", startURL, nil) if err != nil { return err } @@ -379,7 +390,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { } req.ContentLength = l.Size - res, err = doOK(r.client(), req) + res, err = sendRequest(r.client(), req) if err == nil { res.Body.Close() } @@ -399,7 +410,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { n.Model(), n.Tag(), ) - res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data)) + res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data)) if err == nil { res.Body.Close() } @@ -448,10 +459,15 @@ func (r *Registry) Pull(ctx context.Context, name string) error { t := traceFromContext(ctx) - var g errgroup.Group + g, ctx := errgroup.WithContext(ctx) g.SetLimit(r.maxStreams()) - for _, l := range m.Layers { + layers := m.Layers + if m.Config != nil && m.Config.Digest.IsValid() { + layers = append(layers, m.Config) + } + + for _, l := range layers { if exists(l) { t.update(l, l.Size, ErrCached) continue @@ -468,7 +484,9 @@ func (r *Registry) Pull(ctx context.Context, name string) error { if l.Size <= r.maxChunkingThreshold() { g.Go(func() error { - res, err := doOK(r.client(), req) + // TODO(bmizerany): retry/backoff like below in + // the chunking case + res, err := sendRequest(r.client(), req) if err != nil { return err } @@ -494,19 +512,21 @@ func (r *Registry) Pull(ctx context.Context, name string) error { // fire an initial request to get the final URL and // then use that URL for the chunk requests. req.Header.Set("Range", "bytes=0-0") - res, err := doOK(r.client(), req) + res, err := sendRequest(r.client(), req) if err != nil { return err } res.Body.Close() req = res.Request.WithContext(req.Context()) - streamNo := 0 - tws := make([]*bufio.Writer, r.maxStreams()-1) + wp := writerPool{size: r.maxChunkSize()} + for chunk := range chunks.Of(l.Size, r.maxChunkSize()) { + if ctx.Err() != nil { + break + } + ticket := q.Take() - bufIdx := streamNo % len(tws) - streamNo++ g.Go(func() (err error) { defer func() { if err != nil { @@ -520,23 +540,18 @@ func (r *Registry) Pull(ctx context.Context, name string) error { if err != nil { return err } - err := func() error { req := req.Clone(req.Context()) req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk)) - res, err := doOK(r.client(), req) + res, err := sendRequest(r.client(), req) if err != nil { return err } defer res.Body.Close() - tw := tws[bufIdx] - if tw == nil { - tw = bufio.NewWriterSize(nil, int(r.maxChunkSize())) - tws[bufIdx] = tw - } + tw := wp.get() tw.Reset(ticket) - defer tw.Reset(nil) // release ticket + defer wp.put(tw) _, err = io.CopyN(tw, res.Body, chunk.Size()) if err != nil { @@ -595,6 +610,9 @@ type Manifest struct { Name string `json:"-"` // the canonical name of the model Data []byte `json:"-"` // the raw data of the manifest Layers []*Layer `json:"layers"` + + // For legacy reasons, we still have to download the config layer. + Config *Layer `json:"config"` } var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000") @@ -678,7 +696,7 @@ func (r *Registry) ResolveLocal(name string) (*Manifest, error) { data, err := os.ReadFile(c.GetFile(d)) if err != nil { if errors.Is(err, fs.ErrNotExist) { - return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name) + return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name) } return nil, err } @@ -701,7 +719,7 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d) } - res, err := r.doOK(ctx, "GET", manifestURL, nil) + res, err := r.send(ctx, "GET", manifestURL, nil) if err != nil { return nil, err } @@ -726,7 +744,7 @@ func (r *Registry) client() *http.Client { } // newRequest constructs a new request, ready to use, with the given method, -// url, and body, presigned with client Key and UserAgent. +// url, and body, pre-signed with client [Key] and [UserAgent]. func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { @@ -745,11 +763,17 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R return req, nil } -// doOK makes a request with the given client and request, and returns the +// sendRequest makes a request with the given client and request, and returns the // response if the status code is 200. If the status code is not 200, an Error // is parsed from the response body and returned. If any other error occurs, it // is returned. -func doOK(c *http.Client, r *http.Request) (*http.Response, error) { +func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("request error %s: %w", r.URL, err) + } + }() + if r.URL.Scheme == "https+insecure" { // TODO(bmizerany): clone client.Transport, set // InsecureSkipVerify, etc. @@ -792,20 +816,26 @@ func doOK(c *http.Client, r *http.Request) (*http.Response, error) { // Use the raw body if we can't parse it as an error object. re.Message = string(out) } + + // coerce MANIFEST_UNKNOWN to ErrManifestNotFound + if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") { + return nil, ErrModelNotFound + } + re.Status = res.StatusCode return nil, &re } return res, nil } -// doOK is a convenience method for making a request with newRequest and -// passing it to doOK with r.client(). -func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { +// send is a convenience method for making a request with newRequest and +// passing it to send with r.client(). +func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { req, err := r.newRequest(ctx, method, path, body) if err != nil { return nil, err } - return doOK(r.client(), req) + return sendRequest(r.client(), req) } // makeAuthToken creates an Ollama auth token for the given private key. @@ -960,3 +990,28 @@ func splitExtended(s string) (scheme, name, digest string) { } return scheme, s, digest } + +type writerPool struct { + size int64 // set by the caller + + mu sync.Mutex + ws []*bufio.Writer +} + +func (p *writerPool) get() *bufio.Writer { + p.mu.Lock() + defer p.mu.Unlock() + if len(p.ws) == 0 { + return bufio.NewWriterSize(nil, int(p.size)) + } + w := p.ws[len(p.ws)-1] + p.ws = p.ws[:len(p.ws)-1] + return w +} + +func (p *writerPool) put(w *bufio.Writer) { + p.mu.Lock() + defer p.mu.Unlock() + w.Reset(nil) + p.ws = append(p.ws, w) +} diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index b9b4271b9..8f4e1604f 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -608,7 +608,7 @@ func TestInsecureSkipVerify(t *testing.T) { url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name) _, err := rc.Resolve(t.Context(), url) if err == nil || !strings.Contains(err.Error(), "failed to verify") { - t.Errorf("err = %v; want cert verifiction failure", err) + t.Errorf("err = %v; want cert verification failure", err) } url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name) diff --git a/server/internal/client/ollama/trace.go b/server/internal/client/ollama/trace.go index e300870bb..69435c406 100644 --- a/server/internal/client/ollama/trace.go +++ b/server/internal/client/ollama/trace.go @@ -13,9 +13,13 @@ type Trace struct { // Update is called during [Registry.Push] and [Registry.Pull] to // report the progress of blob uploads and downloads. // - // It is called once at the beginning of the download with a zero n and - // then once per read operation with the number of bytes read so far, - // and an error if any. + // The n argument is the number of bytes transferred so far, and err is + // any error that has occurred. If n == 0, and err is nil, the download + // or upload has just started. If err is [ErrCached], the download or + // upload has been skipped because the blob is already present in the + // local cache or remote registry, respectively. Otherwise, if err is + // non-nil, the download or upload has failed. When l.Size == n, and + // err is nil, the download or upload has completed. // // A function assigned must be safe for concurrent use. The function is // called synchronously and so should not block or take long to run. diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go index 4d44aa8d0..62fefb4c7 100644 --- a/server/internal/registry/server.go +++ b/server/internal/registry/server.go @@ -7,10 +7,14 @@ import ( "cmp" "encoding/json" "errors" + "fmt" "io" "log/slog" "net/http" + "sync" + "time" + "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" ) @@ -109,6 +113,8 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) { switch r.URL.Path { case "/api/delete": return false, s.handleDelete(rec, r) + case "/api/pull": + return false, s.handlePull(rec, r) default: if s.Fallback != nil { s.Fallback.ServeHTTP(rec, r) @@ -214,6 +220,97 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error { return s.Prune() } +type progressUpdateJSON struct { + Status string `json:"status"` + Digest blob.Digest `json:"digest,omitempty,omitzero"` + Total int64 `json:"total,omitempty,omitzero"` + Completed int64 `json:"completed,omitempty,omitzero"` +} + +func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { + if r.Method != "POST" { + return errMethodNotAllowed + } + + p, err := decodeUserJSON[*params](r.Body) + if err != nil { + return err + } + + maybeFlush := func() { + fl, _ := w.(http.Flusher) + if fl != nil { + fl.Flush() + } + } + defer maybeFlush() + + var mu sync.Mutex + enc := json.NewEncoder(w) + enc.Encode(progressUpdateJSON{Status: "pulling manifest"}) + + ctx := ollama.WithTrace(r.Context(), &ollama.Trace{ + Update: func(l *ollama.Layer, n int64, err error) { + mu.Lock() + defer mu.Unlock() + + // TODO(bmizerany): coalesce these updates; writing per + // update is expensive + enc.Encode(progressUpdateJSON{ + Digest: l.Digest, + Status: "pulling", + Total: l.Size, + Completed: n, + }) + }, + }) + + done := make(chan error, 1) + go func() { + // TODO(bmizerany): continue to support non-streaming responses + done <- s.Client.Pull(ctx, p.model()) + }() + + func() { + t := time.NewTicker(100 * time.Millisecond) + defer t.Stop() + for { + select { + case <-t.C: + mu.Lock() + maybeFlush() + mu.Unlock() + case err := <-done: + if err != nil { + var status string + if errors.Is(err, ollama.ErrModelNotFound) { + status = fmt.Sprintf("error: model %q not found", p.model()) + enc.Encode(progressUpdateJSON{Status: status}) + } else { + status = fmt.Sprintf("error: %v", err) + enc.Encode(progressUpdateJSON{Status: status}) + } + return + } + + // These final updates are not strictly necessary, because they have + // already happened at this point. Our pull handler code used to do + // these steps after, not during, the pull, and they were slow, so we + // wanted to provide feedback to users what was happening. For now, we + // keep them to not jar users who are used to seeing them. We can phase + // them out with a new and nicer UX later. One without progress bars + // and digests that no one cares about. + enc.Encode(progressUpdateJSON{Status: "verifying layers"}) + enc.Encode(progressUpdateJSON{Status: "writing manifest"}) + enc.Encode(progressUpdateJSON{Status: "success"}) + return + } + } + }() + + return nil +} + func decodeUserJSON[T any](r io.Reader) (T, error) { var v T err := json.NewDecoder(r).Decode(&v) diff --git a/server/internal/registry/server_test.go b/server/internal/registry/server_test.go index e44d88c0f..597e9bd63 100644 --- a/server/internal/registry/server_test.go +++ b/server/internal/registry/server_test.go @@ -1,17 +1,27 @@ package registry import ( + "bytes" + "context" "encoding/json" + "fmt" + "io" + "io/fs" + "net" "net/http" "net/http/httptest" "os" "regexp" "strings" + "sync" "testing" "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/testutil" + "golang.org/x/tools/txtar" + + _ "embed" ) type panicTransport struct{} @@ -30,7 +40,7 @@ type bytesResetter interface { Reset() } -func newTestServer(t *testing.T) *Local { +func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local { t.Helper() dir := t.TempDir() err := os.CopyFS(dir, os.DirFS("testdata/models")) @@ -41,10 +51,25 @@ func newTestServer(t *testing.T) *Local { if err != nil { t.Fatal(err) } + + client := panicOnRoundTrip + if upstreamRegistry != nil { + s := httptest.NewTLSServer(upstreamRegistry) + t.Cleanup(s.Close) + tr := s.Client().Transport.(*http.Transport).Clone() + tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", s.Listener.Addr().String()) + } + client = &http.Client{Transport: tr} + } + rc := &ollama.Registry{ Cache: c, - HTTPClient: panicOnRoundTrip, + HTTPClient: client, + Mask: "example.com/library/_:latest", } + l := &Local{ Client: rc, Logger: testutil.Slogger(t), @@ -85,7 +110,7 @@ func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) { func TestServerDelete(t *testing.T) { check := testutil.Checker(t) - s := newTestServer(t) + s := newTestServer(t, nil) _, err := s.Client.ResolveLocal("smol") check(err) @@ -127,8 +152,105 @@ func TestServerDelete(t *testing.T) { } } +//go:embed testdata/registry.txt +var registryTXT []byte + +var registryFS = sync.OnceValue(func() fs.FS { + // Txtar gets hung up on \r\n line endings, so we need to convert them + // to \n when parsing the txtar on Windows. + data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n")) + a := txtar.Parse(data) + fmt.Printf("%q\n", a.Comment) + fsys, err := txtar.FS(a) + if err != nil { + panic(err) + } + return fsys +}) + +func TestServerPull(t *testing.T) { + modelsHandler := http.FileServerFS(registryFS()) + s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v2/library/BOOM/manifests/latest": + w.WriteHeader(999) + io.WriteString(w, `{"error": "boom"}`) + case "/v2/library/unknown/manifests/latest": + w.WriteHeader(404) + io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`) + default: + t.Logf("serving file: %s", r.URL.Path) + modelsHandler.ServeHTTP(w, r) + } + }) + + checkResponse := func(got *httptest.ResponseRecorder, wantlines string) { + t.Helper() + + if got.Code != 200 { + t.Fatalf("Code = %d; want 200", got.Code) + } + gotlines := got.Body.String() + t.Logf("got:\n%s", gotlines) + for want := range strings.Lines(wantlines) { + want = strings.TrimSpace(want) + want, unwanted := strings.CutPrefix(want, "!") + want = strings.TrimSpace(want) + if !unwanted && !strings.Contains(gotlines, want) { + t.Fatalf("! missing %q in body", want) + } + if unwanted && strings.Contains(gotlines, want) { + t.Fatalf("! unexpected %q in body", want) + } + } + } + + got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`) + checkResponse(got, ` + {"status":"pulling manifest"} + {"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"} + `) + + got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`) + checkResponse(got, ` + {"status":"pulling manifest"} + {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5} + {"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3} + {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5} + {"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3} + {"status":"verifying layers"} + {"status":"writing manifest"} + {"status":"success"} + `) + + got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`) + checkResponse(got, ` + {"status":"pulling manifest"} + {"status":"error: model \"unknown\" not found"} + `) + + got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`) + checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed") + + got = s.send(t, "POST", "/api/pull", `!`) + checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value") + + got = s.send(t, "POST", "/api/pull", ``) + checkErrorResponse(t, got, 400, "bad_request", "empty request body") + + got = s.send(t, "POST", "/api/pull", `{"model": "://"}`) + checkResponse(got, ` + {"status":"pulling manifest"} + {"status":"error: invalid or missing name: \"\""} + + !verifying + !writing + !success + `) +} + func TestServerUnknownPath(t *testing.T) { - s := newTestServer(t) + s := newTestServer(t, nil) got := s.send(t, "DELETE", "/api/unknown", `{}`) checkErrorResponse(t, got, 404, "not_found", "not found") } diff --git a/server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest b/server/internal/registry/testdata/models/manifests/example.com/library/smol/latest similarity index 100% rename from server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest rename to server/internal/registry/testdata/models/manifests/example.com/library/smol/latest diff --git a/server/internal/registry/testdata/registry.txt b/server/internal/registry/testdata/registry.txt new file mode 100644 index 000000000..2fc363fcb --- /dev/null +++ b/server/internal/registry/testdata/registry.txt @@ -0,0 +1,22 @@ +-- v2/library/smol/manifests/latest -- +{ + "schemaVersion": 2, + "mediaType": "application/vnd.docker.distribution.manifest.v2+json", + "config": { + "mediaType": "application/vnd.docker.container.image.v1+json", + "digest": "sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356", + "size": 3 + }, + "layers": [ + { + "mediaType": "application/vnd.ollama.image.model", + "digest": "sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312", + "size": 5 + } + ] +} + +-- v2/library/smol/blobs/sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312 -- +GGUF +-- v2/library/smol/blobs/sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356 -- +{} diff --git a/server/routes.go b/server/routes.go index 73e94dc65..3efa12e43 100644 --- a/server/routes.go +++ b/server/routes.go @@ -42,6 +42,12 @@ import ( "github.com/ollama/ollama/version" ) +func experimentEnabled(name string) bool { + return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) +} + +var useClient2 = experimentEnabled("client2") + var mode string = gin.DebugMode type Server struct { @@ -1173,6 +1179,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.HEAD("/api/tags", s.ListHandler) r.GET("/api/tags", s.ListHandler) r.POST("/api/show", s.ShowHandler) + r.DELETE("/api/delete", s.DeleteHandler) // Create r.POST("/api/create", s.CreateHandler) @@ -1194,16 +1201,19 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler) - // wrap old with new - rs := ®istry.Local{ - Client: rc, - Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default() - Fallback: r, + if rc != nil { + // wrap old with new + rs := ®istry.Local{ + Client: rc, + Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default() + Fallback: r, - Prune: PruneLayers, + Prune: PruneLayers, + } + return rs, nil } - return rs, nil + return r, nil } func Serve(ln net.Listener) error { @@ -1258,15 +1268,20 @@ func Serve(ln net.Listener) error { s := &Server{addr: ln.Addr()} - rc, err := ollama.DefaultRegistry() - if err != nil { - return err + var rc *ollama.Registry + if useClient2 { + var err error + rc, err = ollama.DefaultRegistry() + if err != nil { + return err + } } h, err := s.GenerateRoutes(rc) if err != nil { return err } + http.Handle("/", h) ctx, done := context.WithCancel(context.Background())