diff --git a/auth/auth.go b/auth/auth.go index 026b2a2c7..0301f8327 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -3,49 +3,75 @@ package auth import ( "bytes" "context" + "crypto/ed25519" "crypto/rand" "encoding/base64" + "encoding/pem" "fmt" "io" "log/slog" "os" "path/filepath" - "strings" "golang.org/x/crypto/ssh" ) const defaultPrivateKey = "id_ed25519" -func keyPath() (string, error) { +func privateKey() (ssh.Signer, error) { home, err := os.UserHomeDir() if err != nil { - return "", err - } - - return filepath.Join(home, ".ollama", defaultPrivateKey), nil -} - -func GetPublicKey() (string, error) { - keyPath, err := keyPath() - if err != nil { - return "", err + return nil, err } + keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) privateKeyFile, err := os.ReadFile(keyPath) if err != nil { slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) - return "", err + return nil, err } - privateKey, err := ssh.ParsePrivateKey(privateKeyFile) + return ssh.ParsePrivateKey(privateKeyFile) +} + +func GetPublicKey() (ssh.PublicKey, error) { + // try to read pubkey first + pubkey, err := readPubkey() + if err == nil { + return pubkey, nil + } + + privateKey, err := privateKey() + if err == nil { + return privateKey.PublicKey(), nil + } + + err = initializeKeypair() if err != nil { - return "", err + return nil, err } - publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) + return readPubkey() +} - return strings.TrimSpace(string(publicKey)), nil +func readPubkey() (ssh.PublicKey, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + + pubkeyPath := filepath.Join(home, ".ollama", defaultPrivateKey+".pub") + _, err = os.Stat(pubkeyPath) + if os.IsNotExist(err) { + return nil, fmt.Errorf("public key not found") + } + + pubKeyFile, err := os.ReadFile(pubkeyPath) + if err != nil { + return nil, fmt.Errorf("failed to read public key: %w", err) + } + + return ssh.ParsePublicKey(pubKeyFile) } func NewNonce(r io.Reader, length int) (string, error) { @@ -58,25 +84,20 @@ func NewNonce(r io.Reader, length int) (string, error) { } func Sign(ctx context.Context, bts []byte) (string, error) { - keyPath, err := keyPath() - if err != nil { - return "", err - } - - privateKeyFile, err := os.ReadFile(keyPath) - if err != nil { - slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) - return "", err - } - - privateKey, err := ssh.ParsePrivateKey(privateKeyFile) + privateKey, err := privateKey() if err != nil { return "", err } // get the pubkey, but remove the type - publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) - parts := bytes.Split(publicKey, []byte(" ")) + publicKey, err := GetPublicKey() + if err != nil { + return "", err + } + + publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey) + + parts := bytes.Split(publicKeyBytes, []byte(" ")) if len(parts) < 2 { return "", fmt.Errorf("malformed public key") } @@ -89,3 +110,49 @@ func Sign(ctx context.Context, bts []byte) (string, error) { // signature is : return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil } + +func initializeKeypair() error { + home, err := os.UserHomeDir() + if err != nil { + return err + } + + privKeyPath := filepath.Join(home, ".ollama", "id_ed25519") + pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub") + + _, err = os.Stat(privKeyPath) + if os.IsNotExist(err) { + fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath) + cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return err + } + + privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "") + if err != nil { + return err + } + + if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil { + return fmt.Errorf("could not create directory %w", err) + } + + if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil { + return err + } + + sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey) + if err != nil { + return err + } + + publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey) + + if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil { + return err + } + + fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes) + } + return nil +} \ No newline at end of file diff --git a/cmd/cmd.go b/cmd/cmd.go index b761d018f..c7cf0581a 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -4,10 +4,7 @@ import ( "archive/zip" "bytes" "context" - "crypto/ed25519" - "crypto/rand" "crypto/sha256" - "encoding/pem" "errors" "fmt" "io" @@ -379,11 +376,12 @@ func errFromUnknownKey(unknownKeyErr error) error { if len(matches) > 0 { serverPubKey := matches[0] - localPubKey, err := auth.GetPublicKey() + publicKey, err := auth.GetPublicKey() if err != nil { return unknownKeyErr } + localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey))) if runtime.GOOS == "linux" && serverPubKey != localPubKey { // try the ollama service public key svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub") @@ -1072,7 +1070,7 @@ func generate(cmd *cobra.Command, opts runOptions) error { } func RunServer(cmd *cobra.Command, _ []string) error { - if err := initializeKeypair(); err != nil { + if _, err := auth.GetPublicKey(); err != nil { return err } @@ -1089,52 +1087,6 @@ func RunServer(cmd *cobra.Command, _ []string) error { return err } -func initializeKeypair() error { - home, err := os.UserHomeDir() - if err != nil { - return err - } - - privKeyPath := filepath.Join(home, ".ollama", "id_ed25519") - pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub") - - _, err = os.Stat(privKeyPath) - if os.IsNotExist(err) { - fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath) - cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - return err - } - - privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "") - if err != nil { - return err - } - - if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil { - return fmt.Errorf("could not create directory %w", err) - } - - if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil { - return err - } - - sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey) - if err != nil { - return err - } - - publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey) - - if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil { - return err - } - - fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes) - } - return nil -} - func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { client, err := api.ClientFromEnvironment() if err != nil {