Compare commits
6 Commits
main
...
brucemacd/
Author | SHA1 | Date | |
---|---|---|---|
![]() |
375a662775 | ||
![]() |
ae9165d661 | ||
![]() |
a262b86a5e | ||
![]() |
4d5d3c3276 | ||
![]() |
ea90ee7415 | ||
![]() |
40134c6587 |
87
cmd/cmd.go
87
cmd/cmd.go
@ -8,6 +8,7 @@ import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -16,9 +17,11 @@ import (
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -29,16 +32,19 @@ import (
|
||||
"github.com/containerd/console"
|
||||
"github.com/mattn/go-runewidth"
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/pkg/browser"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@ -513,6 +519,76 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
func generateFingerprint(key string) string {
|
||||
hash := sha256.Sum256([]byte(key))
|
||||
fingerprint := base64.RawURLEncoding.EncodeToString(hash[:6])
|
||||
|
||||
var formatted strings.Builder
|
||||
for i, char := range fingerprint {
|
||||
if i > 0 && i%2 == 0 {
|
||||
formatted.WriteRune('-')
|
||||
}
|
||||
formatted.WriteRune(char)
|
||||
}
|
||||
|
||||
return formatted.String()
|
||||
}
|
||||
|
||||
// tryConnect handles key validation when a connection fails due to an unknown key.
|
||||
// It attempts to open the browser for interactive sessions to let users connect their key,
|
||||
// falling back to command-line instructions for non-interactive sessions.
|
||||
// Returns nil if browser flow succeeds, or an error with connection instructions otherwise.
|
||||
func tryConnect(unknownKeyErr error) error {
|
||||
// find SSH public key in the error message
|
||||
// TODO (brucemacd): the API should return structured errors so that this message parsing isn't needed
|
||||
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
||||
re := regexp.MustCompile(sshKeyPattern)
|
||||
matches := re.FindStringSubmatch(unknownKeyErr.Error())
|
||||
|
||||
if len(matches) > 0 {
|
||||
serverPubKey := matches[0]
|
||||
|
||||
localPubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||
// try the ollama service public key
|
||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
localPubKey = strings.TrimSpace(string(svcPubKey))
|
||||
}
|
||||
|
||||
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
|
||||
if serverPubKey != localPubKey {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
if term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
// URL encode the key and device name for the browser URL
|
||||
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(localPubKey))
|
||||
d, _ := os.Hostname()
|
||||
encodedDevice := url.QueryEscape(d)
|
||||
browserURL := fmt.Sprintf("https://ollama.com/connect?host=%s&key=%s", encodedDevice, encodedKey)
|
||||
|
||||
if err := browser.OpenURL(browserURL); err == nil {
|
||||
fmt.Printf("\nOpening browser to add your key...\n")
|
||||
fmt.Printf("\nCheck that this code matches what is shown in your browser:\n")
|
||||
fmt.Printf("\n %s\n", generateFingerprint(localPubKey))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// only return error for non-interactive terminals or if browser opening failed
|
||||
return fmt.Errorf("%s\nAdd your key at:\nhttps://ollama.com/settings/keys", unknownKeyErr.Error())
|
||||
}
|
||||
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@ -561,13 +637,22 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||
|
||||
n := model.ParseName(args[0])
|
||||
isOllamaHost := strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com")
|
||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
if p != nil {
|
||||
p.Stop()
|
||||
}
|
||||
if strings.Contains(err.Error(), "access denied") {
|
||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||
}
|
||||
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
|
||||
// the user has not added their ollama key to ollama.com
|
||||
// return an error with a more user-friendly message
|
||||
return tryConnect(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@ -575,7 +660,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
spinner.Stop()
|
||||
|
||||
destination := n.String()
|
||||
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
|
||||
if isOllamaHost {
|
||||
destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
|
||||
}
|
||||
fmt.Printf("\nYou can find your model at:\n\n")
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
)
|
||||
|
||||
func TestShowInfo(t *testing.T) {
|
||||
@ -373,15 +374,13 @@ func TestGetModelfileName(t *testing.T) {
|
||||
|
||||
func TestPushHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
|
||||
expectedError string
|
||||
expectedOutput string
|
||||
}{
|
||||
{
|
||||
name: "successful push",
|
||||
modelName: "test-model",
|
||||
modelName: "successful-push",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
@ -394,8 +393,8 @@ func TestPushHandler(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name != "test-model" {
|
||||
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
||||
if req.Name != "successful-push" {
|
||||
t.Errorf("expected model name 'successful-push', got %s", req.Name)
|
||||
}
|
||||
|
||||
// Simulate progress updates
|
||||
@ -414,11 +413,10 @@ func TestPushHandler(t *testing.T) {
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
|
||||
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/successful-push\n",
|
||||
},
|
||||
{
|
||||
name: "unauthorized push",
|
||||
modelName: "unauthorized-model",
|
||||
modelName: "unauthorized-push",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@ -433,10 +431,29 @@ func TestPushHandler(t *testing.T) {
|
||||
},
|
||||
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
||||
},
|
||||
{
|
||||
modelName: "unknown-key-err",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
uerr := errtypes.UnknownOllamaKey{
|
||||
Key: "aaa",
|
||||
}
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": uerr.Error(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedError: "unauthorized: unknown ollama key \"aaa\"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Run(tt.modelName, func(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
|
||||
handler(w, r)
|
||||
|
1
go.mod
1
go.mod
@ -22,6 +22,7 @@ require (
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
golang.org/x/image v0.22.0
|
||||
)
|
||||
|
||||
|
2
go.sum
2
go.sum
@ -159,6 +159,8 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
|
||||
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
||||
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
|
||||
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
@ -23,13 +23,16 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/registry"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@ -980,8 +983,6 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||
}
|
||||
|
||||
var errUnauthorized = errors.New("unauthorized: access denied")
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||
for range 2 {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
@ -1019,13 +1020,33 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
||||
}
|
||||
|
||||
var re registry.Errs
|
||||
if err := json.Unmarshal(responseBody, &re); err == nil && len(re.Errors) > 0 {
|
||||
if re.HasCode(registry.ErrCodeAnonymous) {
|
||||
// if the error is due to anonymous access return a custom error
|
||||
// this error is used by the CLI to direct a user to add their key to an account
|
||||
pubKey, nestedErr := auth.GetPublicKey()
|
||||
if nestedErr != nil {
|
||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||
return nil, re
|
||||
}
|
||||
return nil, errtypes.UnknownOllamaKey{
|
||||
Key: pubKey,
|
||||
}
|
||||
}
|
||||
return nil, re
|
||||
}
|
||||
|
||||
// Fallback to returning the raw response if parsing fails
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
||||
default:
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errUnauthorized
|
||||
// should never be reached
|
||||
return nil, fmt.Errorf("failed to make upload request")
|
||||
}
|
||||
|
||||
// testMakeRequestDialContext specifies the dial function for the http client in
|
||||
|
@ -16,6 +16,6 @@ type UnknownOllamaKey struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e *UnknownOllamaKey) Error() string {
|
||||
func (e UnknownOllamaKey) Error() string {
|
||||
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
|
||||
}
|
||||
|
37
types/registry/error.go
Normal file
37
types/registry/error.go
Normal file
@ -0,0 +1,37 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const ErrCodeAnonymous = "ANONYMOUS_ACCESS_DENIED"
|
||||
|
||||
type Err struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Errs represents the structure of error responses from the registry
|
||||
// TODO (brucemacd): this struct should be imported from some shared package that is used between the registry and ollama
|
||||
type Errs struct {
|
||||
Errors []Err `json:"errors"`
|
||||
}
|
||||
|
||||
func (e Errs) Error() string {
|
||||
if len(e.Errors) == 0 {
|
||||
return "unknown registry error"
|
||||
}
|
||||
var msgs []string
|
||||
for _, err := range e.Errors {
|
||||
msgs = append(msgs, fmt.Sprintf("%s: %s", err.Code, err.Message))
|
||||
}
|
||||
return strings.Join(msgs, "; ")
|
||||
}
|
||||
|
||||
func (e Errs) HasCode(code string) bool {
|
||||
return slices.ContainsFunc(e.Errors, func(err Err) bool {
|
||||
return err.Code == code
|
||||
})
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user