diff --git a/api/client.go b/api/client.go index c4f213463..0b3424f14 100644 --- a/api/client.go +++ b/api/client.go @@ -17,14 +17,20 @@ import ( "bufio" "bytes" "context" + "encoding/base64" "encoding/json" "fmt" "io" "net" "net/http" "net/url" + "os" + "path/filepath" "runtime" + "strings" + "time" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/version" @@ -394,3 +400,28 @@ func (c *Client) IsLocal() bool { return false } + +func Authorization(ctx context.Context, request *http.Request) (string, error) { + + data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix())) + + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600) + if err != nil { + return "", err + } + defer knownHostsFile.Close() + + token, err := auth.Sign(ctx, data) + if err != nil { + return "", err + } + + // interleave request data into the token + key, sig, _ := strings.Cut(token, ":") + return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil +} diff --git a/auth/auth.go b/auth/auth.go index 026b2a2c7..09c8c529a 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -10,42 +10,37 @@ import ( "log/slog" "os" "path/filepath" - "strings" "golang.org/x/crypto/ssh" ) const defaultPrivateKey = "id_ed25519" -func keyPath() (string, error) { +func keyPath() (ssh.Signer, error) { home, err := os.UserHomeDir() if err != nil { - return "", err - } - - return filepath.Join(home, ".ollama", defaultPrivateKey), nil -} - -func GetPublicKey() (string, error) { - keyPath, err := keyPath() - if err != nil { - return "", err + return nil, err } + keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) privateKeyFile, err := os.ReadFile(keyPath) if err != nil { slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) - return "", err + return nil, err } - privateKey, err := ssh.ParsePrivateKey(privateKeyFile) + return ssh.ParsePrivateKey(privateKeyFile) +} + +func GetPublicKey() (ssh.PublicKey, error) { + privateKey, err := keyPath() + // if privateKey, try public key directly + if err != nil { - return "", err + return nil, err } - publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) - - return strings.TrimSpace(string(publicKey)), nil + return privateKey.PublicKey(), nil } func NewNonce(r io.Reader, length int) (string, error) { @@ -58,25 +53,20 @@ func NewNonce(r io.Reader, length int) (string, error) { } func Sign(ctx context.Context, bts []byte) (string, error) { - keyPath, err := keyPath() - if err != nil { - return "", err - } - - privateKeyFile, err := os.ReadFile(keyPath) - if err != nil { - slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) - return "", err - } - - privateKey, err := ssh.ParsePrivateKey(privateKeyFile) + privateKey, err := keyPath() if err != nil { return "", err } // get the pubkey, but remove the type - publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) - parts := bytes.Split(publicKey, []byte(" ")) + publicKey, err := GetPublicKey() + if err != nil { + return "", err + } + + publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey) + + parts := bytes.Split(publicKeyBytes, []byte(" ")) if len(parts) < 2 { return "", fmt.Errorf("malformed public key") } diff --git a/cmd/cmd.go b/cmd/cmd.go index a2ecfad7e..acc68b3a4 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -354,6 +354,12 @@ func getLocalPath(ctx context.Context, digest string) (string, error) { return "", err } + authz, err := api.Authorization(ctx, request) + if err != nil { + return "", err + } + + request.Header.Set("Authorization", authz) request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) request.Header.Set("X-Redirect-Create", "1") @@ -499,11 +505,13 @@ func errFromUnknownKey(unknownKeyErr error) error { if len(matches) > 0 { serverPubKey := matches[0] - localPubKey, err := auth.GetPublicKey() + publicKey, err := auth.GetPublicKey() if err != nil { return unknownKeyErr } + localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey))) + if runtime.GOOS == "linux" && serverPubKey != localPubKey { // try the ollama service public key svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub") diff --git a/server/images.go b/server/images.go index 688d5dcae..791a81a15 100644 --- a/server/images.go +++ b/server/images.go @@ -32,6 +32,7 @@ import ( "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" + "golang.org/x/crypto/ssh" ) var errCapabilityCompletion = errors.New("completion") @@ -1064,11 +1065,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR if anonymous { // no user is associated with the public key, and the request requires non-anonymous access pubKey, nestedErr := auth.GetPublicKey() + localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey))) if nestedErr != nil { slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr)) return nil, errUnauthorized } - return nil, &errtypes.UnknownOllamaKey{Key: pubKey} + return nil, &errtypes.UnknownOllamaKey{Key: localPubKey} } // user is associated with the public key, but is not authorized to make the request return nil, errUnauthorized diff --git a/server/routes.go b/server/routes.go index 588155753..5e2cbca08 100644 --- a/server/routes.go +++ b/server/routes.go @@ -4,10 +4,12 @@ import ( "bytes" "cmp" "context" + "encoding/base64" "encoding/json" "errors" "fmt" "io" + "log" "log/slog" "net" "net/http" @@ -22,8 +24,10 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" + "golang.org/x/crypto/ssh" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" @@ -783,7 +787,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) { return } - if c.GetHeader("X-Redirect-Create") == "1" { + if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) { c.Header("LocalLocation", path) c.Status(http.StatusTemporaryRedirect) return @@ -803,6 +807,74 @@ func (s *Server) CreateBlobHandler(c *gin.Context) { c.Status(http.StatusCreated) } +func (s *Server) IsLocal(c *gin.Context) bool { + if authz := c.GetHeader("Authorization"); authz != "" { + parts := strings.Split(authz, ":") + if len(parts) != 3 { + return false + } + + clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0]))) + if err != nil { + return false + } + + // partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce + partialRequestData, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return false + } + + partialRequestDataParts := strings.Split(string(partialRequestData), ",") + if len(partialRequestDataParts) != 4 { + return false + } + + /* timestamp, err := strconv.ParseInt(partialRequestDataParts[2], 10, 0) + if err != nil { + return false + } + + t := time.Unix(timestamp, 0) + if time.Since(t) > 5*time.Minute || time.Until(t) > 5*time.Minute { + // token is invalid if timestamp +/- 5 minutes from current time + return false + } */ + + /* nonce := partialRequestDataParts[3] + if nonceCache.has(nonce) { + return false + } + nonceCache.add(nonce, 5*time.Minute) */ + + signature, err := base64.StdEncoding.DecodeString(parts[2]) + if err != nil { + return false + } + + serverPublicKey, err := auth.GetPublicKey() + if err != nil { + log.Fatal(err) + } + + _, key, _ := bytes.Cut(bytes.TrimSpace(ssh.MarshalAuthorizedKey(serverPublicKey)), []byte(" ")) + requestData := fmt.Sprintf("%s,%s", key, partialRequestData) + + if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil { + return false + } + + if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) { + return true + } + + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) + return false + } + + return false +} + func isLocalIP(ip netip.Addr) bool { if interfaces, err := net.Interfaces(); err == nil { for _, iface := range interfaces {