Compare commits
43 Commits
royh-embed
...
jyan/progr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b6c7d01af3 | ||
|
|
9d517cf556 | ||
|
|
6bab0e2368 | ||
|
|
c4cccaf936 | ||
|
|
9fe5c393e4 | ||
|
|
007c988dba | ||
|
|
91d21e7c7b | ||
|
|
3e64284f69 | ||
|
|
39910f2ab2 | ||
|
|
96d0cd92f2 | ||
|
|
3a724a7c80 | ||
|
|
f520f0056e | ||
|
|
d25f85ede4 | ||
|
|
b48420b74b | ||
|
|
784958a1cb | ||
|
|
ae65cc8dea | ||
|
|
a037528bba | ||
|
|
04bf41deb5 | ||
|
|
c23cec9547 | ||
|
|
8377dc48d0 | ||
|
|
3aee405dfa | ||
|
|
9b3f47b674 | ||
|
|
f5441f01a2 | ||
|
|
ab165df43a | ||
|
|
79cc4c9585 | ||
|
|
bc3f59a6ad | ||
|
|
1a85cb904c | ||
|
|
10ea0987e9 | ||
|
|
413d368a6a | ||
|
|
cabf375059 | ||
|
|
ca0ee1d4fe | ||
|
|
1142999aab | ||
|
|
0d5a72aba9 | ||
|
|
ea837412c2 | ||
|
|
736ad6f438 | ||
|
|
64607d16a5 | ||
|
|
a6cfe7f00b | ||
|
|
c3b411a515 | ||
|
|
928f37e3ae | ||
|
|
2d1e3c3229 | ||
|
|
4918fae535 | ||
|
|
0aff67877e | ||
|
|
0bacb30007 |
@@ -17,14 +17,20 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@@ -374,3 +380,27 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
|||||||
|
|
||||||
return version.Version, nil
|
return version.Version, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Authorization(ctx context.Context, request *http.Request) (string, error) {
|
||||||
|
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
|
||||||
|
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer knownHostsFile.Close()
|
||||||
|
|
||||||
|
token, err := auth.Sign(ctx, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// interleave request data into the token
|
||||||
|
key, sig, _ := strings.Cut(token, ":")
|
||||||
|
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
|
||||||
|
}
|
||||||
|
|||||||
54
auth/auth.go
54
auth/auth.go
@@ -10,42 +10,37 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultPrivateKey = "id_ed25519"
|
const defaultPrivateKey = "id_ed25519"
|
||||||
|
|
||||||
func keyPath() (string, error) {
|
func keyPath() (ssh.Signer, error) {
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
|
||||||
|
|
||||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetPublicKey() (string, error) {
|
|
||||||
keyPath, err := keyPath()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
return ssh.ParsePrivateKey(privateKeyFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPublicKey() (ssh.PublicKey, error) {
|
||||||
|
privateKey, err := keyPath()
|
||||||
|
// if privateKey, try public key directly
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
return privateKey.PublicKey(), nil
|
||||||
|
|
||||||
return strings.TrimSpace(string(publicKey)), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNonce(r io.Reader, length int) (string, error) {
|
func NewNonce(r io.Reader, length int) (string, error) {
|
||||||
@@ -58,25 +53,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||||
keyPath, err := keyPath()
|
privateKey, err := keyPath()
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
|
||||||
if err != nil {
|
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the pubkey, but remove the type
|
// get the pubkey, but remove the type
|
||||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
publicKey, err := GetPublicKey()
|
||||||
parts := bytes.Split(publicKey, []byte(" "))
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
|
||||||
|
|
||||||
|
parts := bytes.Split(publicKeyBytes, []byte(" "))
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return "", fmt.Errorf("malformed public key")
|
return "", fmt.Errorf("malformed public key")
|
||||||
}
|
}
|
||||||
|
|||||||
163
cmd/cmd.go
163
cmd/cmd.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -15,6 +16,7 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -78,6 +80,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
status := "transferring model data"
|
status := "transferring model data"
|
||||||
spinner := progress.NewSpinner(status)
|
spinner := progress.NewSpinner(status)
|
||||||
p.Add(status, spinner)
|
p.Add(status, spinner)
|
||||||
|
defer p.Stop()
|
||||||
|
|
||||||
for i := range modelfile.Commands {
|
for i := range modelfile.Commands {
|
||||||
switch modelfile.Commands[i].Name {
|
switch modelfile.Commands[i].Name {
|
||||||
@@ -112,11 +115,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
path = tempfile
|
path = tempfile
|
||||||
}
|
}
|
||||||
|
|
||||||
digest, err := createBlob(cmd, client, path)
|
digest, err := createBlob(cmd, client, path, spinner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
modelfile.Commands[i].Args = "@" + digest
|
modelfile.Commands[i].Args = "@" + digest
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -138,7 +140,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
|
|
||||||
status = resp.Status
|
status = resp.Status
|
||||||
spinner = progress.NewSpinner(status)
|
spinner := progress.NewSpinner(status)
|
||||||
p.Add(status, spinner)
|
p.Add(status, spinner)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -263,13 +265,22 @@ func tempZipFiles(path string) (string, error) {
|
|||||||
return tempfile.Name(), nil
|
return tempfile.Name(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
|
var ErrBlobExists = errors.New("blob exists")
|
||||||
|
|
||||||
|
func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
|
||||||
bin, err := os.Open(path)
|
bin, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
defer bin.Close()
|
defer bin.Close()
|
||||||
|
|
||||||
|
// Get file info to retrieve the size
|
||||||
|
fileInfo, err := bin.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
fileSize := fileInfo.Size()
|
||||||
|
|
||||||
hash := sha256.New()
|
hash := sha256.New()
|
||||||
if _, err := io.Copy(hash, bin); err != nil {
|
if _, err := io.Copy(hash, bin); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -279,13 +290,151 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var pw progressWriter
|
||||||
|
status := "transferring model data 0%"
|
||||||
|
spinner.SetMessage(status)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(60 * time.Millisecond)
|
||||||
|
done := make(chan struct{})
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n/fileSize)))
|
||||||
|
case <-done:
|
||||||
|
spinner.SetMessage("transferring model data 100%")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||||
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
|
|
||||||
|
// We check if we can find the models directory locally
|
||||||
|
// If we can, we return the path to the directory
|
||||||
|
// If we can't, we return an error
|
||||||
|
// If the blob exists already, we return the digest
|
||||||
|
dest, err := getLocalPath(cmd.Context(), digest)
|
||||||
|
|
||||||
|
if errors.Is(err, ErrBlobExists) {
|
||||||
|
return digest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Successfully found the model directory
|
||||||
|
if err == nil {
|
||||||
|
// Copy blob in via OS specific copy
|
||||||
|
// Linux errors out to use io.copy
|
||||||
|
err = localCopy(path, dest)
|
||||||
|
if err == nil {
|
||||||
|
return digest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default copy using io.copy
|
||||||
|
err = defaultCopy(path, dest)
|
||||||
|
if err == nil {
|
||||||
|
return digest, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If at any point copying the blob over locally fails, we default to the copy through the server
|
||||||
|
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return digest, nil
|
return digest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type progressWriter struct {
|
||||||
|
n int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *progressWriter) Write(p []byte) (n int, err error) {
|
||||||
|
w.n += int64(len(p))
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLocalPath(ctx context.Context, digest string) (string, error) {
|
||||||
|
ollamaHost := envconfig.Host
|
||||||
|
|
||||||
|
client := http.DefaultClient
|
||||||
|
base := &url.URL{
|
||||||
|
Scheme: ollamaHost.Scheme,
|
||||||
|
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(digest)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := bytes.NewReader(data)
|
||||||
|
path := fmt.Sprintf("/api/blobs/%s", digest)
|
||||||
|
requestURL := base.JoinPath(path)
|
||||||
|
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
authz, err := api.Authorization(ctx, request)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
request.Header.Set("Authorization", authz)
|
||||||
|
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||||
|
request.Header.Set("X-Redirect-Create", "1")
|
||||||
|
|
||||||
|
resp, err := client.Do(request)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusTemporaryRedirect {
|
||||||
|
dest := resp.Header.Get("LocalLocation")
|
||||||
|
|
||||||
|
return dest, nil
|
||||||
|
}
|
||||||
|
return "", ErrBlobExists
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultCopy(path string, dest string) error {
|
||||||
|
// This function should be called if the server is local
|
||||||
|
// It should find the model directory, copy the blob over, and return the digest
|
||||||
|
dirPath := filepath.Dir(dest)
|
||||||
|
|
||||||
|
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy blob over
|
||||||
|
sourceFile, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not open source file: %v", err)
|
||||||
|
}
|
||||||
|
defer sourceFile.Close()
|
||||||
|
|
||||||
|
destFile, err := os.Create(dest)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not create destination file: %v", err)
|
||||||
|
}
|
||||||
|
defer destFile.Close()
|
||||||
|
|
||||||
|
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error copying file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = destFile.Sync()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error flushing file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
@@ -379,11 +528,13 @@ func errFromUnknownKey(unknownKeyErr error) error {
|
|||||||
if len(matches) > 0 {
|
if len(matches) > 0 {
|
||||||
serverPubKey := matches[0]
|
serverPubKey := matches[0]
|
||||||
|
|
||||||
localPubKey, err := auth.GetPublicKey()
|
publicKey, err := auth.GetPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return unknownKeyErr
|
return unknownKeyErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
|
||||||
|
|
||||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||||
// try the ollama service public key
|
// try the ollama service public key
|
||||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||||
|
|||||||
23
cmd/copy_darwin.go
Normal file
23
cmd/copy_darwin.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func localCopy(src, target string) error {
|
||||||
|
dirPath := filepath.Dir(target)
|
||||||
|
|
||||||
|
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err := unix.Clonefile(src, target, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
7
cmd/copy_linux.go
Normal file
7
cmd/copy_linux.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
func localCopy(src, target string) error {
|
||||||
|
return errors.New("no local copy implementation for linux")
|
||||||
|
}
|
||||||
67
cmd/copy_windows.go
Normal file
67
cmd/copy_windows.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
func localCopy(src, target string) error {
|
||||||
|
// Create target directory if it doesn't exist
|
||||||
|
dirPath := filepath.Dir(target)
|
||||||
|
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open source file
|
||||||
|
sourceFile, err := os.Open(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer sourceFile.Close()
|
||||||
|
|
||||||
|
// Create target file
|
||||||
|
targetFile, err := os.Create(target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer targetFile.Close()
|
||||||
|
|
||||||
|
// Use CopyFileExW to copy the file
|
||||||
|
err = copyFileEx(src, target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyFileEx(src, dst string) error {
|
||||||
|
kernel32 := syscall.NewLazyDLL("kernel32.dll")
|
||||||
|
copyFileEx := kernel32.NewProc("CopyFileExW")
|
||||||
|
|
||||||
|
srcPtr, err := syscall.UTF16PtrFromString(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dstPtr, err := syscall.UTF16PtrFromString(dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
r1, _, err := copyFileEx.Call(
|
||||||
|
uintptr(unsafe.Pointer(srcPtr)),
|
||||||
|
uintptr(unsafe.Pointer(dstPtr)),
|
||||||
|
0, 0, 0, 0)
|
||||||
|
|
||||||
|
if r1 == 0 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -254,7 +254,7 @@ if [ -z "${OLLAMA_SKIP_ROCM_GENERATE}" -a -d "${ROCM_PATH}" ]; then
|
|||||||
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
|
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
|
||||||
fi
|
fi
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DLLAMA_CUDA_NO_PEER_COPY=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
||||||
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
|
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
|
||||||
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
|
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
|
||||||
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""
|
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""
|
||||||
|
|||||||
@@ -366,6 +366,7 @@ function build_rocm() {
|
|||||||
"-DCMAKE_C_COMPILER=clang.exe",
|
"-DCMAKE_C_COMPILER=clang.exe",
|
||||||
"-DCMAKE_CXX_COMPILER=clang++.exe",
|
"-DCMAKE_CXX_COMPILER=clang++.exe",
|
||||||
"-DGGML_HIPBLAS=on",
|
"-DGGML_HIPBLAS=on",
|
||||||
|
"-DLLAMA_CUDA_NO_PEER_COPY=on",
|
||||||
"-DHIP_PLATFORM=amd",
|
"-DHIP_PLATFORM=amd",
|
||||||
"-DGGML_AVX=on",
|
"-DGGML_AVX=on",
|
||||||
"-DGGML_AVX2=off",
|
"-DGGML_AVX2=off",
|
||||||
|
|||||||
@@ -338,12 +338,16 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|||||||
switch stop := r.Stop.(type) {
|
switch stop := r.Stop.(type) {
|
||||||
case string:
|
case string:
|
||||||
options["stop"] = []string{stop}
|
options["stop"] = []string{stop}
|
||||||
case []string:
|
case []any:
|
||||||
options["stop"] = stop
|
var stops []string
|
||||||
default:
|
for _, s := range stop {
|
||||||
if r.Stop != nil {
|
if str, ok := s.(string); ok {
|
||||||
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
|
stops = append(stops, str)
|
||||||
|
} else {
|
||||||
|
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
options["stop"] = stops
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.MaxTokens != nil {
|
if r.MaxTokens != nil {
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -16,7 +15,133 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMiddleware(t *testing.T) {
|
func TestMiddlewareRequests(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Method string
|
||||||
|
Path string
|
||||||
|
Handler func() gin.HandlerFunc
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *http.Request)
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *http.Request
|
||||||
|
|
||||||
|
captureRequestMiddleware := func() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
capturedRequest = c.Request
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
Name: "chat handler",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/chat",
|
||||||
|
Handler: ChatMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var chatReq api.ChatRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatReq.Messages[0].Role != "user" {
|
||||||
|
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatReq.Messages[0].Content != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "completions handler",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/generate",
|
||||||
|
Handler: CompletionsMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
temp := float32(0.8)
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Temperature: &temp,
|
||||||
|
Stop: []string{"\n", "stop"},
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var genReq api.GenerateRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if genReq.Prompt != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if genReq.Options["temperature"] != 1.6 {
|
||||||
|
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
|
||||||
|
}
|
||||||
|
|
||||||
|
stopTokens, ok := genReq.Options["stop"].([]any)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected stop tokens to be a list")
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
||||||
|
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
router = gin.New()
|
||||||
|
router.Use(captureRequestMiddleware())
|
||||||
|
router.Use(tc.Handler())
|
||||||
|
router.Handle(tc.Method, tc.Path, endpoint)
|
||||||
|
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
|
||||||
|
|
||||||
|
if tc.Setup != nil {
|
||||||
|
tc.Setup(t, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
tc.Expected(t, capturedRequest)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareResponses(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
Name string
|
Name string
|
||||||
Method string
|
Method string
|
||||||
@@ -30,159 +155,7 @@ func TestMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "chat handler",
|
Name: "completions handler error forwarding",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/chat",
|
|
||||||
TestPath: "/api/chat",
|
|
||||||
Handler: ChatMiddleware,
|
|
||||||
Endpoint: func(c *gin.Context) {
|
|
||||||
var chatReq api.ChatRequest
|
|
||||||
if err := c.ShouldBindJSON(&chatReq); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userMessage := chatReq.Messages[0].Content
|
|
||||||
var assistantMessage string
|
|
||||||
|
|
||||||
switch userMessage {
|
|
||||||
case "Hello":
|
|
||||||
assistantMessage = "Hello!"
|
|
||||||
default:
|
|
||||||
assistantMessage = "I'm not sure how to respond to that."
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, api.ChatResponse{
|
|
||||||
Message: api.Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: assistantMessage,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
},
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
body := ChatCompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
||||||
assert.Equal(t, http.StatusOK, resp.Code)
|
|
||||||
|
|
||||||
var chatResp ChatCompletion
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if chatResp.Object != "chat.completion" {
|
|
||||||
t.Fatalf("expected chat.completion, got %s", chatResp.Object)
|
|
||||||
}
|
|
||||||
|
|
||||||
if chatResp.Choices[0].Message.Content != "Hello!" {
|
|
||||||
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "completions handler",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/generate",
|
|
||||||
TestPath: "/api/generate",
|
|
||||||
Handler: CompletionsMiddleware,
|
|
||||||
Endpoint: func(c *gin.Context) {
|
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
||||||
Response: "Hello!",
|
|
||||||
})
|
|
||||||
},
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
body := CompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Prompt: "Hello",
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
||||||
assert.Equal(t, http.StatusOK, resp.Code)
|
|
||||||
var completionResp Completion
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if completionResp.Object != "text_completion" {
|
|
||||||
t.Fatalf("expected text_completion, got %s", completionResp.Object)
|
|
||||||
}
|
|
||||||
|
|
||||||
if completionResp.Choices[0].Text != "Hello!" {
|
|
||||||
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "completions handler with params",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/generate",
|
|
||||||
TestPath: "/api/generate",
|
|
||||||
Handler: CompletionsMiddleware,
|
|
||||||
Endpoint: func(c *gin.Context) {
|
|
||||||
var generateReq api.GenerateRequest
|
|
||||||
if err := c.ShouldBindJSON(&generateReq); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
temperature := generateReq.Options["temperature"].(float64)
|
|
||||||
var assistantMessage string
|
|
||||||
|
|
||||||
switch temperature {
|
|
||||||
case 1.6:
|
|
||||||
assistantMessage = "Received temperature of 1.6"
|
|
||||||
default:
|
|
||||||
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
||||||
Response: assistantMessage,
|
|
||||||
})
|
|
||||||
},
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
temp := float32(0.8)
|
|
||||||
body := CompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Prompt: "Hello",
|
|
||||||
Temperature: &temp,
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
||||||
assert.Equal(t, http.StatusOK, resp.Code)
|
|
||||||
var completionResp Completion
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if completionResp.Object != "text_completion" {
|
|
||||||
t.Fatalf("expected text_completion, got %s", completionResp.Object)
|
|
||||||
}
|
|
||||||
|
|
||||||
if completionResp.Choices[0].Text != "Received temperature of 1.6" {
|
|
||||||
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "completions handler with error",
|
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Path: "/api/generate",
|
Path: "/api/generate",
|
||||||
TestPath: "/api/generate",
|
TestPath: "/api/generate",
|
||||||
|
|||||||
@@ -31,6 +31,10 @@ func NewSpinner(message string) *Spinner {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Spinner) SetMessage(message string) {
|
||||||
|
s.message = message
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Spinner) String() string {
|
func (s *Spinner) String() string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
if len(s.message) > 0 {
|
if len(s.message) > 0 {
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ import (
|
|||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errCapabilityCompletion = errors.New("completion")
|
var errCapabilityCompletion = errors.New("completion")
|
||||||
@@ -1064,11 +1065,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|||||||
if anonymous {
|
if anonymous {
|
||||||
// no user is associated with the public key, and the request requires non-anonymous access
|
// no user is associated with the public key, and the request requires non-anonymous access
|
||||||
pubKey, nestedErr := auth.GetPublicKey()
|
pubKey, nestedErr := auth.GetPublicKey()
|
||||||
|
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
|
||||||
if nestedErr != nil {
|
if nestedErr != nil {
|
||||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||||
return nil, errUnauthorized
|
return nil, errUnauthorized
|
||||||
}
|
}
|
||||||
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
|
||||||
}
|
}
|
||||||
// user is associated with the public key, but is not authorized to make the request
|
// user is associated with the public key, but is not authorized to make the request
|
||||||
return nil, errUnauthorized
|
return nil, errUnauthorized
|
||||||
|
|||||||
@@ -4,10 +4,12 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"cmp"
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -22,8 +24,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
@@ -770,7 +774,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = os.Stat(path)
|
_, err = os.Stat(path)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, os.ErrNotExist):
|
case errors.Is(err, os.ErrNotExist):
|
||||||
@@ -783,6 +786,12 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
|
||||||
|
c.Header("LocalLocation", path)
|
||||||
|
c.Status(http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
layer, err := NewLayer(c.Request.Body, "")
|
layer, err := NewLayer(c.Request.Body, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -797,6 +806,54 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
c.Status(http.StatusCreated)
|
c.Status(http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) IsLocal(c *gin.Context) bool {
|
||||||
|
if authz := c.GetHeader("Authorization"); authz != "" {
|
||||||
|
parts := strings.Split(authz, ":")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
|
||||||
|
requestData, err := base64.StdEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
partialRequestDataParts := strings.Split(string(requestData), ",")
|
||||||
|
if len(partialRequestDataParts) != 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := base64.StdEncoding.DecodeString(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := clientPublicKey.Verify(requestData, &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
serverPublicKey, err := auth.GetPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func isLocalIP(ip netip.Addr) bool {
|
func isLocalIP(ip netip.Addr) bool {
|
||||||
if interfaces, err := net.Interfaces(); err == nil {
|
if interfaces, err := net.Interfaces(); err == nil {
|
||||||
for _, iface := range interfaces {
|
for _, iface := range interfaces {
|
||||||
|
|||||||
Reference in New Issue
Block a user