diff --git a/auth/auth.go b/auth/auth.go index 0301f8327..69c5a39fe 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -26,7 +26,14 @@ func privateKey() (ssh.Signer, error) { keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) privateKeyFile, err := os.ReadFile(keyPath) - if err != nil { + if os.IsNotExist(err) { + err := initializeKeypair() + if err != nil { + return nil, err + } + + return privateKey() + } else if err != nil { slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) return nil, err } @@ -36,42 +43,27 @@ func privateKey() (ssh.Signer, error) { func GetPublicKey() (ssh.PublicKey, error) { // try to read pubkey first - pubkey, err := readPubkey() - if err == nil { - return pubkey, nil - } - - privateKey, err := privateKey() - if err == nil { - return privateKey.PublicKey(), nil - } - - err = initializeKeypair() - if err != nil { - return nil, err - } - - return readPubkey() -} - -func readPubkey() (ssh.PublicKey, error) { home, err := os.UserHomeDir() if err != nil { return nil, err } pubkeyPath := filepath.Join(home, ".ollama", defaultPrivateKey+".pub") - _, err = os.Stat(pubkeyPath) - if os.IsNotExist(err) { - return nil, fmt.Errorf("public key not found") - } - pubKeyFile, err := os.ReadFile(pubkeyPath) - if err != nil { + if os.IsNotExist(err) { + // try from privateKey + privateKey, err := privateKey() + if err != nil { + return nil, fmt.Errorf("failed to read public key: %w", err) + } + + return privateKey.PublicKey(), nil + } else if err != nil { return nil, fmt.Errorf("failed to read public key: %w", err) } - return ssh.ParsePublicKey(pubKeyFile) + pubKey, _, _, _, err := ssh.ParseAuthorizedKey(pubKeyFile) + return pubKey, err } func NewNonce(r io.Reader, length int) (string, error) {