diff --git a/auth/auth.go b/auth/auth.go index 8056fddac..b3f34927a 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -24,7 +24,6 @@ func privateKey() (ssh.Signer, error) { return nil, err } - keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) privateKeyFile, err := os.ReadFile(keyPath) if os.IsNotExist(err) { diff --git a/cmd/cmd.go b/cmd/cmd.go index 768b5c841..06bc13332 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -343,7 +343,6 @@ func getLocalPath(ctx context.Context, digest string) (string, error) { } request.Header.Set("Authorization", authz) - request.Header.Set("Timestamp", time.Now().Format(time.RFC3339)) request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) request.Header.Set("X-Redirect-Create", "1") diff --git a/server/routes.go b/server/routes.go index 31328dc46..e601710f4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -942,7 +942,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) { c.Status(http.StatusOK) return } - if c.GetHeader("X-Redirect-Create") == "1" && s.IsServerKeyPublicKey(c) { + if c.GetHeader("X-Redirect-Create") == "1" && s.isLocal(c) { c.Header("LocalLocation", path) c.Status(http.StatusTemporaryRedirect) return @@ -962,7 +962,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) { c.Status(http.StatusCreated) } -func (s *Server) IsServerKeyPublicKey(c *gin.Context) bool { +func (s *Server) isLocal(c *gin.Context) bool { if authz := c.GetHeader("Authorization"); authz != "" { parts := strings.Split(authz, ":") if len(parts) != 3 { @@ -999,16 +999,7 @@ func (s *Server) IsServerKeyPublicKey(c *gin.Context) bool { slog.Error(fmt.Sprintf("failed to get server public key: %v", err)) return false } - - timestamp, err := time.Parse(time.RFC3339, c.GetHeader("Timestamp")) - if err != nil { - return false - } - - if time.Since(timestamp) > time.Minute { - return false - } - + if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) { return true } diff --git a/server/routes_test.go b/server/routes_test.go index 97786ba2b..b002fa71f 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -3,20 +3,27 @@ package server import ( "bytes" "context" + "crypto/ed25519" + "crypto/rand" "encoding/binary" "encoding/json" + "encoding/pem" "fmt" "io" "math" "net/http" "net/http/httptest" + "net/url" "os" + "path/filepath" "sort" "strings" "testing" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" @@ -527,3 +534,108 @@ func TestNormalize(t *testing.T) { }) } } + +func TestIsLocalReal(t *testing.T) { + gin.SetMode(gin.TestMode) + clientPubLoc := t.TempDir() + t.Setenv("HOME", clientPubLoc) + + err := initializeKeypair() + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + ctx.Request = &http.Request{ + Header: make(http.Header), + } + + requestURL := url.URL{ + Scheme: "http", + Host: "localhost:8080", + Path: "/api/blobs", + } + request := &http.Request{ + Method: http.MethodPost, + URL: &requestURL, + } + s := &Server{} + + authz, err := api.Authorization(ctx, request) + if err != nil { + t.Fatal(err) + } + + // Set client authorization header + ctx.Request.Header.Set("Authorization", authz) + if !s.isLocal(ctx) { + t.Fatal("Expected isLocal to return true") + } + + t.Run("different server pubkey", func(t *testing.T) { + serverPubLoc := t.TempDir() + t.Setenv("HOME", serverPubLoc) + err := initializeKeypair() + if err != nil { + t.Fatal(err) + } + + if s.isLocal(ctx) { + t.Fatal("Expected isLocal to return false") + } + }) + + t.Run("invalid pubkey", func(t *testing.T) { + ctx.Request.Header.Set("Authorization", "sha-25616:invalid") + if s.isLocal(ctx) { + t.Fatal("Expected isLocal to return false") + } + }) +} + +func initializeKeypair() error { + home, err := os.UserHomeDir() + if err != nil { + return err + } + + privKeyPath := filepath.Join(home, ".ollama", "id_ed25519") + pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub") + + _, err = os.Stat(privKeyPath) + if os.IsNotExist(err) { + fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath) + cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return err + } + + privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "") + if err != nil { + return err + } + + if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil { + return fmt.Errorf("could not create directory %w", err) + } + + if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil { + return err + } + + sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey) + if err != nil { + return err + } + + publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey) + + if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil { + return err + } + + fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes) + } + return nil +} \ No newline at end of file