isLocal firstdraft

This commit is contained in:
Josh Yan 2024-07-05 14:18:25 -07:00
parent 8ee1ada22a
commit 154b59c0b6
4 changed files with 119 additions and 1 deletions

View File

@ -17,14 +17,20 @@ import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/version"
@ -403,3 +409,28 @@ func (c *Client) IsLocal() bool {
return false
}
func Authorization(ctx context.Context, request *http.Request) (string, error) {
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
if err != nil {
return "", err
}
defer knownHostsFile.Close()
token, err := auth.Sign(ctx, data)
if err != nil {
return "", err
}
// interleave request data into the token
key, sig, _ := strings.Cut(token, ":")
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
}

View File

@ -24,6 +24,7 @@ func privateKey() (ssh.Signer, error) {
return nil, err
}
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath)
if os.IsNotExist(err) {
@ -36,11 +37,19 @@ func privateKey() (ssh.Signer, error) {
} else if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return nil, err
return nil, err
}
return ssh.ParsePrivateKey(privateKeyFile)
}
func GetPublicKey() (ssh.PublicKey, error) {
privateKey, err := keyPath()
// if privateKey, try public key directly
return ssh.ParsePrivateKey(privateKeyFile)
}
func GetPublicKey() (ssh.PublicKey, error) {
// try to read pubkey first
home, err := os.UserHomeDir()

View File

@ -351,6 +351,12 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
return "", err
}
authz, err := api.Authorization(ctx, request)
if err != nil {
return "", err
}
request.Header.Set("Authorization", authz)
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
request.Header.Set("X-Redirect-Create", "1")

View File

@ -4,10 +4,12 @@ import (
"bytes"
"cmp"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"log/slog"
"math"
"net"
@ -23,8 +25,10 @@ import (
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/ssh"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
@ -941,7 +945,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
return
}
if c.GetHeader("X-Redirect-Create") == "1" {
if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
c.Header("LocalLocation", path)
c.Status(http.StatusTemporaryRedirect)
return
@ -961,6 +965,74 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusCreated)
}
func (s *Server) IsLocal(c *gin.Context) bool {
if authz := c.GetHeader("Authorization"); authz != "" {
parts := strings.Split(authz, ":")
if len(parts) != 3 {
return false
}
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
if err != nil {
return false
}
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
partialRequestData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return false
}
partialRequestDataParts := strings.Split(string(partialRequestData), ",")
if len(partialRequestDataParts) != 4 {
return false
}
/* timestamp, err := strconv.ParseInt(partialRequestDataParts[2], 10, 0)
if err != nil {
return false
}
t := time.Unix(timestamp, 0)
if time.Since(t) > 5*time.Minute || time.Until(t) > 5*time.Minute {
// token is invalid if timestamp +/- 5 minutes from current time
return false
} */
/* nonce := partialRequestDataParts[3]
if nonceCache.has(nonce) {
return false
}
nonceCache.add(nonce, 5*time.Minute) */
signature, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return false
}
serverPublicKey, err := auth.GetPublicKey()
if err != nil {
log.Fatal(err)
}
_, key, _ := bytes.Cut(bytes.TrimSpace(ssh.MarshalAuthorizedKey(serverPublicKey)), []byte(" "))
requestData := fmt.Sprintf("%s,%s", key, partialRequestData)
if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
return false
}
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
return true
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return false
}
return false
}
func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces {