diff --git a/cmd/cmd.go b/cmd/cmd.go index 01eb66f9b..57797e931 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -19,6 +19,7 @@ import ( "os" "os/signal" "path/filepath" + "regexp" "runtime" "strconv" "strings" @@ -34,11 +35,13 @@ import ( "golang.org/x/term" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/server" + "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -513,6 +516,48 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generate(cmd, opts) } +func errFromUnknownKey(unknownKeyErr error) error { + // find SSH public key in the error message + // TODO (brucemacd): the API should return structured errors so that this message parsing isn't needed + sshKeyPattern := `ssh-\w+ [^\s"]+` + re := regexp.MustCompile(sshKeyPattern) + matches := re.FindStringSubmatch(unknownKeyErr.Error()) + + if len(matches) > 0 { + serverPubKey := matches[0] + + localPubKey, err := auth.GetPublicKey() + if err != nil { + return unknownKeyErr + } + + if runtime.GOOS == "linux" && serverPubKey != localPubKey { + // try the ollama service public key + svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub") + if err != nil { + return unknownKeyErr + } + localPubKey = strings.TrimSpace(string(svcPubKey)) + } + + // check if the returned public key matches the local public key, this prevents adding a remote key to the user's account + if serverPubKey != localPubKey { + return unknownKeyErr + } + + var msg strings.Builder + msg.WriteString(unknownKeyErr.Error()) + msg.WriteString("\n\nYour ollama key is:\n") + msg.WriteString(localPubKey) + msg.WriteString("\nAdd your key at:\n") + msg.WriteString("https://ollama.com/settings/keys") + + return errors.New(msg.String()) + } + + return unknownKeyErr +} + func PushHandler(cmd *cobra.Command, args []string) error { client, err := api.ClientFromEnvironment() if err != nil { @@ -561,6 +606,7 @@ func PushHandler(cmd *cobra.Command, args []string) error { request := api.PushRequest{Name: args[0], Insecure: insecure} n := model.ParseName(args[0]) + isOllamaHost := strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") if err := client.Push(cmd.Context(), &request, fn); err != nil { if spinner != nil { spinner.Stop() @@ -568,6 +614,11 @@ func PushHandler(cmd *cobra.Command, args []string) error { if strings.Contains(err.Error(), "access denied") { return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own") } + if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost { + // the user has not added their ollama key to ollama.com + // return an error with a more user-friendly message + return errFromUnknownKey(err) + } return err } diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 2e6428cfa..3a8e44a7e 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -16,6 +16,7 @@ import ( "github.com/spf13/cobra" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/types/errtypes" ) func TestShowInfo(t *testing.T) { @@ -373,15 +374,13 @@ func TestGetModelfileName(t *testing.T) { func TestPushHandler(t *testing.T) { tests := []struct { - name string modelName string serverResponse map[string]func(w http.ResponseWriter, r *http.Request) expectedError string expectedOutput string }{ { - name: "successful push", - modelName: "test-model", + modelName: "successful-push", serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ "/api/push": func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -394,8 +393,8 @@ func TestPushHandler(t *testing.T) { return } - if req.Name != "test-model" { - t.Errorf("expected model name 'test-model', got %s", req.Name) + if req.Name != "successful-push" { + t.Errorf("expected model name 'successful-push', got %s", req.Name) } // Simulate progress updates @@ -414,11 +413,10 @@ func TestPushHandler(t *testing.T) { } }, }, - expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n", + expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/successful-push\n", }, { - name: "unauthorized push", - modelName: "unauthorized-model", + modelName: "unauthorized-push", serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ "/api/push": func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -433,10 +431,29 @@ func TestPushHandler(t *testing.T) { }, expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own", }, + { + modelName: "unknown-key-err", + serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ + "/api/push": func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + uerr := errtypes.UnknownOllamaKey{ + Key: "aaa", + } + err := json.NewEncoder(w).Encode(map[string]string{ + "error": uerr.Error(), + }) + if err != nil { + t.Fatal(err) + } + }, + }, + expectedError: "unauthorized: unknown ollama key \"aaa\"", + }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + t.Run(tt.modelName, func(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if handler, ok := tt.serverResponse[r.URL.Path]; ok { handler(w, r) diff --git a/server/images.go b/server/images.go index 29877db33..cda8eb317 100644 --- a/server/images.go +++ b/server/images.go @@ -23,13 +23,16 @@ import ( "strings" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llama" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/template" + "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" + "github.com/ollama/ollama/types/registry" "github.com/ollama/ollama/version" ) @@ -980,8 +983,6 @@ func GetSHA256Digest(r io.Reader) (string, int64) { return fmt.Sprintf("sha256:%x", h.Sum(nil)), n } -var errUnauthorized = errors.New("unauthorized: access denied") - func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) { for range 2 { resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) @@ -1019,13 +1020,33 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR if err != nil { return nil, fmt.Errorf("%d: %s", resp.StatusCode, err) } + + var re registry.Errs + if err := json.Unmarshal(responseBody, &re); err == nil && len(re.Errors) > 0 { + if re.HasCode(registry.ErrCodeAnonymous) { + // if the error is due to anonymous access return a custom error + // this error is used by the CLI to direct a user to add their key to an account + pubKey, nestedErr := auth.GetPublicKey() + if nestedErr != nil { + slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr)) + return nil, re + } + return nil, errtypes.UnknownOllamaKey{ + Key: pubKey, + } + } + return nil, re + } + + // Fallback to returning the raw response if parsing fails return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody) default: return resp, nil } } - return nil, errUnauthorized + // should never be reached + return nil, fmt.Errorf("failed to make upload request") } // testMakeRequestDialContext specifies the dial function for the http client in diff --git a/server/images_test.go b/server/images_test.go new file mode 100644 index 000000000..9f75cba5b --- /dev/null +++ b/server/images_test.go @@ -0,0 +1,107 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" +) + +func TestMakeRequestWithRetry(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "token": "test-token", + }) + })) + defer authServer.Close() + + tests := []struct { + name string + serverHandler http.HandlerFunc + method string + body string + wantErr error + wantStatus int + }{ + { + name: "successful request", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + }, + method: http.MethodGet, + wantStatus: http.StatusOK, + }, + { + name: "not found error", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }, + method: http.MethodGet, + wantErr: os.ErrNotExist, + }, + { + name: "request with body retry", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") == "" { + w.Header().Set("WWW-Authenticate", `Bearer realm="`+authServer.URL+`"`) + w.WriteHeader(http.StatusUnauthorized) + return + } + buf := new(bytes.Buffer) + buf.ReadFrom(r.Body) + if buf.String() != `{"key": "value"}` { + t.Errorf("body not preserved on retry, got %s", buf.String()) + } + w.WriteHeader(http.StatusOK) + }, + method: http.MethodPost, + body: `{"key": "value"}`, + wantStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(tt.serverHandler) + defer server.Close() + + requestURL, _ := url.Parse(server.URL) + var body io.ReadSeeker + if tt.body != "" { + body = strings.NewReader(tt.body) + } + + regOpts := ®istryOptions{ + Insecure: true, + } + + resp, err := makeRequestWithRetry(context.Background(), tt.method, requestURL, nil, body, regOpts) + + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Errorf("got error %v, want %v", err, tt.wantErr) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resp.StatusCode != tt.wantStatus { + t.Errorf("got status %d, want %d", resp.StatusCode, tt.wantStatus) + } + + resp.Body.Close() + }) + } +} diff --git a/types/errtypes/errtypes.go b/types/errtypes/errtypes.go index 27c3f913e..814b58b03 100644 --- a/types/errtypes/errtypes.go +++ b/types/errtypes/errtypes.go @@ -16,6 +16,6 @@ type UnknownOllamaKey struct { Key string } -func (e *UnknownOllamaKey) Error() string { +func (e UnknownOllamaKey) Error() string { return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key)) } diff --git a/types/registry/error.go b/types/registry/error.go new file mode 100644 index 000000000..69392afb1 --- /dev/null +++ b/types/registry/error.go @@ -0,0 +1,38 @@ +package registry + +import ( + "fmt" + "slices" + "strings" +) + +const ErrCodeAnonymous = "ANONYMOUS_ACCESS_DENIED" + +type Err struct { + Code string `json:"code"` + Message string `json:"message"` +} + +// Errs represents the structure of error responses from the registry +// TODO (brucemacd): this struct should be imported from some shared package that is used between the registry and ollama +type Errs struct { + Errors []Err `json:"errors"` +} + +// Error implements the error interface for RegistryError +func (e Errs) Error() string { + if len(e.Errors) == 0 { + return "unknown registry error" + } + var msgs []string + for _, err := range e.Errors { + msgs = append(msgs, fmt.Sprintf("%s: %s", err.Code, err.Message)) + } + return strings.Join(msgs, "; ") +} + +func (e Errs) HasCode(code string) bool { + return slices.ContainsFunc(e.Errors, func(err Err) bool { + return err.Code == code + }) +}