isLocal firstdraft
This commit is contained in:
parent
413d368a6a
commit
10ea0987e9
@ -17,14 +17,20 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@ -394,3 +400,28 @@ func (c *Client) IsLocal() bool {
|
|||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Authorization(ctx context.Context, request *http.Request) (string, error) {
|
||||||
|
|
||||||
|
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
|
||||||
|
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer knownHostsFile.Close()
|
||||||
|
|
||||||
|
token, err := auth.Sign(ctx, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// interleave request data into the token
|
||||||
|
key, sig, _ := strings.Cut(token, ":")
|
||||||
|
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
|
||||||
|
}
|
||||||
|
54
auth/auth.go
54
auth/auth.go
@ -10,42 +10,37 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultPrivateKey = "id_ed25519"
|
const defaultPrivateKey = "id_ed25519"
|
||||||
|
|
||||||
func keyPath() (string, error) {
|
func keyPath() (ssh.Signer, error) {
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
|
||||||
|
|
||||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetPublicKey() (string, error) {
|
|
||||||
keyPath, err := keyPath()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
return ssh.ParsePrivateKey(privateKeyFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPublicKey() (ssh.PublicKey, error) {
|
||||||
|
privateKey, err := keyPath()
|
||||||
|
// if privateKey, try public key directly
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
return privateKey.PublicKey(), nil
|
||||||
|
|
||||||
return strings.TrimSpace(string(publicKey)), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNonce(r io.Reader, length int) (string, error) {
|
func NewNonce(r io.Reader, length int) (string, error) {
|
||||||
@ -58,25 +53,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||||
keyPath, err := keyPath()
|
privateKey, err := keyPath()
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
|
||||||
if err != nil {
|
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the pubkey, but remove the type
|
// get the pubkey, but remove the type
|
||||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
publicKey, err := GetPublicKey()
|
||||||
parts := bytes.Split(publicKey, []byte(" "))
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
|
||||||
|
|
||||||
|
parts := bytes.Split(publicKeyBytes, []byte(" "))
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return "", fmt.Errorf("malformed public key")
|
return "", fmt.Errorf("malformed public key")
|
||||||
}
|
}
|
||||||
|
10
cmd/cmd.go
10
cmd/cmd.go
@ -354,6 +354,12 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
authz, err := api.Authorization(ctx, request)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
request.Header.Set("Authorization", authz)
|
||||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||||
request.Header.Set("X-Redirect-Create", "1")
|
request.Header.Set("X-Redirect-Create", "1")
|
||||||
|
|
||||||
@ -499,11 +505,13 @@ func errFromUnknownKey(unknownKeyErr error) error {
|
|||||||
if len(matches) > 0 {
|
if len(matches) > 0 {
|
||||||
serverPubKey := matches[0]
|
serverPubKey := matches[0]
|
||||||
|
|
||||||
localPubKey, err := auth.GetPublicKey()
|
publicKey, err := auth.GetPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return unknownKeyErr
|
return unknownKeyErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
|
||||||
|
|
||||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||||
// try the ollama service public key
|
// try the ollama service public key
|
||||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||||
|
@ -32,6 +32,7 @@ import (
|
|||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errCapabilityCompletion = errors.New("completion")
|
var errCapabilityCompletion = errors.New("completion")
|
||||||
@ -1064,11 +1065,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|||||||
if anonymous {
|
if anonymous {
|
||||||
// no user is associated with the public key, and the request requires non-anonymous access
|
// no user is associated with the public key, and the request requires non-anonymous access
|
||||||
pubKey, nestedErr := auth.GetPublicKey()
|
pubKey, nestedErr := auth.GetPublicKey()
|
||||||
|
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
|
||||||
if nestedErr != nil {
|
if nestedErr != nil {
|
||||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||||
return nil, errUnauthorized
|
return nil, errUnauthorized
|
||||||
}
|
}
|
||||||
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
|
||||||
}
|
}
|
||||||
// user is associated with the public key, but is not authorized to make the request
|
// user is associated with the public key, but is not authorized to make the request
|
||||||
return nil, errUnauthorized
|
return nil, errUnauthorized
|
||||||
|
@ -4,10 +4,12 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"cmp"
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -22,8 +24,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
@ -783,7 +787,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.GetHeader("X-Redirect-Create") == "1" {
|
if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
|
||||||
c.Header("LocalLocation", path)
|
c.Header("LocalLocation", path)
|
||||||
c.Status(http.StatusTemporaryRedirect)
|
c.Status(http.StatusTemporaryRedirect)
|
||||||
return
|
return
|
||||||
@ -803,6 +807,74 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
c.Status(http.StatusCreated)
|
c.Status(http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) IsLocal(c *gin.Context) bool {
|
||||||
|
if authz := c.GetHeader("Authorization"); authz != "" {
|
||||||
|
parts := strings.Split(authz, ":")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
|
||||||
|
partialRequestData, err := base64.StdEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
partialRequestDataParts := strings.Split(string(partialRequestData), ",")
|
||||||
|
if len(partialRequestDataParts) != 4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
/* timestamp, err := strconv.ParseInt(partialRequestDataParts[2], 10, 0)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
t := time.Unix(timestamp, 0)
|
||||||
|
if time.Since(t) > 5*time.Minute || time.Until(t) > 5*time.Minute {
|
||||||
|
// token is invalid if timestamp +/- 5 minutes from current time
|
||||||
|
return false
|
||||||
|
} */
|
||||||
|
|
||||||
|
/* nonce := partialRequestDataParts[3]
|
||||||
|
if nonceCache.has(nonce) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
nonceCache.add(nonce, 5*time.Minute) */
|
||||||
|
|
||||||
|
signature, err := base64.StdEncoding.DecodeString(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
serverPublicKey, err := auth.GetPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, key, _ := bytes.Cut(bytes.TrimSpace(ssh.MarshalAuthorizedKey(serverPublicKey)), []byte(" "))
|
||||||
|
requestData := fmt.Sprintf("%s,%s", key, partialRequestData)
|
||||||
|
|
||||||
|
if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func isLocalIP(ip netip.Addr) bool {
|
func isLocalIP(ip netip.Addr) bool {
|
||||||
if interfaces, err := net.Interfaces(); err == nil {
|
if interfaces, err := net.Interfaces(); err == nil {
|
||||||
for _, iface := range interfaces {
|
for _, iface := range interfaces {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user