Compare commits
7 Commits
main
...
brucemacd/
Author | SHA1 | Date | |
---|---|---|---|
![]() |
e19c64e047 | ||
![]() |
9e190ac4d9 | ||
![]() |
ae9165d661 | ||
![]() |
a262b86a5e | ||
![]() |
4d5d3c3276 | ||
![]() |
ea90ee7415 | ||
![]() |
40134c6587 |
@ -163,24 +163,29 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
|||||||
scanBuf := make([]byte, 0, maxBufferSize)
|
scanBuf := make([]byte, 0, maxBufferSize)
|
||||||
scanner.Buffer(scanBuf, maxBufferSize)
|
scanner.Buffer(scanBuf, maxBufferSize)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
var errorResponse struct {
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
bts := scanner.Bytes()
|
bts := scanner.Bytes()
|
||||||
|
|
||||||
|
var errorResponse ErrorResponse
|
||||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||||
return fmt.Errorf("unmarshal: %w", err)
|
return fmt.Errorf("unmarshal: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errorResponse.Error != "" {
|
switch errorResponse.Code {
|
||||||
return errors.New(errorResponse.Error)
|
case ErrCodeUnknownKey:
|
||||||
|
return ErrUnknownOllamaKey{
|
||||||
|
Message: errorResponse.Message,
|
||||||
|
Key: errorResponse.Data["key"].(string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errorResponse.Message != "" {
|
||||||
|
return errors.New(errorResponse.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.StatusCode >= http.StatusBadRequest {
|
if response.StatusCode >= http.StatusBadRequest {
|
||||||
return StatusError{
|
return StatusError{
|
||||||
StatusCode: response.StatusCode,
|
StatusCode: response.StatusCode,
|
||||||
Status: response.Status,
|
Status: response.Status,
|
||||||
ErrorMessage: errorResponse.Error,
|
ErrorMessage: errorResponse.Message,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -43,3 +49,117 @@ func TestClientFromEnvironment(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStream(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
serverResponse []string
|
||||||
|
statusCode int
|
||||||
|
expectedError error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "unknown key error",
|
||||||
|
serverResponse: []string{
|
||||||
|
`{"error":"unauthorized access","code":"unknown_key","data":{"key":"test-key"}}`,
|
||||||
|
},
|
||||||
|
statusCode: http.StatusUnauthorized,
|
||||||
|
expectedError: &ErrUnknownOllamaKey{
|
||||||
|
Message: "unauthorized access",
|
||||||
|
Key: "test-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "general error message",
|
||||||
|
serverResponse: []string{
|
||||||
|
`{"error":"something went wrong"}`,
|
||||||
|
},
|
||||||
|
statusCode: http.StatusInternalServerError,
|
||||||
|
expectedError: fmt.Errorf("something went wrong"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "malformed json response",
|
||||||
|
serverResponse: []string{
|
||||||
|
`{invalid-json`,
|
||||||
|
},
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
expectedError: fmt.Errorf("unmarshal: invalid character 'i' looking for beginning of object key string"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||||
|
w.WriteHeader(tt.statusCode)
|
||||||
|
for _, resp := range tt.serverResponse {
|
||||||
|
fmt.Fprintln(w, resp)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
baseURL, err := url.Parse(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse server URL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &Client{
|
||||||
|
http: server.Client(),
|
||||||
|
base: baseURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
var responses [][]byte
|
||||||
|
err = client.stream(context.Background(), "POST", "/test", "test", func(bts []byte) error {
|
||||||
|
responses = append(responses, bts)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Error checking
|
||||||
|
if tt.expectedError == nil {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error %v, got nil", tt.expectedError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for specific error types
|
||||||
|
var unknownKeyErr ErrUnknownOllamaKey
|
||||||
|
if errors.As(tt.expectedError, &unknownKeyErr) {
|
||||||
|
var gotErr ErrUnknownOllamaKey
|
||||||
|
if !errors.As(err, &gotErr) {
|
||||||
|
t.Fatalf("expected ErrUnknownOllamaKey, got %T", err)
|
||||||
|
}
|
||||||
|
if unknownKeyErr.Key != gotErr.Key {
|
||||||
|
t.Errorf("expected key %q, got %q", unknownKeyErr.Key, gotErr.Key)
|
||||||
|
}
|
||||||
|
if unknownKeyErr.Message != gotErr.Message {
|
||||||
|
t.Errorf("expected message %q, got %q", unknownKeyErr.Message, gotErr.Message)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var statusErr StatusError
|
||||||
|
if errors.As(tt.expectedError, &statusErr) {
|
||||||
|
var gotErr StatusError
|
||||||
|
if !errors.As(err, &gotErr) {
|
||||||
|
t.Fatalf("expected StatusError, got %T", err)
|
||||||
|
}
|
||||||
|
if statusErr.StatusCode != gotErr.StatusCode {
|
||||||
|
t.Errorf("expected status code %d, got %d", statusErr.StatusCode, gotErr.StatusCode)
|
||||||
|
}
|
||||||
|
if statusErr.ErrorMessage != gotErr.ErrorMessage {
|
||||||
|
t.Errorf("expected error message %q, got %q", statusErr.ErrorMessage, gotErr.ErrorMessage)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For other errors, compare error strings
|
||||||
|
if err.Error() != tt.expectedError.Error() {
|
||||||
|
t.Errorf("expected error %q, got %q", tt.expectedError, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
74
api/errors.go
Normal file
74
api/errors.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const InvalidModelNameErrMsg = "invalid model name"
|
||||||
|
|
||||||
|
// API error responses
|
||||||
|
// ErrorCode represents a standardized error code identifier
|
||||||
|
type ErrorCode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ErrCodeUnknownKey ErrorCode = "unknown_key"
|
||||||
|
ErrCodeGeneral ErrorCode = "general" // Generic fallback error code
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrorResponse implements a structured error interface
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Message string `json:"error"` // Human-readable error message, uses 'error' field name for backwards compatibility
|
||||||
|
Code ErrorCode `json:"code"` // Machine-readable error code for programmatic handling, not response code
|
||||||
|
Data map[string]any `json:"data"` // Additional error specific data, if any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ErrorResponse) Error() string {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrUnknownOllamaKey struct {
|
||||||
|
Message string
|
||||||
|
Key string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ErrUnknownOllamaKey) Error() string {
|
||||||
|
return fmt.Sprintf("unauthorized: unknown ollama key %q", strings.TrimSpace(e.Key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ErrUnknownOllamaKey) FormatUserMessage(localKeys []string) string {
|
||||||
|
// The user should only be told to add the key if it is the same one that exists locally
|
||||||
|
if slices.Index(localKeys, e.Key) == -1 {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(`%s
|
||||||
|
|
||||||
|
Your ollama key is:
|
||||||
|
%s
|
||||||
|
Add your key at:
|
||||||
|
https://ollama.com/settings/keys`, e.Message, e.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusError is an error with an HTTP status code and message,
|
||||||
|
// it is parsed on the client-side and not returned from the API
|
||||||
|
type StatusError struct {
|
||||||
|
StatusCode int // e.g. 200
|
||||||
|
Status string // e.g. "200 OK"
|
||||||
|
ErrorMessage string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e StatusError) Error() string {
|
||||||
|
switch {
|
||||||
|
case e.Status != "" && e.ErrorMessage != "":
|
||||||
|
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
|
||||||
|
case e.Status != "":
|
||||||
|
return e.Status
|
||||||
|
case e.ErrorMessage != "":
|
||||||
|
return e.ErrorMessage
|
||||||
|
default:
|
||||||
|
// this should not happen
|
||||||
|
return "something went wrong, please see the ollama server logs for details"
|
||||||
|
}
|
||||||
|
}
|
21
api/types.go
21
api/types.go
@ -12,27 +12,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StatusError is an error with an HTTP status code and message.
|
|
||||||
type StatusError struct {
|
|
||||||
StatusCode int
|
|
||||||
Status string
|
|
||||||
ErrorMessage string `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e StatusError) Error() string {
|
|
||||||
switch {
|
|
||||||
case e.Status != "" && e.ErrorMessage != "":
|
|
||||||
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
|
|
||||||
case e.Status != "":
|
|
||||||
return e.Status
|
|
||||||
case e.ErrorMessage != "":
|
|
||||||
return e.ErrorMessage
|
|
||||||
default:
|
|
||||||
// this should not happen
|
|
||||||
return "something went wrong, please see the ollama server logs for details"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ImageData represents the raw binary data of an image file.
|
// ImageData represents the raw binary data of an image file.
|
||||||
type ImageData []byte
|
type ImageData []byte
|
||||||
|
|
||||||
|
31
cmd/cmd.go
31
cmd/cmd.go
@ -34,6 +34,7 @@ import (
|
|||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
@ -513,6 +514,24 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return generate(cmd, opts)
|
return generate(cmd, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func localPubKeys() ([]string, error) {
|
||||||
|
usrKey, err := auth.GetPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := []string{usrKey}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
// try the ollama service public key if on Linux
|
||||||
|
if svcKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub"); err == nil {
|
||||||
|
keys = append(keys, strings.TrimSpace(string(svcKey)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -561,21 +580,29 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
|||||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||||
|
|
||||||
n := model.ParseName(args[0])
|
n := model.ParseName(args[0])
|
||||||
|
isOllamaHost := strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com")
|
||||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||||
if spinner != nil {
|
if spinner != nil {
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
}
|
}
|
||||||
|
var ke api.ErrUnknownOllamaKey
|
||||||
|
if errors.As(err, &ke) && isOllamaHost {
|
||||||
|
// the user has not added their ollama key to ollama.com
|
||||||
|
// return an error with a more user-friendly message
|
||||||
|
locals, _ := localPubKeys()
|
||||||
|
return errors.New(ke.FormatUserMessage(locals))
|
||||||
|
}
|
||||||
if strings.Contains(err.Error(), "access denied") {
|
if strings.Contains(err.Error(), "access denied") {
|
||||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||||
}
|
}
|
||||||
return err
|
return fmt.Errorf("yoyoyo: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.Stop()
|
p.Stop()
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
|
|
||||||
destination := n.String()
|
destination := n.String()
|
||||||
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
|
if isOllamaHost {
|
||||||
destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
|
destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
|
||||||
}
|
}
|
||||||
fmt.Printf("\nYou can find your model at:\n\n")
|
fmt.Printf("\nYou can find your model at:\n\n")
|
||||||
|
@ -373,15 +373,13 @@ func TestGetModelfileName(t *testing.T) {
|
|||||||
|
|
||||||
func TestPushHandler(t *testing.T) {
|
func TestPushHandler(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
|
||||||
modelName string
|
modelName string
|
||||||
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
|
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
|
||||||
expectedError string
|
expectedError string
|
||||||
expectedOutput string
|
expectedOutput string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "successful push",
|
modelName: "successful-push",
|
||||||
modelName: "test-model",
|
|
||||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
@ -394,8 +392,8 @@ func TestPushHandler(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Name != "test-model" {
|
if req.Name != "successful-push" {
|
||||||
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
t.Errorf("expected model name 'successful-push', got %s", req.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simulate progress updates
|
// Simulate progress updates
|
||||||
@ -414,11 +412,10 @@ func TestPushHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
|
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/successful-push\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "unauthorized push",
|
modelName: "unauthorized-push",
|
||||||
modelName: "unauthorized-model",
|
|
||||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@ -433,10 +430,29 @@ func TestPushHandler(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
modelName: "unknown-key-err",
|
||||||
|
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||||
|
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
uerr := api.ErrUnknownOllamaKey{
|
||||||
|
Key: "aaa",
|
||||||
|
}
|
||||||
|
err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": uerr.Error(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: "unauthorized: unknown ollama key \"aaa\"",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.modelName, func(t *testing.T) {
|
||||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
|
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
|
||||||
handler(w, r)
|
handler(w, r)
|
||||||
|
@ -19,7 +19,6 @@ import (
|
|||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type MultilineState int
|
type MultilineState int
|
||||||
@ -220,7 +219,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
fn := func(resp api.ProgressResponse) error { return nil }
|
fn := func(resp api.ProgressResponse) error { return nil }
|
||||||
err = client.Create(cmd.Context(), req, fn)
|
err = client.Create(cmd.Context(), req, fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), errtypes.InvalidModelNameErrMsg) {
|
if strings.Contains(err.Error(), api.InvalidModelNameErrMsg) {
|
||||||
fmt.Printf("error: The model name '%s' is invalid\n", args[1])
|
fmt.Printf("error: The model name '%s' is invalid\n", args[1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
@ -30,6 +31,7 @@ import (
|
|||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
|
"github.com/ollama/ollama/types/registry"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -980,8 +982,6 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
|||||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||||
}
|
}
|
||||||
|
|
||||||
var errUnauthorized = errors.New("unauthorized: access denied")
|
|
||||||
|
|
||||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||||
for range 2 {
|
for range 2 {
|
||||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||||
@ -1019,13 +1019,33 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var re registry.Errs
|
||||||
|
if err := json.Unmarshal(responseBody, &re); err == nil && len(re.Errors) > 0 {
|
||||||
|
if re.HasCode(registry.ErrCodeAnonymous) {
|
||||||
|
// if the error is due to anonymous access return a custom error
|
||||||
|
// this error is used by the CLI to direct a user to add their key to an account
|
||||||
|
pubKey, nestedErr := auth.GetPublicKey()
|
||||||
|
if nestedErr != nil {
|
||||||
|
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||||
|
return nil, re
|
||||||
|
}
|
||||||
|
return nil, api.ErrUnknownOllamaKey{
|
||||||
|
Key: pubKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, re
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to returning the raw response if parsing fails
|
||||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
||||||
default:
|
default:
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errUnauthorized
|
// should never be reached
|
||||||
|
return nil, fmt.Errorf("failed to make upload request")
|
||||||
}
|
}
|
||||||
|
|
||||||
// testMakeRequestDialContext specifies the dial function for the http client in
|
// testMakeRequestDialContext specifies the dial function for the http client in
|
||||||
|
@ -36,7 +36,6 @@ import (
|
|||||||
"github.com/ollama/ollama/runners"
|
"github.com/ollama/ollama/runners"
|
||||||
"github.com/ollama/ollama/server/imageproc"
|
"github.com/ollama/ollama/server/imageproc"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
@ -610,7 +609,7 @@ func (s *Server) PushHandler(c *gin.Context) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := PushModel(ctx, model, regOpts, fn); err != nil {
|
if err := PushModel(ctx, model, regOpts, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- newErr(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -650,7 +649,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||||||
|
|
||||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||||
if !name.IsValid() {
|
if !name.IsValid() {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": api.InvalidModelNameErrMsg})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1550,3 +1549,24 @@ func handleScheduleError(c *gin.Context, name string, err error) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newErr creates a structured API ErrorResponse from an existing error
|
||||||
|
func newErr(err error) api.ErrorResponse {
|
||||||
|
if err == nil {
|
||||||
|
return api.ErrorResponse{}
|
||||||
|
}
|
||||||
|
// Default to just returning the generic error message
|
||||||
|
resp := api.ErrorResponse{
|
||||||
|
Code: api.ErrCodeGeneral,
|
||||||
|
Message: err.Error(),
|
||||||
|
}
|
||||||
|
// Add additional error specific data, if any
|
||||||
|
var keyErr api.ErrUnknownOllamaKey
|
||||||
|
if errors.As(err, &keyErr) {
|
||||||
|
resp.Code = api.ErrCodeUnknownKey
|
||||||
|
resp.Data = map[string]any{
|
||||||
|
"key": keyErr.Key,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
@ -1,21 +0,0 @@
|
|||||||
// Package errtypes contains custom error types
|
|
||||||
package errtypes
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
UnknownOllamaKeyErrMsg = "unknown ollama key"
|
|
||||||
InvalidModelNameErrMsg = "invalid model name"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO: This should have a structured response from the API
|
|
||||||
type UnknownOllamaKey struct {
|
|
||||||
Key string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *UnknownOllamaKey) Error() string {
|
|
||||||
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
|
|
||||||
}
|
|
37
types/registry/error.go
Normal file
37
types/registry/error.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const ErrCodeAnonymous = "ANONYMOUS_ACCESS_DENIED"
|
||||||
|
|
||||||
|
type Err struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errs represents the structure of error responses from the registry
|
||||||
|
// TODO (brucemacd): this struct should be imported from some shared package that is used between the registry and ollama
|
||||||
|
type Errs struct {
|
||||||
|
Errors []Err `json:"errors"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Errs) Error() string {
|
||||||
|
if len(e.Errors) == 0 {
|
||||||
|
return "unknown registry error"
|
||||||
|
}
|
||||||
|
var msgs []string
|
||||||
|
for _, err := range e.Errors {
|
||||||
|
msgs = append(msgs, fmt.Sprintf("%s: %s", err.Code, err.Message))
|
||||||
|
}
|
||||||
|
return strings.Join(msgs, "; ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Errs) HasCode(code string) bool {
|
||||||
|
return slices.ContainsFunc(e.Errors, func(err Err) bool {
|
||||||
|
return err.Code == code
|
||||||
|
})
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user