isLocal testing
This commit is contained in:
parent
09431f353d
commit
c507325288
@ -24,7 +24,6 @@ func privateKey() (ssh.Signer, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
|
||||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
|
@ -343,7 +343,6 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Header.Set("Authorization", authz)
|
request.Header.Set("Authorization", authz)
|
||||||
request.Header.Set("Timestamp", time.Now().Format(time.RFC3339))
|
|
||||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||||
request.Header.Set("X-Redirect-Create", "1")
|
request.Header.Set("X-Redirect-Create", "1")
|
||||||
|
|
||||||
|
@ -942,7 +942,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if c.GetHeader("X-Redirect-Create") == "1" && s.IsServerKeyPublicKey(c) {
|
if c.GetHeader("X-Redirect-Create") == "1" && s.isLocal(c) {
|
||||||
c.Header("LocalLocation", path)
|
c.Header("LocalLocation", path)
|
||||||
c.Status(http.StatusTemporaryRedirect)
|
c.Status(http.StatusTemporaryRedirect)
|
||||||
return
|
return
|
||||||
@ -962,7 +962,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
c.Status(http.StatusCreated)
|
c.Status(http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) IsServerKeyPublicKey(c *gin.Context) bool {
|
func (s *Server) isLocal(c *gin.Context) bool {
|
||||||
if authz := c.GetHeader("Authorization"); authz != "" {
|
if authz := c.GetHeader("Authorization"); authz != "" {
|
||||||
parts := strings.Split(authz, ":")
|
parts := strings.Split(authz, ":")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
@ -1000,15 +1000,6 @@ func (s *Server) IsServerKeyPublicKey(c *gin.Context) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
timestamp, err := time.Parse(time.RFC3339, c.GetHeader("Timestamp"))
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.Since(timestamp) > time.Minute {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
|
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -3,20 +3,27 @@ package server
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
@ -527,3 +534,108 @@ func TestNormalize(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsLocalReal(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
clientPubLoc := t.TempDir()
|
||||||
|
t.Setenv("HOME", clientPubLoc)
|
||||||
|
|
||||||
|
err := initializeKeypair()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(w)
|
||||||
|
ctx.Request = &http.Request{
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
|
||||||
|
requestURL := url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: "localhost:8080",
|
||||||
|
Path: "/api/blobs",
|
||||||
|
}
|
||||||
|
request := &http.Request{
|
||||||
|
Method: http.MethodPost,
|
||||||
|
URL: &requestURL,
|
||||||
|
}
|
||||||
|
s := &Server{}
|
||||||
|
|
||||||
|
authz, err := api.Authorization(ctx, request)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set client authorization header
|
||||||
|
ctx.Request.Header.Set("Authorization", authz)
|
||||||
|
if !s.isLocal(ctx) {
|
||||||
|
t.Fatal("Expected isLocal to return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("different server pubkey", func(t *testing.T) {
|
||||||
|
serverPubLoc := t.TempDir()
|
||||||
|
t.Setenv("HOME", serverPubLoc)
|
||||||
|
err := initializeKeypair()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.isLocal(ctx) {
|
||||||
|
t.Fatal("Expected isLocal to return false")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid pubkey", func(t *testing.T) {
|
||||||
|
ctx.Request.Header.Set("Authorization", "sha-25616:invalid")
|
||||||
|
if s.isLocal(ctx) {
|
||||||
|
t.Fatal("Expected isLocal to return false")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func initializeKeypair() error {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
|
||||||
|
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
|
||||||
|
|
||||||
|
_, err = os.Stat(privKeyPath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
|
||||||
|
cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
|
||||||
|
return fmt.Errorf("could not create directory %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
|
||||||
|
|
||||||
|
if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user