Compare commits
2 Commits
brucemacd/
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e19c64e047 | ||
|
|
9e190ac4d9 |
@@ -163,24 +163,29 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
scanBuf := make([]byte, 0, maxBufferSize)
|
||||
scanner.Buffer(scanBuf, maxBufferSize)
|
||||
for scanner.Scan() {
|
||||
var errorResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
bts := scanner.Bytes()
|
||||
|
||||
var errorResponse ErrorResponse
|
||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if errorResponse.Error != "" {
|
||||
return errors.New(errorResponse.Error)
|
||||
switch errorResponse.Code {
|
||||
case ErrCodeUnknownKey:
|
||||
return ErrUnknownOllamaKey{
|
||||
Message: errorResponse.Message,
|
||||
Key: errorResponse.Data["key"].(string),
|
||||
}
|
||||
}
|
||||
if errorResponse.Message != "" {
|
||||
return errors.New(errorResponse.Message)
|
||||
}
|
||||
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
ErrorMessage: errorResponse.Error,
|
||||
ErrorMessage: errorResponse.Message,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -43,3 +49,117 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverResponse []string
|
||||
statusCode int
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "unknown key error",
|
||||
serverResponse: []string{
|
||||
`{"error":"unauthorized access","code":"unknown_key","data":{"key":"test-key"}}`,
|
||||
},
|
||||
statusCode: http.StatusUnauthorized,
|
||||
expectedError: &ErrUnknownOllamaKey{
|
||||
Message: "unauthorized access",
|
||||
Key: "test-key",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "general error message",
|
||||
serverResponse: []string{
|
||||
`{"error":"something went wrong"}`,
|
||||
},
|
||||
statusCode: http.StatusInternalServerError,
|
||||
expectedError: fmt.Errorf("something went wrong"),
|
||||
},
|
||||
{
|
||||
name: "malformed json response",
|
||||
serverResponse: []string{
|
||||
`{invalid-json`,
|
||||
},
|
||||
statusCode: http.StatusOK,
|
||||
expectedError: fmt.Errorf("unmarshal: invalid character 'i' looking for beginning of object key string"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.WriteHeader(tt.statusCode)
|
||||
for _, resp := range tt.serverResponse {
|
||||
fmt.Fprintln(w, resp)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
baseURL, err := url.Parse(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse server URL: %v", err)
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
http: server.Client(),
|
||||
base: baseURL,
|
||||
}
|
||||
|
||||
var responses [][]byte
|
||||
err = client.stream(context.Background(), "POST", "/test", "test", func(bts []byte) error {
|
||||
responses = append(responses, bts)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Error checking
|
||||
if tt.expectedError == nil {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("expected error %v, got nil", tt.expectedError)
|
||||
}
|
||||
|
||||
// Check for specific error types
|
||||
var unknownKeyErr ErrUnknownOllamaKey
|
||||
if errors.As(tt.expectedError, &unknownKeyErr) {
|
||||
var gotErr ErrUnknownOllamaKey
|
||||
if !errors.As(err, &gotErr) {
|
||||
t.Fatalf("expected ErrUnknownOllamaKey, got %T", err)
|
||||
}
|
||||
if unknownKeyErr.Key != gotErr.Key {
|
||||
t.Errorf("expected key %q, got %q", unknownKeyErr.Key, gotErr.Key)
|
||||
}
|
||||
if unknownKeyErr.Message != gotErr.Message {
|
||||
t.Errorf("expected message %q, got %q", unknownKeyErr.Message, gotErr.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var statusErr StatusError
|
||||
if errors.As(tt.expectedError, &statusErr) {
|
||||
var gotErr StatusError
|
||||
if !errors.As(err, &gotErr) {
|
||||
t.Fatalf("expected StatusError, got %T", err)
|
||||
}
|
||||
if statusErr.StatusCode != gotErr.StatusCode {
|
||||
t.Errorf("expected status code %d, got %d", statusErr.StatusCode, gotErr.StatusCode)
|
||||
}
|
||||
if statusErr.ErrorMessage != gotErr.ErrorMessage {
|
||||
t.Errorf("expected error message %q, got %q", statusErr.ErrorMessage, gotErr.ErrorMessage)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// For other errors, compare error strings
|
||||
if err.Error() != tt.expectedError.Error() {
|
||||
t.Errorf("expected error %q, got %q", tt.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
74
api/errors.go
Normal file
74
api/errors.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const InvalidModelNameErrMsg = "invalid model name"
|
||||
|
||||
// API error responses
|
||||
// ErrorCode represents a standardized error code identifier
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
ErrCodeUnknownKey ErrorCode = "unknown_key"
|
||||
ErrCodeGeneral ErrorCode = "general" // Generic fallback error code
|
||||
)
|
||||
|
||||
// ErrorResponse implements a structured error interface
|
||||
type ErrorResponse struct {
|
||||
Message string `json:"error"` // Human-readable error message, uses 'error' field name for backwards compatibility
|
||||
Code ErrorCode `json:"code"` // Machine-readable error code for programmatic handling, not response code
|
||||
Data map[string]any `json:"data"` // Additional error specific data, if any
|
||||
}
|
||||
|
||||
func (e ErrorResponse) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
type ErrUnknownOllamaKey struct {
|
||||
Message string
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e ErrUnknownOllamaKey) Error() string {
|
||||
return fmt.Sprintf("unauthorized: unknown ollama key %q", strings.TrimSpace(e.Key))
|
||||
}
|
||||
|
||||
func (e *ErrUnknownOllamaKey) FormatUserMessage(localKeys []string) string {
|
||||
// The user should only be told to add the key if it is the same one that exists locally
|
||||
if slices.Index(localKeys, e.Key) == -1 {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`%s
|
||||
|
||||
Your ollama key is:
|
||||
%s
|
||||
Add your key at:
|
||||
https://ollama.com/settings/keys`, e.Message, e.Key)
|
||||
}
|
||||
|
||||
// StatusError is an error with an HTTP status code and message,
|
||||
// it is parsed on the client-side and not returned from the API
|
||||
type StatusError struct {
|
||||
StatusCode int // e.g. 200
|
||||
Status string // e.g. "200 OK"
|
||||
ErrorMessage string `json:"error"`
|
||||
}
|
||||
|
||||
func (e StatusError) Error() string {
|
||||
switch {
|
||||
case e.Status != "" && e.ErrorMessage != "":
|
||||
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
|
||||
case e.Status != "":
|
||||
return e.Status
|
||||
case e.ErrorMessage != "":
|
||||
return e.ErrorMessage
|
||||
default:
|
||||
// this should not happen
|
||||
return "something went wrong, please see the ollama server logs for details"
|
||||
}
|
||||
}
|
||||
21
api/types.go
21
api/types.go
@@ -12,27 +12,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// StatusError is an error with an HTTP status code and message.
|
||||
type StatusError struct {
|
||||
StatusCode int
|
||||
Status string
|
||||
ErrorMessage string `json:"error"`
|
||||
}
|
||||
|
||||
func (e StatusError) Error() string {
|
||||
switch {
|
||||
case e.Status != "" && e.ErrorMessage != "":
|
||||
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
|
||||
case e.Status != "":
|
||||
return e.Status
|
||||
case e.ErrorMessage != "":
|
||||
return e.ErrorMessage
|
||||
default:
|
||||
// this should not happen
|
||||
return "something went wrong, please see the ollama server logs for details"
|
||||
}
|
||||
}
|
||||
|
||||
// ImageData represents the raw binary data of an image file.
|
||||
type ImageData []byte
|
||||
|
||||
|
||||
68
cmd/cmd.go
68
cmd/cmd.go
@@ -19,7 +19,6 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -41,7 +40,6 @@ import (
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -516,46 +514,22 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
func errFromUnknownKey(unknownKeyErr error) error {
|
||||
// find SSH public key in the error message
|
||||
// TODO (brucemacd): the API should return structured errors so that this message parsing isn't needed
|
||||
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
||||
re := regexp.MustCompile(sshKeyPattern)
|
||||
matches := re.FindStringSubmatch(unknownKeyErr.Error())
|
||||
|
||||
if len(matches) > 0 {
|
||||
serverPubKey := matches[0]
|
||||
|
||||
localPubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||
// try the ollama service public key
|
||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
localPubKey = strings.TrimSpace(string(svcPubKey))
|
||||
}
|
||||
|
||||
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
|
||||
if serverPubKey != localPubKey {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
var msg strings.Builder
|
||||
msg.WriteString(unknownKeyErr.Error())
|
||||
msg.WriteString("\n\nYour ollama key is:\n")
|
||||
msg.WriteString(localPubKey)
|
||||
msg.WriteString("\nAdd your key at:\n")
|
||||
msg.WriteString("https://ollama.com/settings/keys")
|
||||
|
||||
return errors.New(msg.String())
|
||||
func localPubKeys() ([]string, error) {
|
||||
usrKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return unknownKeyErr
|
||||
keys := []string{usrKey}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
// try the ollama service public key if on Linux
|
||||
if svcKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub"); err == nil {
|
||||
keys = append(keys, strings.TrimSpace(string(svcKey)))
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
@@ -611,15 +585,17 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
var ke api.ErrUnknownOllamaKey
|
||||
if errors.As(err, &ke) && isOllamaHost {
|
||||
// the user has not added their ollama key to ollama.com
|
||||
// return an error with a more user-friendly message
|
||||
locals, _ := localPubKeys()
|
||||
return errors.New(ke.FormatUserMessage(locals))
|
||||
}
|
||||
if strings.Contains(err.Error(), "access denied") {
|
||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||
}
|
||||
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
|
||||
// the user has not added their ollama key to ollama.com
|
||||
// return an error with a more user-friendly message
|
||||
return errFromUnknownKey(err)
|
||||
}
|
||||
return err
|
||||
return fmt.Errorf("yoyoyo: %w", err)
|
||||
}
|
||||
|
||||
p.Stop()
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
)
|
||||
|
||||
func TestShowInfo(t *testing.T) {
|
||||
@@ -437,7 +436,7 @@ func TestPushHandler(t *testing.T) {
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
uerr := errtypes.UnknownOllamaKey{
|
||||
uerr := api.ErrUnknownOllamaKey{
|
||||
Key: "aaa",
|
||||
}
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
)
|
||||
|
||||
type MultilineState int
|
||||
@@ -220,7 +219,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fn := func(resp api.ProgressResponse) error { return nil }
|
||||
err = client.Create(cmd.Context(), req, fn)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), errtypes.InvalidModelNameErrMsg) {
|
||||
if strings.Contains(err.Error(), api.InvalidModelNameErrMsg) {
|
||||
fmt.Printf("error: The model name '%s' is invalid\n", args[1])
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -30,7 +30,6 @@ import (
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/registry"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -1031,7 +1030,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||
return nil, re
|
||||
}
|
||||
return nil, errtypes.UnknownOllamaKey{
|
||||
return nil, api.ErrUnknownOllamaKey{
|
||||
Key: pubKey,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,7 +36,6 @@ import (
|
||||
"github.com/ollama/ollama/runners"
|
||||
"github.com/ollama/ollama/server/imageproc"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -610,7 +609,7 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||
defer cancel()
|
||||
|
||||
if err := PushModel(ctx, model, regOpts, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
ch <- newErr(err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -650,7 +649,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
|
||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": api.InvalidModelNameErrMsg})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1550,3 +1549,24 @@ func handleScheduleError(c *gin.Context, name string, err error) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
// newErr creates a structured API ErrorResponse from an existing error
|
||||
func newErr(err error) api.ErrorResponse {
|
||||
if err == nil {
|
||||
return api.ErrorResponse{}
|
||||
}
|
||||
// Default to just returning the generic error message
|
||||
resp := api.ErrorResponse{
|
||||
Code: api.ErrCodeGeneral,
|
||||
Message: err.Error(),
|
||||
}
|
||||
// Add additional error specific data, if any
|
||||
var keyErr api.ErrUnknownOllamaKey
|
||||
if errors.As(err, &keyErr) {
|
||||
resp.Code = api.ErrCodeUnknownKey
|
||||
resp.Data = map[string]any{
|
||||
"key": keyErr.Key,
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
// Package errtypes contains custom error types
|
||||
package errtypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
UnknownOllamaKeyErrMsg = "unknown ollama key"
|
||||
InvalidModelNameErrMsg = "invalid model name"
|
||||
)
|
||||
|
||||
// TODO: This should have a structured response from the API
|
||||
type UnknownOllamaKey struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e UnknownOllamaKey) Error() string {
|
||||
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
|
||||
}
|
||||
Reference in New Issue
Block a user