isLocal firstdraft

This commit is contained in:
Josh Yan 2024-07-05 14:18:25 -07:00
parent 413d368a6a
commit 10ea0987e9
5 changed files with 138 additions and 35 deletions

View File

@ -17,14 +17,20 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"path/filepath"
"runtime" "runtime"
"strings"
"time"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@ -394,3 +400,28 @@ func (c *Client) IsLocal() bool {
return false return false
} }
func Authorization(ctx context.Context, request *http.Request) (string, error) {
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
if err != nil {
return "", err
}
defer knownHostsFile.Close()
token, err := auth.Sign(ctx, data)
if err != nil {
return "", err
}
// interleave request data into the token
key, sig, _ := strings.Cut(token, ":")
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
}

View File

@ -10,42 +10,37 @@ import (
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
const defaultPrivateKey = "id_ed25519" const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) { func keyPath() (ssh.Signer, error) {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
return "", err return nil, err
}
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
} }
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath) privateKeyFile, err := os.ReadFile(keyPath)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err return nil, err
} }
privateKey, err := ssh.ParsePrivateKey(privateKeyFile) return ssh.ParsePrivateKey(privateKeyFile)
}
func GetPublicKey() (ssh.PublicKey, error) {
privateKey, err := keyPath()
// if privateKey, try public key directly
if err != nil { if err != nil {
return "", err return nil, err
} }
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) return privateKey.PublicKey(), nil
return strings.TrimSpace(string(publicKey)), nil
} }
func NewNonce(r io.Reader, length int) (string, error) { func NewNonce(r io.Reader, length int) (string, error) {
@ -58,25 +53,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
} }
func Sign(ctx context.Context, bts []byte) (string, error) { func Sign(ctx context.Context, bts []byte) (string, error) {
keyPath, err := keyPath() privateKey, err := keyPath()
if err != nil {
return "", err
}
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
}
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
if err != nil { if err != nil {
return "", err return "", err
} }
// get the pubkey, but remove the type // get the pubkey, but remove the type
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) publicKey, err := GetPublicKey()
parts := bytes.Split(publicKey, []byte(" ")) if err != nil {
return "", err
}
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
parts := bytes.Split(publicKeyBytes, []byte(" "))
if len(parts) < 2 { if len(parts) < 2 {
return "", fmt.Errorf("malformed public key") return "", fmt.Errorf("malformed public key")
} }

View File

@ -354,6 +354,12 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
return "", err return "", err
} }
authz, err := api.Authorization(ctx, request)
if err != nil {
return "", err
}
request.Header.Set("Authorization", authz)
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
request.Header.Set("X-Redirect-Create", "1") request.Header.Set("X-Redirect-Create", "1")
@ -499,11 +505,13 @@ func errFromUnknownKey(unknownKeyErr error) error {
if len(matches) > 0 { if len(matches) > 0 {
serverPubKey := matches[0] serverPubKey := matches[0]
localPubKey, err := auth.GetPublicKey() publicKey, err := auth.GetPublicKey()
if err != nil { if err != nil {
return unknownKeyErr return unknownKeyErr
} }
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
if runtime.GOOS == "linux" && serverPubKey != localPubKey { if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key // try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub") svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")

View File

@ -32,6 +32,7 @@ import (
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
"golang.org/x/crypto/ssh"
) )
var errCapabilityCompletion = errors.New("completion") var errCapabilityCompletion = errors.New("completion")
@ -1064,11 +1065,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if anonymous { if anonymous {
// no user is associated with the public key, and the request requires non-anonymous access // no user is associated with the public key, and the request requires non-anonymous access
pubKey, nestedErr := auth.GetPublicKey() pubKey, nestedErr := auth.GetPublicKey()
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
if nestedErr != nil { if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr)) slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, errUnauthorized return nil, errUnauthorized
} }
return nil, &errtypes.UnknownOllamaKey{Key: pubKey} return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
} }
// user is associated with the public key, but is not authorized to make the request // user is associated with the public key, but is not authorized to make the request
return nil, errUnauthorized return nil, errUnauthorized

View File

@ -4,10 +4,12 @@ import (
"bytes" "bytes"
"cmp" "cmp"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
@ -22,8 +24,10 @@ import (
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/crypto/ssh"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
@ -783,7 +787,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
return return
} }
if c.GetHeader("X-Redirect-Create") == "1" { if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
c.Header("LocalLocation", path) c.Header("LocalLocation", path)
c.Status(http.StatusTemporaryRedirect) c.Status(http.StatusTemporaryRedirect)
return return
@ -803,6 +807,74 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusCreated) c.Status(http.StatusCreated)
} }
func (s *Server) IsLocal(c *gin.Context) bool {
if authz := c.GetHeader("Authorization"); authz != "" {
parts := strings.Split(authz, ":")
if len(parts) != 3 {
return false
}
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
if err != nil {
return false
}
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
partialRequestData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return false
}
partialRequestDataParts := strings.Split(string(partialRequestData), ",")
if len(partialRequestDataParts) != 4 {
return false
}
/* timestamp, err := strconv.ParseInt(partialRequestDataParts[2], 10, 0)
if err != nil {
return false
}
t := time.Unix(timestamp, 0)
if time.Since(t) > 5*time.Minute || time.Until(t) > 5*time.Minute {
// token is invalid if timestamp +/- 5 minutes from current time
return false
} */
/* nonce := partialRequestDataParts[3]
if nonceCache.has(nonce) {
return false
}
nonceCache.add(nonce, 5*time.Minute) */
signature, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return false
}
serverPublicKey, err := auth.GetPublicKey()
if err != nil {
log.Fatal(err)
}
_, key, _ := bytes.Cut(bytes.TrimSpace(ssh.MarshalAuthorizedKey(serverPublicKey)), []byte(" "))
requestData := fmt.Sprintf("%s,%s", key, partialRequestData)
if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
return false
}
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
return true
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return false
}
return false
}
func isLocalIP(ip netip.Addr) bool { func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil { if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces { for _, iface := range interfaces {