Compare commits
12 Commits
jyan/local
...
royh-opena
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
568416ba17 | ||
|
|
80cba42ab2 | ||
|
|
6477a7aca4 | ||
|
|
51214ddef5 | ||
|
|
b950d749a9 | ||
|
|
3702ed7532 | ||
|
|
6266603b17 | ||
|
|
2644c4e682 | ||
|
|
04cde43b2a | ||
|
|
105e36765d | ||
|
|
fa7be5aab4 | ||
|
|
02169f3e60 |
@@ -295,7 +295,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
||||||
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
||||||
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
|
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
|
||||||
- [Sidellama](https://github.com/gyopak/sidellama) (browser-based LLM client)
|
|
||||||
|
|
||||||
### Terminal
|
### Terminal
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -25,10 +24,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/auth"
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@@ -387,16 +383,3 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
|||||||
|
|
||||||
return version.Version, nil
|
return version.Version, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Authorization(ctx context.Context, request *http.Request) (string, error) {
|
|
||||||
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
|
|
||||||
|
|
||||||
token, err := auth.Sign(ctx, data)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
// interleave request data into the token
|
|
||||||
key, sig, _ := strings.Cut(token, ":")
|
|
||||||
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
|
|
||||||
}
|
|
||||||
|
|||||||
78
api/types.go
78
api/types.go
@@ -101,29 +101,46 @@ type ChatRequest struct {
|
|||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
// Tools is an optional list of tools the model has access to.
|
// Tools is an optional list of tools the model has access to.
|
||||||
Tools `json:"tools,omitempty"`
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tools []Tool
|
|
||||||
|
|
||||||
func (t Tools) String() string {
|
|
||||||
bts, _ := json.Marshal(t)
|
|
||||||
return string(bts)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Message is a single message in a chat sequence. The message contains the
|
// Message is a single message in a chat sequence. The message contains the
|
||||||
// role ("system", "user", or "assistant"), the content and an optional list
|
// role ("system", "user", or "assistant"), the content and an optional list
|
||||||
// of images.
|
// of images.
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content,omitempty"`
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
Function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]any `json:"arguments"`
|
||||||
|
} `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
} `json:"parameters"`
|
||||||
|
} `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||||
type Alias Message
|
type Alias Message
|
||||||
var a Alias
|
var a Alias
|
||||||
@@ -136,46 +153,6 @@ func (m *Message) UnmarshalJSON(b []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
Function ToolCallFunction `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCallFunction struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCallFunctionArguments map[string]any
|
|
||||||
|
|
||||||
func (t *ToolCallFunctionArguments) String() string {
|
|
||||||
bts, _ := json.Marshal(t)
|
|
||||||
return string(bts)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Function ToolFunction `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolFunction struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Parameters struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
Properties map[string]struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []string `json:"enum,omitempty"`
|
|
||||||
} `json:"properties"`
|
|
||||||
} `json:"parameters"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *ToolFunction) String() string {
|
|
||||||
bts, _ := json.Marshal(t)
|
|
||||||
return string(bts)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||||
// similar to [GenerateResponse].
|
// similar to [GenerateResponse].
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
@@ -428,6 +405,9 @@ type GenerateResponse struct {
|
|||||||
// Response is the textual response itself.
|
// Response is the textual response itself.
|
||||||
Response string `json:"response"`
|
Response string `json:"response"`
|
||||||
|
|
||||||
|
// ToolCalls is the list of tools the model wants to call
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
|
||||||
// Done specifies if the response is complete.
|
// Done specifies if the response is complete.
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
|
||||||
|
|||||||
120
auth/auth.go
120
auth/auth.go
@@ -3,68 +3,49 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/pem"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultPrivateKey = "id_ed25519"
|
const defaultPrivateKey = "id_ed25519"
|
||||||
|
|
||||||
func privateKey() (ssh.Signer, error) {
|
func keyPath() (string, error) {
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
|
||||||
err := initializeKeypair()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return privateKey()
|
|
||||||
} else if err != nil {
|
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ssh.ParsePrivateKey(privateKeyFile)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPublicKey() (ssh.PublicKey, error) {
|
func GetPublicKey() (string, error) {
|
||||||
// try to read pubkey first
|
keyPath, err := keyPath()
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
pubkeyPath := filepath.Join(home, ".ollama", defaultPrivateKey+".pub")
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
pubKeyFile, err := os.ReadFile(pubkeyPath)
|
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
|
||||||
// try from privateKey
|
|
||||||
privateKey, err := privateKey()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read public key: %w", err)
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||||
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return privateKey.PublicKey(), nil
|
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||||
} else if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read public key: %w", err)
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
pubKey, _, _, _, err := ssh.ParseAuthorizedKey(pubKeyFile)
|
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||||
return pubKey, err
|
|
||||||
|
return strings.TrimSpace(string(publicKey)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNonce(r io.Reader, length int) (string, error) {
|
func NewNonce(r io.Reader, length int) (string, error) {
|
||||||
@@ -77,20 +58,25 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||||
privateKey, err := privateKey()
|
keyPath, err := keyPath()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the pubkey, but remove the type
|
// get the pubkey, but remove the type
|
||||||
publicKey, err := GetPublicKey()
|
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||||
if err != nil {
|
parts := bytes.Split(publicKey, []byte(" "))
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
|
|
||||||
|
|
||||||
parts := bytes.Split(publicKeyBytes, []byte(" "))
|
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return "", fmt.Errorf("malformed public key")
|
return "", fmt.Errorf("malformed public key")
|
||||||
}
|
}
|
||||||
@@ -103,49 +89,3 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
|
|||||||
// signature is <pubkey>:<signature>
|
// signature is <pubkey>:<signature>
|
||||||
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
|
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func initializeKeypair() error {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
|
|
||||||
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
|
|
||||||
|
|
||||||
_, err = os.Stat(privKeyPath)
|
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
|
||||||
fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
|
|
||||||
cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
|
|
||||||
return fmt.Errorf("could not create directory %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
|
|
||||||
|
|
||||||
if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
117
cmd/cmd.go
117
cmd/cmd.go
@@ -4,7 +4,10 @@ import (
|
|||||||
"archive/zip"
|
"archive/zip"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -12,7 +15,6 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -110,7 +112,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
path = tempfile
|
path = tempfile
|
||||||
}
|
}
|
||||||
|
|
||||||
digest, err := createBlob(cmd, path)
|
digest, err := createBlob(cmd, client, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -261,9 +263,7 @@ func tempZipFiles(path string) (string, error) {
|
|||||||
return tempfile.Name(), nil
|
return tempfile.Name(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrBlobExists = errors.New("blob exists")
|
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
|
||||||
|
|
||||||
func createBlob(cmd *cobra.Command, path string) (string, error) {
|
|
||||||
bin, err := os.Open(path)
|
bin, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -280,65 +280,12 @@ func createBlob(cmd *cobra.Command, path string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||||
|
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
|
||||||
// Use our new CreateBlob request which will include the file path
|
|
||||||
// The server checks for that file and if the server is local, it will copy the file over
|
|
||||||
// If the local copy fails, the server will continue to the default local copy
|
|
||||||
// If that fails, it will continue with the server POST
|
|
||||||
err = CreateBlob(cmd.Context(), path, digest, bin)
|
|
||||||
if errors.Is(err, ErrBlobExists) {
|
|
||||||
return digest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return digest, nil
|
return digest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateBlob(ctx context.Context, src, digest string, r *os.File) (error) {
|
|
||||||
ollamaHost := envconfig.Host
|
|
||||||
|
|
||||||
client := http.DefaultClient
|
|
||||||
base := &url.URL{
|
|
||||||
Scheme: ollamaHost.Scheme,
|
|
||||||
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
|
|
||||||
}
|
|
||||||
|
|
||||||
path := fmt.Sprintf("/api/blobs/%s", digest)
|
|
||||||
requestURL := base.JoinPath(path)
|
|
||||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), r)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
authz, err := api.Authorization(ctx, request)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
request.Header.Set("Authorization", authz)
|
|
||||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
|
||||||
request.Header.Set("X-Ollama-File", src)
|
|
||||||
|
|
||||||
resp, err := client.Do(request)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusCreated {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
return ErrBlobExists
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
@@ -432,12 +379,11 @@ func errFromUnknownKey(unknownKeyErr error) error {
|
|||||||
if len(matches) > 0 {
|
if len(matches) > 0 {
|
||||||
serverPubKey := matches[0]
|
serverPubKey := matches[0]
|
||||||
|
|
||||||
publicKey, err := auth.GetPublicKey()
|
localPubKey, err := auth.GetPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return unknownKeyErr
|
return unknownKeyErr
|
||||||
}
|
}
|
||||||
|
|
||||||
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
|
|
||||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||||
// try the ollama service public key
|
// try the ollama service public key
|
||||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||||
@@ -1126,7 +1072,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||||
if _, err := auth.GetPublicKey(); err != nil {
|
if err := initializeKeypair(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1143,6 +1089,52 @@ func RunServer(cmd *cobra.Command, _ []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initializeKeypair() error {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
|
||||||
|
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
|
||||||
|
|
||||||
|
_, err = os.Stat(privKeyPath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
|
||||||
|
cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
|
||||||
|
return fmt.Errorf("could not create directory %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
|
||||||
|
|
||||||
|
if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1352,6 +1344,7 @@ func NewCLI() *cobra.Command {
|
|||||||
envVars["OLLAMA_TMPDIR"],
|
envVars["OLLAMA_TMPDIR"],
|
||||||
envVars["OLLAMA_FLASH_ATTENTION"],
|
envVars["OLLAMA_FLASH_ATTENTION"],
|
||||||
envVars["OLLAMA_LLM_LIBRARY"],
|
envVars["OLLAMA_LLM_LIBRARY"],
|
||||||
|
envVars["OLLAMA_MAX_VRAM"],
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
appendEnvDocs(cmd, envs)
|
appendEnvDocs(cmd, envs)
|
||||||
|
|||||||
15
docs/gpu.md
15
docs/gpu.md
@@ -46,24 +46,13 @@ sudo modprobe nvidia_uvm`
|
|||||||
|
|
||||||
## AMD Radeon
|
## AMD Radeon
|
||||||
Ollama supports the following AMD GPUs:
|
Ollama supports the following AMD GPUs:
|
||||||
|
|
||||||
### Linux Support
|
|
||||||
| Family | Cards and accelerators |
|
| Family | Cards and accelerators |
|
||||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
|
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
|
||||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
|
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
|
||||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
|
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
|
||||||
|
|
||||||
### Windows Support
|
### Overrides
|
||||||
With ROCm v6.1, the following GPUs are supported on Windows.
|
|
||||||
|
|
||||||
| Family | Cards and accelerators |
|
|
||||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
||||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
|
|
||||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
|
|
||||||
|
|
||||||
|
|
||||||
### Overrides on Linux
|
|
||||||
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
|
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
|
||||||
some cases you can force the system to try to use a similar LLVM target that is
|
some cases you can force the system to try to use a similar LLVM target that is
|
||||||
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
|
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
|
||||||
@@ -74,7 +63,7 @@ would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the
|
|||||||
server. If you have an unsupported AMD GPU you can experiment using the list of
|
server. If you have an unsupported AMD GPU you can experiment using the list of
|
||||||
supported types below.
|
supported types below.
|
||||||
|
|
||||||
At this time, the known supported GPU types on linux are the following LLVM Targets.
|
At this time, the known supported GPU types are the following LLVM Targets.
|
||||||
This table shows some example GPUs that map to these LLVM targets:
|
This table shows some example GPUs that map to these LLVM targets:
|
||||||
| **LLVM Target** | **An Example GPU** |
|
| **LLVM Target** | **An Example GPU** |
|
||||||
|-----------------|---------------------|
|
|-----------------|---------------------|
|
||||||
|
|||||||
@@ -27,6 +27,11 @@ chat_completion = client.chat.completions.create(
|
|||||||
],
|
],
|
||||||
model='llama3',
|
model='llama3',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
completion = client.completions.create(
|
||||||
|
model="llama3",
|
||||||
|
prompt="Say this is a test"
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### OpenAI JavaScript library
|
### OpenAI JavaScript library
|
||||||
@@ -45,6 +50,11 @@ const chatCompletion = await openai.chat.completions.create({
|
|||||||
messages: [{ role: 'user', content: 'Say this is a test' }],
|
messages: [{ role: 'user', content: 'Say this is a test' }],
|
||||||
model: 'llama3',
|
model: 'llama3',
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const completion = await openai.completions.create({
|
||||||
|
model: "llama3",
|
||||||
|
prompt: "Say this is a test.",
|
||||||
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
### `curl`
|
### `curl`
|
||||||
@@ -66,6 +76,12 @@ curl http://localhost:11434/v1/chat/completions \
|
|||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
|
|
||||||
|
curl http://localhost:11434/v1/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "llama3",
|
||||||
|
"prompt": "Say this is a test"
|
||||||
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
## Endpoints
|
## Endpoints
|
||||||
@@ -103,6 +119,73 @@ curl http://localhost:11434/v1/chat/completions \
|
|||||||
- [ ] `user`
|
- [ ] `user`
|
||||||
- [ ] `n`
|
- [ ] `n`
|
||||||
|
|
||||||
|
### `/v1/completions`
|
||||||
|
|
||||||
|
#### Supported features
|
||||||
|
|
||||||
|
- [x] Completions
|
||||||
|
- [x] Streaming
|
||||||
|
- [x] JSON mode
|
||||||
|
- [x] Reproducible outputs
|
||||||
|
- [ ] Logprobs
|
||||||
|
|
||||||
|
#### Supported request fields
|
||||||
|
|
||||||
|
- [x] `model`
|
||||||
|
- [x] `prompt`
|
||||||
|
- [x] `frequency_penalty`
|
||||||
|
- [x] `presence_penalty`
|
||||||
|
- [x] `seed`
|
||||||
|
- [x] `stop`
|
||||||
|
- [x] `stream`
|
||||||
|
- [x] `temperature`
|
||||||
|
- [x] `top_p`
|
||||||
|
- [x] `max_tokens`
|
||||||
|
- [x] `suffix`
|
||||||
|
- [ ] `best_of`
|
||||||
|
- [ ] `echo`
|
||||||
|
- [ ] `logit_bias`
|
||||||
|
- [ ] `user`
|
||||||
|
- [ ] `n`
|
||||||
|
|
||||||
|
#### Notes
|
||||||
|
|
||||||
|
- `prompt` currently only accepts a string
|
||||||
|
|
||||||
|
### `/v1/completions`
|
||||||
|
|
||||||
|
#### Supported features
|
||||||
|
|
||||||
|
- [x] Completions
|
||||||
|
- [x] Streaming
|
||||||
|
- [x] JSON mode
|
||||||
|
- [x] Reproducible outputs
|
||||||
|
- [ ] Logprobs
|
||||||
|
|
||||||
|
#### Supported request fields
|
||||||
|
|
||||||
|
- [x] `model`
|
||||||
|
- [x] `prompt`
|
||||||
|
- [x] `frequency_penalty`
|
||||||
|
- [x] `presence_penalty`
|
||||||
|
- [x] `seed`
|
||||||
|
- [x] `stop`
|
||||||
|
- [x] `stream`
|
||||||
|
- [x] `temperature`
|
||||||
|
- [x] `top_p`
|
||||||
|
- [x] `max_tokens`
|
||||||
|
- [ ] `best_of`
|
||||||
|
- [ ] `echo`
|
||||||
|
- [ ] `suffix`
|
||||||
|
- [ ] `logit_bias`
|
||||||
|
- [ ] `user`
|
||||||
|
- [ ] `n`
|
||||||
|
|
||||||
|
#### Notes
|
||||||
|
|
||||||
|
- `prompt` currently only accepts a string
|
||||||
|
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
|
||||||
|
|
||||||
## Models
|
## Models
|
||||||
|
|
||||||
Before using a model, pull it locally `ollama pull`:
|
Before using a model, pull it locally `ollama pull`:
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ var (
|
|||||||
MaxRunners int
|
MaxRunners int
|
||||||
// Set via OLLAMA_MAX_QUEUE in the environment
|
// Set via OLLAMA_MAX_QUEUE in the environment
|
||||||
MaxQueuedRequests int
|
MaxQueuedRequests int
|
||||||
|
// Set via OLLAMA_MAX_VRAM in the environment
|
||||||
|
MaxVRAM uint64
|
||||||
// Set via OLLAMA_MODELS in the environment
|
// Set via OLLAMA_MODELS in the environment
|
||||||
ModelsDir string
|
ModelsDir string
|
||||||
// Set via OLLAMA_NOHISTORY in the environment
|
// Set via OLLAMA_NOHISTORY in the environment
|
||||||
@@ -87,6 +89,7 @@ func AsMap() map[string]EnvVar {
|
|||||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
|
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
|
||||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
|
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
|
||||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
|
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
|
||||||
|
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
|
||||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
|
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
|
||||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
|
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
|
||||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
|
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
|
||||||
@@ -191,6 +194,16 @@ func LoadConfig() {
|
|||||||
|
|
||||||
TmpDir = clean("OLLAMA_TMPDIR")
|
TmpDir = clean("OLLAMA_TMPDIR")
|
||||||
|
|
||||||
|
userLimit := clean("OLLAMA_MAX_VRAM")
|
||||||
|
if userLimit != "" {
|
||||||
|
avail, err := strconv.ParseUint(userLimit, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err)
|
||||||
|
} else {
|
||||||
|
MaxVRAM = avail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LLMLibrary = clean("OLLAMA_LLM_LIBRARY")
|
LLMLibrary = clean("OLLAMA_LLM_LIBRARY")
|
||||||
|
|
||||||
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
|
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
|
||||||
|
|||||||
@@ -33,10 +33,9 @@ type HipLib struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewHipLib() (*HipLib, error) {
|
func NewHipLib() (*HipLib, error) {
|
||||||
// At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs
|
h, err := windows.LoadLibrary("amdhip64.dll")
|
||||||
h, err := windows.LoadLibrary("amdhip64_6.dll")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err)
|
return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err)
|
||||||
}
|
}
|
||||||
hl := &HipLib{}
|
hl := &HipLib{}
|
||||||
hl.dll = h
|
hl.dll = h
|
||||||
|
|||||||
@@ -92,8 +92,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if gfxOverride == "" {
|
if gfxOverride == "" {
|
||||||
// Strip off Target Features when comparing
|
if !slices.Contains[[]string, string](supported, gfx) {
|
||||||
if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) {
|
|
||||||
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
||||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
|||||||
reqLimit := len(req)
|
reqLimit := len(req)
|
||||||
iterLimit := 5
|
iterLimit := 5
|
||||||
|
|
||||||
vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
|
vram := os.Getenv("OLLAMA_MAX_VRAM")
|
||||||
if vram != "" {
|
if vram != "" {
|
||||||
max, err := strconv.ParseUint(vram, 10, 64)
|
max, err := strconv.ParseUint(vram, 10, 64)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -106,7 +106,7 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
|||||||
|
|
||||||
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
|
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
|
||||||
func TestMultiModelStress(t *testing.T) {
|
func TestMultiModelStress(t *testing.T) {
|
||||||
vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
|
vram := os.Getenv("OLLAMA_MAX_VRAM")
|
||||||
if vram == "" {
|
if vram == "" {
|
||||||
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
|
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
func TestContextExhaustion(t *testing.T) {
|
func TestContextExhaustion(t *testing.T) {
|
||||||
// Longer needed for small footprint GPUs
|
// Longer needed for small footprint GPUs
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
@@ -25,10 +25,5 @@ func TestContextExhaustion(t *testing.T) {
|
|||||||
"num_ctx": 128,
|
"num_ctx": 128,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
|
||||||
defer cleanup()
|
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
|
||||||
t.Fatalf("PullIfMissing failed: %v", err)
|
|
||||||
}
|
|
||||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ function amdGPUs {
|
|||||||
return $env:AMDGPU_TARGETS
|
return $env:AMDGPU_TARGETS
|
||||||
}
|
}
|
||||||
# Current supported rocblas list from ROCm v6.1.2 on windows
|
# Current supported rocblas list from ROCm v6.1.2 on windows
|
||||||
# https://rocm.docs.amd.com/projects/install-on-windows/en/latest/reference/system-requirements.html#windows-supported-gpus
|
|
||||||
$GPU_LIST = @(
|
$GPU_LIST = @(
|
||||||
|
"gfx906:xnack-"
|
||||||
"gfx1030"
|
"gfx1030"
|
||||||
"gfx1100"
|
"gfx1100"
|
||||||
"gfx1101"
|
"gfx1101"
|
||||||
|
|||||||
Submodule llm/llama.cpp updated: d94c6e0ccb...a8db2a9ce6
@@ -1,8 +1,8 @@
|
|||||||
diff --git a/src/llama.cpp b/src/llama.cpp
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
index 8fe51971..7113ba64 100644
|
index 2b9ace28..172640e2 100644
|
||||||
--- a/src/llama.cpp
|
--- a/src/llama.cpp
|
||||||
+++ b/src/llama.cpp
|
+++ b/src/llama.cpp
|
||||||
@@ -5433,16 +5433,7 @@ static void llm_load_vocab(
|
@@ -5357,16 +5357,7 @@ static void llm_load_vocab(
|
||||||
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
|
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
|
||||||
vocab.tokenizer_add_space_prefix = false;
|
vocab.tokenizer_add_space_prefix = false;
|
||||||
vocab.tokenizer_clean_spaces = true;
|
vocab.tokenizer_clean_spaces = true;
|
||||||
@@ -20,9 +20,9 @@ index 8fe51971..7113ba64 100644
|
|||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "llama3" ||
|
tokenizer_pre == "llama3" ||
|
||||||
@@ -5526,7 +5517,8 @@ static void llm_load_vocab(
|
@@ -5439,7 +5430,8 @@ static void llm_load_vocab(
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM;
|
tokenizer_pre == "jais") {
|
||||||
vocab.tokenizer_clean_spaces = false;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
|
||||||
} else {
|
} else {
|
||||||
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||||
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
||||||
|
|||||||
13
llm/patches/06-qwen2.diff
Normal file
13
llm/patches/06-qwen2.diff
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
|
index 40d2ec2c..f34eb79a 100644
|
||||||
|
--- a/src/llama.cpp
|
||||||
|
+++ b/src/llama.cpp
|
||||||
|
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||||
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
|
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
|
||||||
|
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
|
||||||
|
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
||||||
|
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
||||||
|
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
@@ -1,360 +0,0 @@
|
|||||||
diff --git a/common/common.cpp b/common/common.cpp
|
|
||||||
index dbb724fb..c26fe6ee 100644
|
|
||||||
--- a/common/common.cpp
|
|
||||||
+++ b/common/common.cpp
|
|
||||||
@@ -2087,14 +2087,29 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|
||||||
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
|
|
||||||
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
|
|
||||||
float lora_scale = std::get<1>(params.lora_adapter[i]);
|
|
||||||
+
|
|
||||||
+ // try to load as gguf
|
|
||||||
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
|
|
||||||
if (adapter == nullptr) {
|
|
||||||
- fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
|
||||||
- llama_free(lctx);
|
|
||||||
- llama_free_model(model);
|
|
||||||
- return std::make_tuple(nullptr, nullptr);
|
|
||||||
+ fprintf(stderr, "%s: error: failed to apply lora adapter, trying ggla\n", __func__);
|
|
||||||
+
|
|
||||||
+ // if that fails, try loading as ggla for compatibility
|
|
||||||
+ int err = llama_model_apply_lora_from_file(model,
|
|
||||||
+ lora_adapter.c_str(),
|
|
||||||
+ lora_scale,
|
|
||||||
+ ((i > 0) || params.lora_base.empty())
|
|
||||||
+ ? NULL
|
|
||||||
+ : params.lora_base.c_str(),
|
|
||||||
+ params.n_threads);
|
|
||||||
+ if (err != 0) {
|
|
||||||
+ fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
|
||||||
+ llama_free(lctx);
|
|
||||||
+ llama_free_model(model);
|
|
||||||
+ return std::make_tuple(nullptr, nullptr);
|
|
||||||
+ }
|
|
||||||
+ } else {
|
|
||||||
+ llama_lora_adapter_set(lctx, adapter, lora_scale);
|
|
||||||
}
|
|
||||||
- llama_lora_adapter_set(lctx, adapter, lora_scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.ignore_eos) {
|
|
||||||
diff --git a/include/llama.h b/include/llama.h
|
|
||||||
index 93fd77ca..b0fb37a6 100644
|
|
||||||
--- a/include/llama.h
|
|
||||||
+++ b/include/llama.h
|
|
||||||
@@ -1160,6 +1160,20 @@ extern "C" {
|
|
||||||
|
|
||||||
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
|
|
||||||
|
|
||||||
+ // Apply a LoRA adapter to a loaded model
|
|
||||||
+ // path_base_model is the path to a higher quality model to use as a base for
|
|
||||||
+ // the layers modified by the adapter. Can be NULL to use the current loaded model.
|
|
||||||
+ // The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
|
||||||
+ // will be applied on top of the previous one
|
|
||||||
+ // Returns 0 on success
|
|
||||||
+ LLAMA_API int32_t llama_model_apply_lora_from_file(
|
|
||||||
+ const struct llama_model * model,
|
|
||||||
+ const char * path_lora,
|
|
||||||
+ float scale,
|
|
||||||
+ const char * path_base_model,
|
|
||||||
+ int32_t n_threads);
|
|
||||||
+
|
|
||||||
+
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
diff --git a/src/llama.cpp b/src/llama.cpp
|
|
||||||
index 80a0dd0f..9d7b0e17 100644
|
|
||||||
--- a/src/llama.cpp
|
|
||||||
+++ b/src/llama.cpp
|
|
||||||
@@ -21880,3 +21880,290 @@ static void llama_log_callback_default(ggml_log_level level, const char * text,
|
|
||||||
fputs(text, stderr);
|
|
||||||
fflush(stderr);
|
|
||||||
}
|
|
||||||
+
|
|
||||||
+static int llama_apply_lora_from_file_internal(
|
|
||||||
+ const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
|
|
||||||
+) {
|
|
||||||
+ LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
|
|
||||||
+
|
|
||||||
+ const int64_t t_start_lora_us = ggml_time_us();
|
|
||||||
+
|
|
||||||
+ llama_file fin(path_lora, "rb");
|
|
||||||
+
|
|
||||||
+ // verify magic and version
|
|
||||||
+ {
|
|
||||||
+ uint32_t magic = fin.read_u32();
|
|
||||||
+ if (magic != LLAMA_FILE_MAGIC_GGLA) {
|
|
||||||
+ LLAMA_LOG_ERROR("%s: bad file magic\n", __func__);
|
|
||||||
+ return 1;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ uint32_t format_version = fin.read_u32();
|
|
||||||
+ if (format_version != 1) {
|
|
||||||
+ LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
|
|
||||||
+ return 1;
|
|
||||||
+ }
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ int32_t lora_r = fin.read_u32();
|
|
||||||
+ int32_t lora_alpha = fin.read_u32();
|
|
||||||
+ float scaling = scale * (float)lora_alpha / (float)lora_r;
|
|
||||||
+
|
|
||||||
+ LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
|
|
||||||
+
|
|
||||||
+ // load base model
|
|
||||||
+ std::unique_ptr<llama_model_loader> ml;
|
|
||||||
+ if (path_base_model) {
|
|
||||||
+ LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
|
|
||||||
+ ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
|
|
||||||
+ ml->init_mappings(/*prefetch*/ false); // no prefetching
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ struct tensor_meta {
|
|
||||||
+ std::string name;
|
|
||||||
+ ggml_type type;
|
|
||||||
+ int32_t ne[2];
|
|
||||||
+ size_t offset;
|
|
||||||
+ };
|
|
||||||
+ std::map<std::string, tensor_meta> tensor_meta_map;
|
|
||||||
+
|
|
||||||
+ // load all tensor meta
|
|
||||||
+ while (true) {
|
|
||||||
+ if (fin.tell() == fin.size) {
|
|
||||||
+ // eof
|
|
||||||
+ break;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ int32_t n_dims;
|
|
||||||
+ int32_t name_len;
|
|
||||||
+ int32_t ftype;
|
|
||||||
+
|
|
||||||
+ fin.read_raw(&n_dims, sizeof(n_dims));
|
|
||||||
+ fin.read_raw(&name_len, sizeof(name_len));
|
|
||||||
+ fin.read_raw(&ftype, sizeof(ftype));
|
|
||||||
+
|
|
||||||
+ if (n_dims != 1 && n_dims != 2) {
|
|
||||||
+ LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
|
|
||||||
+ return 1;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ int32_t ne[2] = { 1, 1 };
|
|
||||||
+ for (int i = 0; i < n_dims; ++i) {
|
|
||||||
+ fin.read_raw(&ne[i], sizeof(ne[i]));
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ std::string name;
|
|
||||||
+ {
|
|
||||||
+ GGML_ASSERT(name_len < GGML_MAX_NAME);
|
|
||||||
+ char buf[GGML_MAX_NAME];
|
|
||||||
+ fin.read_raw(buf, name_len);
|
|
||||||
+ name = std::string(buf, name_len);
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ // check for lora suffix
|
|
||||||
+ std::string lora_suffix;
|
|
||||||
+ if (name.length() > 6) {
|
|
||||||
+ lora_suffix = name.substr(name.length() - 6);
|
|
||||||
+ }
|
|
||||||
+ if (lora_suffix != ".loraA" && lora_suffix != ".loraB") {
|
|
||||||
+ LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
|
|
||||||
+ return 1;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ // tensor type
|
|
||||||
+ ggml_type wtype;
|
|
||||||
+ switch (ftype) {
|
|
||||||
+ case 0: wtype = GGML_TYPE_F32; break;
|
|
||||||
+ case 1: wtype = GGML_TYPE_F16; break;
|
|
||||||
+ default:
|
|
||||||
+ {
|
|
||||||
+ LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n",
|
|
||||||
+ __func__, ftype);
|
|
||||||
+ return 1;
|
|
||||||
+ }
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ // data offset
|
|
||||||
+ size_t offset = fin.tell();
|
|
||||||
+ offset = (offset + 31) & -32;
|
|
||||||
+
|
|
||||||
+ // skip tensor data
|
|
||||||
+ fin.seek(offset + ggml_row_size(wtype, ne[0]) * ne[1], SEEK_SET);
|
|
||||||
+
|
|
||||||
+ tensor_meta_map.emplace(name, tensor_meta{ name, wtype, { ne[0], ne[1] }, offset });
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ bool warned = false;
|
|
||||||
+ int n_tensors = 0;
|
|
||||||
+
|
|
||||||
+ // apply
|
|
||||||
+ ggml_backend_t backend_cpu = ggml_backend_cpu_init();
|
|
||||||
+ if (backend_cpu == nullptr) {
|
|
||||||
+ LLAMA_LOG_ERROR("%s: error: failed to initialize cpu backend\n", __func__);
|
|
||||||
+ return 1;
|
|
||||||
+ }
|
|
||||||
+ ggml_backend_cpu_set_n_threads(backend_cpu, n_threads);
|
|
||||||
+
|
|
||||||
+ std::vector<no_init<uint8_t>> read_buf;
|
|
||||||
+ for (const auto & it : model.tensors_by_name) {
|
|
||||||
+ const std::string & base_name = it.first;
|
|
||||||
+ ggml_tensor * model_t = it.second;
|
|
||||||
+
|
|
||||||
+ if (tensor_meta_map.find(base_name + ".loraA") == tensor_meta_map.end() ||
|
|
||||||
+ tensor_meta_map.find(base_name + ".loraB") == tensor_meta_map.end()) {
|
|
||||||
+ continue;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ tensor_meta & metaA = tensor_meta_map.at(base_name + ".loraA");
|
|
||||||
+ tensor_meta & metaB = tensor_meta_map.at(base_name + ".loraB");
|
|
||||||
+
|
|
||||||
+ ggml_init_params lora_init_params = {
|
|
||||||
+ /* .mem_size */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
|
|
||||||
+ /* .mem_buffer */ nullptr,
|
|
||||||
+ /* .no_alloc */ true,
|
|
||||||
+ };
|
|
||||||
+ ggml_context * lora_ctx = ggml_init(lora_init_params);
|
|
||||||
+ if (lora_ctx == nullptr) {
|
|
||||||
+ LLAMA_LOG_ERROR("%s: error: failed to initialize lora context\n", __func__);
|
|
||||||
+ ggml_backend_free(backend_cpu);
|
|
||||||
+ return 1;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ // create tensors
|
|
||||||
+ ggml_tensor * loraA = ggml_new_tensor_2d(lora_ctx, metaA.type, metaA.ne[0], metaA.ne[1]);
|
|
||||||
+ ggml_tensor * loraB = ggml_new_tensor_2d(lora_ctx, metaB.type, metaB.ne[0], metaB.ne[1]);
|
|
||||||
+ ggml_set_name(loraA, metaA.name.c_str());
|
|
||||||
+ ggml_set_name(loraB, metaB.name.c_str());
|
|
||||||
+
|
|
||||||
+ ggml_tensor * base_t;
|
|
||||||
+ if (ml) {
|
|
||||||
+ if (!ml->get_tensor_meta(base_name.c_str())) {
|
|
||||||
+ LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
|
|
||||||
+ return 1;
|
|
||||||
+ }
|
|
||||||
+ base_t = ggml_dup_tensor(lora_ctx, ml->get_tensor_meta(base_name.c_str()));
|
|
||||||
+ } else {
|
|
||||||
+ base_t = ggml_dup_tensor(lora_ctx, model_t);
|
|
||||||
+ }
|
|
||||||
+ ggml_set_name(base_t, base_name.c_str());
|
|
||||||
+
|
|
||||||
+ // allocate in backend buffer
|
|
||||||
+ ggml_backend_buffer_t lora_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
|
|
||||||
+ if (lora_buf == nullptr) {
|
|
||||||
+ LLAMA_LOG_ERROR("%s: error: failed to allocate lora tensors\n", __func__);
|
|
||||||
+ return 1;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ // load tensor data
|
|
||||||
+ auto load_tensor = [&read_buf, &fin](const tensor_meta & tensor_meta, ggml_tensor * tensor) {
|
|
||||||
+ read_buf.resize(ggml_nbytes(tensor));
|
|
||||||
+ fin.seek(tensor_meta.offset, SEEK_SET);
|
|
||||||
+ fin.read_raw(read_buf.data(), ggml_nbytes(tensor));
|
|
||||||
+ ggml_backend_tensor_set(tensor, read_buf.data(), 0, read_buf.size());
|
|
||||||
+ };
|
|
||||||
+ load_tensor(metaA, loraA);
|
|
||||||
+ load_tensor(metaB, loraB);
|
|
||||||
+
|
|
||||||
+ // load base model tensor data
|
|
||||||
+ if (ml) {
|
|
||||||
+ ml->load_data_for(base_t);
|
|
||||||
+ } else {
|
|
||||||
+ ggml_backend_tensor_copy(model_t, base_t);
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ if (ggml_is_quantized(base_t->type) && !warned) {
|
|
||||||
+ LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, "
|
|
||||||
+ "use a f16 or f32 base model with --lora-base\n", __func__);
|
|
||||||
+ warned = true;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
|
|
||||||
+ LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
|
|
||||||
+ " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
|
|
||||||
+ ggml_free(lora_ctx);
|
|
||||||
+ ggml_backend_buffer_free(lora_buf);
|
|
||||||
+ ggml_backend_free(backend_cpu);
|
|
||||||
+ return 1;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ auto build_lora_graph = [&]() {
|
|
||||||
+ // w = w + BA*s
|
|
||||||
+ ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
|
|
||||||
+ ggml_set_name(BA, "BA");
|
|
||||||
+
|
|
||||||
+ if (scaling != 1.0f) {
|
|
||||||
+ BA = ggml_scale(lora_ctx, BA, scaling);
|
|
||||||
+ ggml_set_name(BA, "BA_scaled");
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ ggml_tensor * r;
|
|
||||||
+ r = ggml_add_inplace(lora_ctx, base_t, BA);
|
|
||||||
+ ggml_set_name(r, "r_add");
|
|
||||||
+
|
|
||||||
+ if (base_t->type != model_t->type) {
|
|
||||||
+ // convert the result to the model type
|
|
||||||
+ r = ggml_cast(lora_ctx, r, model_t->type);
|
|
||||||
+ ggml_set_name(r, "r_cast");
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ return r;
|
|
||||||
+ };
|
|
||||||
+
|
|
||||||
+ ggml_cgraph * gf = ggml_new_graph(lora_ctx);
|
|
||||||
+ ggml_tensor * r = build_lora_graph();
|
|
||||||
+ ggml_build_forward_expand(gf, r);
|
|
||||||
+
|
|
||||||
+ ggml_backend_buffer_t graph_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
|
|
||||||
+ if (graph_buf == nullptr) {
|
|
||||||
+ LLAMA_LOG_ERROR("%s: error: failed to allocate graph tensors\n", __func__);
|
|
||||||
+ ggml_free(lora_ctx);
|
|
||||||
+ ggml_backend_buffer_free(lora_buf);
|
|
||||||
+ ggml_backend_free(backend_cpu);
|
|
||||||
+ return 1;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ ggml_backend_graph_compute(backend_cpu, gf);
|
|
||||||
+
|
|
||||||
+ ggml_backend_tensor_set(model_t, r->data, 0, ggml_nbytes(r));
|
|
||||||
+
|
|
||||||
+#if 0
|
|
||||||
+ // TODO: use scheduler with fallback to CPU for less copies between CPU and GPU
|
|
||||||
+ //ggml_backend_sched_t sched = ggml_backend_sched_new(backends.data(), backends.size(), GGML_DEFAULT_GRAPH_SIZE);
|
|
||||||
+
|
|
||||||
+ // sched compute
|
|
||||||
+ ggml_build_forward_expand(gf, build_graph());
|
|
||||||
+ ggml_backend_sched_init_measure(sched, gf);
|
|
||||||
+
|
|
||||||
+ // create the graph again, since the previous one was destroyed by the measure
|
|
||||||
+ ggml_graph_clear(gf);
|
|
||||||
+ ggml_build_forward_expand(gf, build_graph());
|
|
||||||
+ ggml_backend_sched_graph_compute(sched, gf);
|
|
||||||
+ ggml_backend_sched_free(sched);
|
|
||||||
+#endif
|
|
||||||
+
|
|
||||||
+ ggml_backend_buffer_free(lora_buf);
|
|
||||||
+ ggml_backend_buffer_free(graph_buf);
|
|
||||||
+ ggml_free(lora_ctx);
|
|
||||||
+
|
|
||||||
+ n_tensors++;
|
|
||||||
+ if (n_tensors % 4 == 0) {
|
|
||||||
+ LLAMA_LOG_INFO(".");
|
|
||||||
+ }
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ ggml_backend_free(backend_cpu);
|
|
||||||
+
|
|
||||||
+ const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
|
|
||||||
+ LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
|
|
||||||
+
|
|
||||||
+ return 0;
|
|
||||||
+}
|
|
||||||
+
|
|
||||||
+int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int32_t n_threads) {
|
|
||||||
+ try {
|
|
||||||
+ return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
|
|
||||||
+ } catch (const std::exception & err) {
|
|
||||||
+ LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
|
|
||||||
+ return 1;
|
|
||||||
+ }
|
|
||||||
+}
|
|
||||||
\ No newline at end of file
|
|
||||||
@@ -385,10 +385,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
filteredEnv := []string{}
|
filteredEnv := []string{}
|
||||||
for _, ev := range s.cmd.Env {
|
for _, ev := range s.cmd.Env {
|
||||||
if strings.HasPrefix(ev, "CUDA_") ||
|
if strings.HasPrefix(ev, "CUDA_") ||
|
||||||
strings.HasPrefix(ev, "ROCR_") ||
|
|
||||||
strings.HasPrefix(ev, "ROCM_") ||
|
strings.HasPrefix(ev, "ROCM_") ||
|
||||||
strings.HasPrefix(ev, "HIP_") ||
|
strings.HasPrefix(ev, "HIP_") ||
|
||||||
strings.HasPrefix(ev, "GPU_") ||
|
|
||||||
strings.HasPrefix(ev, "HSA_") ||
|
strings.HasPrefix(ev, "HSA_") ||
|
||||||
strings.HasPrefix(ev, "GGML_") ||
|
strings.HasPrefix(ev, "GGML_") ||
|
||||||
strings.HasPrefix(ev, "PATH=") ||
|
strings.HasPrefix(ev, "PATH=") ||
|
||||||
@@ -417,17 +415,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
|
|
||||||
// reap subprocess when it exits
|
// reap subprocess when it exits
|
||||||
go func() {
|
go func() {
|
||||||
err := s.cmd.Wait()
|
s.done <- s.cmd.Wait()
|
||||||
// Favor a more detailed message over the process exit status
|
|
||||||
if err != nil && s.status != nil && s.status.LastErrMsg != "" {
|
|
||||||
slog.Debug("llama runner terminated", "error", err)
|
|
||||||
if strings.Contains(s.status.LastErrMsg, "unknown model") {
|
|
||||||
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
|
|
||||||
}
|
|
||||||
s.done <- fmt.Errorf(s.status.LastErrMsg)
|
|
||||||
} else {
|
|
||||||
s.done <- err
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
@@ -590,7 +578,14 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
|||||||
slog.Warn("client connection closed before server finished loading, aborting load")
|
slog.Warn("client connection closed before server finished loading, aborting load")
|
||||||
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
|
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
|
||||||
case err := <-s.done:
|
case err := <-s.done:
|
||||||
return fmt.Errorf("llama runner process has terminated: %w", err)
|
msg := ""
|
||||||
|
if s.status != nil && s.status.LastErrMsg != "" {
|
||||||
|
msg = s.status.LastErrMsg
|
||||||
|
}
|
||||||
|
if strings.Contains(msg, "unknown model") {
|
||||||
|
return fmt.Errorf("this model is not supported by your version of Ollama. You may need to upgrade")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
if time.Now().After(stallTimer) {
|
if time.Now().After(stallTimer) {
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -32,7 +31,6 @@ type ErrorResponse struct {
|
|||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content any `json:"content"`
|
Content any `json:"content"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Choice struct {
|
type Choice struct {
|
||||||
@@ -80,7 +78,6 @@ type ChatCompletionRequest struct {
|
|||||||
PresencePenalty *float64 `json:"presence_penalty_penalty"`
|
PresencePenalty *float64 `json:"presence_penalty_penalty"`
|
||||||
TopP *float64 `json:"top_p"`
|
TopP *float64 `json:"top_p"`
|
||||||
ResponseFormat *ResponseFormat `json:"response_format"`
|
ResponseFormat *ResponseFormat `json:"response_format"`
|
||||||
Tools []api.Tool `json:"tools"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletion struct {
|
type ChatCompletion struct {
|
||||||
@@ -114,7 +111,6 @@ type CompletionRequest struct {
|
|||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
Temperature *float32 `json:"temperature"`
|
Temperature *float32 `json:"temperature"`
|
||||||
TopP float32 `json:"top_p"`
|
TopP float32 `json:"top_p"`
|
||||||
Suffix string `json:"suffix"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Completion struct {
|
type Completion struct {
|
||||||
@@ -136,15 +132,6 @@ type CompletionChunk struct {
|
|||||||
SystemFingerprint string `json:"system_fingerprint"`
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Function struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Arguments string `json:"arguments"`
|
|
||||||
} `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
@@ -183,31 +170,7 @@ func NewError(code int, message string) ErrorResponse {
|
|||||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toolCallId() string {
|
|
||||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
||||||
b := make([]byte, 8)
|
|
||||||
for i := range b {
|
|
||||||
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
|
||||||
}
|
|
||||||
return "call_" + strings.ToLower(string(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
|
||||||
for i, tc := range r.Message.ToolCalls {
|
|
||||||
toolCalls[i].ID = toolCallId()
|
|
||||||
toolCalls[i].Type = "function"
|
|
||||||
toolCalls[i].Function.Name = tc.Function.Name
|
|
||||||
|
|
||||||
args, err := json.Marshal(tc.Function.Arguments)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("could not marshall function arguments to json", "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
toolCalls[i].Function.Arguments = string(args)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ChatCompletion{
|
return ChatCompletion{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
@@ -216,7 +179,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||||||
SystemFingerprint: "fp_ollama",
|
SystemFingerprint: "fp_ollama",
|
||||||
Choices: []Choice{{
|
Choices: []Choice{{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
|
Message: Message{Role: r.Message.Role, Content: r.Message.Content},
|
||||||
FinishReason: func(reason string) *string {
|
FinishReason: func(reason string) *string {
|
||||||
if len(reason) > 0 {
|
if len(reason) > 0 {
|
||||||
return &reason
|
return &reason
|
||||||
@@ -225,6 +188,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: Usage{
|
||||||
|
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
||||||
PromptTokens: r.PromptEvalCount,
|
PromptTokens: r.PromptEvalCount,
|
||||||
CompletionTokens: r.EvalCount,
|
CompletionTokens: r.EvalCount,
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
@@ -270,6 +234,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
|||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: Usage{
|
||||||
|
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
||||||
PromptTokens: r.PromptEvalCount,
|
PromptTokens: r.PromptEvalCount,
|
||||||
CompletionTokens: r.EvalCount,
|
CompletionTokens: r.EvalCount,
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
@@ -351,6 +316,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
case string:
|
case string:
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
||||||
case []any:
|
case []any:
|
||||||
|
message := api.Message{Role: msg.Role}
|
||||||
for _, c := range content {
|
for _, c := range content {
|
||||||
data, ok := c.(map[string]any)
|
data, ok := c.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -362,7 +328,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
message.Content = text
|
||||||
case "image_url":
|
case "image_url":
|
||||||
var url string
|
var url string
|
||||||
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
||||||
@@ -394,27 +360,15 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
|
message.Images = append(message.Images, img)
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
messages = append(messages, message)
|
||||||
default:
|
default:
|
||||||
if msg.ToolCalls == nil {
|
|
||||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||||
}
|
}
|
||||||
|
|
||||||
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
|
|
||||||
for i, tc := range msg.ToolCalls {
|
|
||||||
toolCalls[i].Function.Name = tc.Function.Name
|
|
||||||
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid tool call arguments")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]interface{})
|
||||||
@@ -471,7 +425,6 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
Format: format,
|
Format: format,
|
||||||
Options: options,
|
Options: options,
|
||||||
Stream: &r.Stream,
|
Stream: &r.Stream,
|
||||||
Tools: r.Tools,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -522,7 +475,6 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|||||||
Prompt: r.Prompt,
|
Prompt: r.Prompt,
|
||||||
Options: options,
|
Options: options,
|
||||||
Stream: &r.Stream,
|
Stream: &r.Stream,
|
||||||
Suffix: r.Suffix,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -877,7 +829,6 @@ func ChatMiddleware() gin.HandlerFunc {
|
|||||||
chatReq, err := fromChatRequest(req)
|
chatReq, err := fromChatRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||||
|
|||||||
@@ -20,59 +20,108 @@ const prefix = `data:image/jpeg;base64,`
|
|||||||
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
const imageURL = prefix + image
|
const imageURL = prefix + image
|
||||||
|
|
||||||
func prepareRequest(req *http.Request, body any) {
|
func TestMiddlewareRequests(t *testing.T) {
|
||||||
bodyBytes, _ := json.Marshal(body)
|
type testCase struct {
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
Name string
|
||||||
req.Header.Set("Content-Type", "application/json")
|
Method string
|
||||||
}
|
Path string
|
||||||
|
Handler func() gin.HandlerFunc
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *http.Request)
|
||||||
|
}
|
||||||
|
|
||||||
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
var capturedRequest *http.Request
|
||||||
|
|
||||||
|
captureRequestMiddleware := func() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
err := json.Unmarshal(bodyBytes, capturedRequest)
|
capturedRequest = c.Request
|
||||||
if err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
|
|
||||||
}
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func TestChatMiddleware(t *testing.T) {
|
|
||||||
type testCase struct {
|
|
||||||
Name string
|
|
||||||
Setup func(t *testing.T, req *http.Request)
|
|
||||||
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var capturedRequest *api.ChatRequest
|
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "chat handler",
|
Name: "chat handler",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/chat",
|
||||||
|
Handler: ChatMiddleware,
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := ChatCompletionRequest{
|
body := ChatCompletionRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||||
}
|
}
|
||||||
prepareRequest(req, body)
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
if resp.Code != http.StatusOK {
|
var chatReq api.ChatRequest
|
||||||
t.Fatalf("expected 200, got %d", resp.Code)
|
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Messages[0].Role != "user" {
|
if chatReq.Messages[0].Role != "user" {
|
||||||
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Messages[0].Content != "Hello" {
|
if chatReq.Messages[0].Content != "Hello" {
|
||||||
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "completions handler",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/generate",
|
||||||
|
Handler: CompletionsMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
temp := float32(0.8)
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Temperature: &temp,
|
||||||
|
Stop: []string{"\n", "stop"},
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var genReq api.GenerateRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if genReq.Prompt != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if genReq.Options["temperature"] != 1.6 {
|
||||||
|
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
|
||||||
|
}
|
||||||
|
|
||||||
|
stopTokens, ok := genReq.Options["stop"].([]any)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected stop tokens to be a list")
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
||||||
|
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "chat handler with image content",
|
Name: "chat handler with image content",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/chat",
|
||||||
|
Handler: ChatMiddleware,
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := ChatCompletionRequest{
|
body := ChatCompletionRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
@@ -85,254 +134,87 @@ func TestChatMiddleware(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
prepareRequest(req, body)
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
if resp.Code != http.StatusOK {
|
var chatReq api.ChatRequest
|
||||||
t.Fatalf("expected 200, got %d", resp.Code)
|
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Messages[0].Role != "user" {
|
if chatReq.Messages[0].Role != "user" {
|
||||||
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Messages[0].Content != "Hello" {
|
if chatReq.Messages[0].Content != "Hello" {
|
||||||
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
||||||
|
|
||||||
if req.Messages[1].Role != "user" {
|
if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
|
||||||
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
|
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(req.Messages[1].Images[0], img) {
|
|
||||||
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Name: "chat handler with tools",
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
body := ChatCompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Messages: []Message{
|
|
||||||
{Role: "user", Content: "What's the weather like in Paris Today?"},
|
|
||||||
{Role: "assistant", ToolCalls: []ToolCall{{
|
|
||||||
ID: "id",
|
|
||||||
Type: "function",
|
|
||||||
Function: struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Arguments string `json:"arguments"`
|
|
||||||
}{
|
|
||||||
Name: "get_current_weather",
|
|
||||||
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
|
|
||||||
},
|
|
||||||
}}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
prepareRequest(req, body)
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
if resp.Code != 200 {
|
|
||||||
t.Fatalf("expected 200, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
|
|
||||||
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
|
|
||||||
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
|
|
||||||
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "chat handler error forwarding",
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
body := ChatCompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Messages: []Message{{Role: "user", Content: 2}},
|
|
||||||
}
|
|
||||||
prepareRequest(req, body)
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
if resp.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(resp.Body.String(), "invalid message content type") {
|
|
||||||
t.Fatalf("error was not forwarded")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoint := func(c *gin.Context) {
|
|
||||||
c.Status(http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
router := gin.New()
|
|
||||||
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
|
|
||||||
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
|
||||||
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
|
|
||||||
|
|
||||||
tc.Setup(t, req)
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
|
||||||
router.ServeHTTP(resp, req)
|
|
||||||
|
|
||||||
tc.Expected(t, capturedRequest, resp)
|
|
||||||
|
|
||||||
capturedRequest = nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCompletionsMiddleware(t *testing.T) {
|
|
||||||
type testCase struct {
|
|
||||||
Name string
|
|
||||||
Setup func(t *testing.T, req *http.Request)
|
|
||||||
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
|
|
||||||
}
|
|
||||||
|
|
||||||
var capturedRequest *api.GenerateRequest
|
|
||||||
|
|
||||||
testCases := []testCase{
|
|
||||||
{
|
|
||||||
Name: "completions handler",
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
temp := float32(0.8)
|
|
||||||
body := CompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Prompt: "Hello",
|
|
||||||
Temperature: &temp,
|
|
||||||
Stop: []string{"\n", "stop"},
|
|
||||||
Suffix: "suffix",
|
|
||||||
}
|
|
||||||
prepareRequest(req, body)
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
if req.Prompt != "Hello" {
|
|
||||||
t.Fatalf("expected 'Hello', got %s", req.Prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Options["temperature"] != 1.6 {
|
|
||||||
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
|
|
||||||
}
|
|
||||||
|
|
||||||
stopTokens, ok := req.Options["stop"].([]any)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected stop tokens to be a list")
|
|
||||||
}
|
|
||||||
|
|
||||||
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
|
||||||
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Suffix != "suffix" {
|
|
||||||
t.Fatalf("expected 'suffix', got %s", req.Suffix)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "completions handler error forwarding",
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
body := CompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Prompt: "Hello",
|
|
||||||
Temperature: nil,
|
|
||||||
Stop: []int{1, 2},
|
|
||||||
Suffix: "suffix",
|
|
||||||
}
|
|
||||||
prepareRequest(req, body)
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
if resp.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
|
|
||||||
t.Fatalf("error was not forwarded")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoint := func(c *gin.Context) {
|
|
||||||
c.Status(http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
router := gin.New()
|
|
||||||
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
|
||||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
|
||||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
|
|
||||||
|
|
||||||
tc.Setup(t, req)
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
|
||||||
router.ServeHTTP(resp, req)
|
|
||||||
|
|
||||||
tc.Expected(t, capturedRequest, resp)
|
|
||||||
|
|
||||||
capturedRequest = nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEmbeddingsMiddleware(t *testing.T) {
|
|
||||||
type testCase struct {
|
|
||||||
Name string
|
|
||||||
Setup func(t *testing.T, req *http.Request)
|
|
||||||
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
|
|
||||||
}
|
|
||||||
|
|
||||||
var capturedRequest *api.EmbedRequest
|
|
||||||
|
|
||||||
testCases := []testCase{
|
|
||||||
{
|
{
|
||||||
Name: "embed handler single input",
|
Name: "embed handler single input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Handler: EmbeddingsMiddleware,
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := EmbedRequest{
|
body := EmbedRequest{
|
||||||
Input: "Hello",
|
Input: "Hello",
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
}
|
}
|
||||||
prepareRequest(req, body)
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
if req.Input != "Hello" {
|
var embedReq api.EmbedRequest
|
||||||
t.Fatalf("expected 'Hello', got %s", req.Input)
|
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Model != "test-model" {
|
if embedReq.Input != "Hello" {
|
||||||
t.Fatalf("expected 'test-model', got %s", req.Model)
|
t.Fatalf("expected 'Hello', got %s", embedReq.Input)
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedReq.Model != "test-model" {
|
||||||
|
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "embed handler batch input",
|
Name: "embed handler batch input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Handler: EmbeddingsMiddleware,
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := EmbedRequest{
|
body := EmbedRequest{
|
||||||
Input: []string{"Hello", "World"},
|
Input: []string{"Hello", "World"},
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
}
|
}
|
||||||
prepareRequest(req, body)
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
input, ok := req.Input.([]any)
|
var embedReq api.EmbedRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
input, ok := embedReq.Input.([]any)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("expected input to be a list")
|
t.Fatalf("expected input to be a list")
|
||||||
@@ -346,52 +228,36 @@ func TestEmbeddingsMiddleware(t *testing.T) {
|
|||||||
t.Fatalf("expected 'World', got %s", input[1])
|
t.Fatalf("expected 'World', got %s", input[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Model != "test-model" {
|
if embedReq.Model != "test-model" {
|
||||||
t.Fatalf("expected 'test-model', got %s", req.Model)
|
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Name: "embed handler error forwarding",
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
body := EmbedRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
}
|
|
||||||
prepareRequest(req, body)
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
if resp.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d", resp.Code)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.Contains(resp.Body.String(), "invalid input") {
|
gin.SetMode(gin.TestMode)
|
||||||
t.Fatalf("error was not forwarded")
|
router := gin.New()
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoint := func(c *gin.Context) {
|
endpoint := func(c *gin.Context) {
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
router := gin.New()
|
|
||||||
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
|
||||||
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
|
router = gin.New()
|
||||||
|
router.Use(captureRequestMiddleware())
|
||||||
|
router.Use(tc.Handler())
|
||||||
|
router.Handle(tc.Method, tc.Path, endpoint)
|
||||||
|
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
|
||||||
|
|
||||||
|
if tc.Setup != nil {
|
||||||
tc.Setup(t, req)
|
tc.Setup(t, req)
|
||||||
|
}
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
tc.Expected(t, capturedRequest, resp)
|
tc.Expected(t, capturedRequest)
|
||||||
|
|
||||||
capturedRequest = nil
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -409,6 +275,36 @@ func TestMiddlewareResponses(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
Name: "completions handler error forwarding",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/generate",
|
||||||
|
TestPath: "/api/generate",
|
||||||
|
Handler: CompletionsMiddleware,
|
||||||
|
Endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||||
|
},
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
|
||||||
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "list handler",
|
Name: "list handler",
|
||||||
Method: http.MethodGet,
|
Method: http.MethodGet,
|
||||||
@@ -425,6 +321,8 @@ func TestMiddlewareResponses(t *testing.T) {
|
|||||||
})
|
})
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
|
||||||
var listResp ListCompletion
|
var listResp ListCompletion
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -488,8 +386,6 @@ func TestMiddlewareResponses(t *testing.T) {
|
|||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, resp.Code)
|
|
||||||
|
|
||||||
tc.Expected(t, resp)
|
tc.Expected(t, resp)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func localCopy(src, target string) error {
|
|
||||||
dirPath := filepath.Dir(target)
|
|
||||||
|
|
||||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err := unix.Clonefile(src, target, 0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import "errors"
|
|
||||||
|
|
||||||
func localCopy(src, target string) error {
|
|
||||||
return errors.New("no local copy implementation for linux")
|
|
||||||
}
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
// +build windows
|
|
||||||
|
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
func localCopy(src, target string) error {
|
|
||||||
// Create target directory if it doesn't exist
|
|
||||||
dirPath := filepath.Dir(target)
|
|
||||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open source file
|
|
||||||
sourceFile, err := os.Open(src)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer sourceFile.Close()
|
|
||||||
|
|
||||||
// Create target file
|
|
||||||
targetFile, err := os.Create(target)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer targetFile.Close()
|
|
||||||
|
|
||||||
// Use CopyFileExW to copy the file
|
|
||||||
err = copyFileEx(src, target)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func copyFileEx(src, dst string) error {
|
|
||||||
kernel32 := syscall.NewLazyDLL("kernel32.dll")
|
|
||||||
copyFileEx := kernel32.NewProc("CopyFileExW")
|
|
||||||
|
|
||||||
srcPtr, err := syscall.UTF16PtrFromString(src)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dstPtr, err := syscall.UTF16PtrFromString(dst)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
r1, _, err := copyFileEx.Call(
|
|
||||||
uintptr(unsafe.Pointer(srcPtr)),
|
|
||||||
uintptr(unsafe.Pointer(dstPtr)),
|
|
||||||
0, 0, 0, 0)
|
|
||||||
|
|
||||||
if r1 == 0 {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -32,7 +32,6 @@ import (
|
|||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -493,12 +492,6 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
|||||||
layers = append(layers, baseLayer.Layer)
|
layers = append(layers, baseLayer.Layer)
|
||||||
}
|
}
|
||||||
case "license", "template", "system":
|
case "license", "template", "system":
|
||||||
if c.Name == "template" {
|
|
||||||
if _, err := template.Parse(c.Args); err != nil {
|
|
||||||
return fmt.Errorf("%w: %s", errBadTemplate, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Name != "license" {
|
if c.Name != "license" {
|
||||||
// replace
|
// replace
|
||||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
||||||
@@ -1089,12 +1082,11 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|||||||
if anonymous {
|
if anonymous {
|
||||||
// no user is associated with the public key, and the request requires non-anonymous access
|
// no user is associated with the public key, and the request requires non-anonymous access
|
||||||
pubKey, nestedErr := auth.GetPublicKey()
|
pubKey, nestedErr := auth.GetPublicKey()
|
||||||
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
|
|
||||||
if nestedErr != nil {
|
if nestedErr != nil {
|
||||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||||
return nil, errUnauthorized
|
return nil, errUnauthorized
|
||||||
}
|
}
|
||||||
return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
|
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
||||||
}
|
}
|
||||||
// user is associated with the public key, but is not authorized to make the request
|
// user is associated with the public key, but is not authorized to make the request
|
||||||
return nil, errUnauthorized
|
return nil, errUnauthorized
|
||||||
|
|||||||
@@ -311,14 +311,12 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
if err := tmpl.Execute(&b, map[string][]map[string]any{
|
||||||
"ToolCalls": {
|
"ToolCalls": {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
"Function": map[string]any{
|
||||||
Name: "@@name@@",
|
"Name": "@@name@@",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
"Arguments": "@@arguments@@",
|
||||||
"@@argument@@": 1,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -326,7 +324,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
var kv map[string]any
|
var kv map[string]string
|
||||||
// execute the subtree with placeholders to identify the keys
|
// execute the subtree with placeholders to identify the keys
|
||||||
// trim any commands that might exist in the template
|
// trim any commands that might exist in the template
|
||||||
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
|
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
|
||||||
@@ -336,23 +334,17 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|||||||
// find the keys that correspond to the name and arguments fields
|
// find the keys that correspond to the name and arguments fields
|
||||||
var name, arguments string
|
var name, arguments string
|
||||||
for k, v := range kv {
|
for k, v := range kv {
|
||||||
switch v.(type) {
|
switch v {
|
||||||
case string:
|
case "@@name@@":
|
||||||
name = k
|
name = k
|
||||||
case map[string]any:
|
case "@@arguments@@":
|
||||||
arguments = k
|
arguments = k
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if name == "" || arguments == "" {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
var objs []map[string]any
|
var objs []map[string]any
|
||||||
for offset := 0; offset < len(s); {
|
for offset := 0; offset < len(s); {
|
||||||
var obj map[string]any
|
if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) {
|
||||||
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
|
|
||||||
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
|
||||||
break
|
break
|
||||||
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
|
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
|
||||||
// skip over any syntax errors
|
// skip over any syntax errors
|
||||||
@@ -361,45 +353,27 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|||||||
// skip over any unmarshalable types
|
// skip over any unmarshalable types
|
||||||
offset += int(unmarshalType.Offset)
|
offset += int(unmarshalType.Offset)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
slog.Error("parseToolCalls", "error", err)
|
|
||||||
return nil, false
|
return nil, false
|
||||||
} else {
|
} else {
|
||||||
offset += int(decoder.InputOffset())
|
// break when an object is decoded
|
||||||
|
break
|
||||||
// collect all nested objects
|
|
||||||
var collect func(any) []map[string]any
|
|
||||||
collect = func(obj any) (all []map[string]any) {
|
|
||||||
switch o := obj.(type) {
|
|
||||||
case map[string]any:
|
|
||||||
all = append(all, o)
|
|
||||||
for _, v := range o {
|
|
||||||
all = append(all, collect(v)...)
|
|
||||||
}
|
|
||||||
case []any:
|
|
||||||
for _, v := range o {
|
|
||||||
all = append(all, collect(v)...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return all
|
|
||||||
}
|
|
||||||
objs = append(objs, collect(obj)...)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var toolCalls []api.ToolCall
|
var toolCalls []api.ToolCall
|
||||||
for _, kv := range objs {
|
for _, kv := range objs {
|
||||||
n, nok := kv[name].(string)
|
var call api.ToolCall
|
||||||
a, aok := kv[arguments].(map[string]any)
|
for k, v := range kv {
|
||||||
if nok && aok {
|
switch k {
|
||||||
toolCalls = append(toolCalls, api.ToolCall{
|
case name:
|
||||||
Function: api.ToolCallFunction{
|
call.Function.Name = v.(string)
|
||||||
Name: n,
|
case arguments:
|
||||||
Arguments: a,
|
call.Function.Arguments = v.(map[string]any)
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toolCalls = append(toolCalls, call)
|
||||||
|
}
|
||||||
|
|
||||||
return toolCalls, len(toolCalls) > 0
|
return toolCalls, len(toolCalls) > 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -115,6 +115,11 @@ func TestExtractFromZipFile(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]any `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -162,11 +167,6 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
|
|||||||
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
{"llama3-groq-tool-use", `<tool_call>
|
|
||||||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
|
||||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
|
|
||||||
</tool_call>`, true},
|
|
||||||
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var tools []api.Tool
|
var tools []api.Tool
|
||||||
@@ -181,18 +181,18 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
|
|||||||
|
|
||||||
calls := []api.ToolCall{
|
calls := []api.ToolCall{
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: function{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: map[string]any{
|
||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: function{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: map[string]any{
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
"location": "Toronto, Canada",
|
"location": "Toronto, Canada",
|
||||||
},
|
},
|
||||||
|
|||||||
136
server/routes.go
136
server/routes.go
@@ -4,7 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"cmp"
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -24,10 +23,8 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/auth"
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
@@ -59,7 +56,6 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var errRequired = errors.New("is required")
|
var errRequired = errors.New("is required")
|
||||||
var errBadTemplate = errors.New("template error")
|
|
||||||
|
|
||||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
@@ -279,6 +275,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.Response = sb.String()
|
r.Response = sb.String()
|
||||||
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
|
r.ToolCalls = toolCalls
|
||||||
|
r.Response = ""
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, r)
|
c.JSON(http.StatusOK, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -613,9 +614,6 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
|||||||
|
|
||||||
quantization := cmp.Or(r.Quantize, r.Quantization)
|
quantization := cmp.Or(r.Quantize, r.Quantization)
|
||||||
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
|
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
|
||||||
if errors.Is(err, errBadTemplate) {
|
|
||||||
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
|
||||||
}
|
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -931,6 +929,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = os.Stat(path)
|
_, err = os.Stat(path)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, os.ErrNotExist):
|
case errors.Is(err, os.ErrNotExist):
|
||||||
@@ -943,14 +942,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.GetHeader("X-Ollama-File") != "" && s.isLocal(c) {
|
|
||||||
err = localBlobCopy(c.GetHeader("X-Ollama-File"), path)
|
|
||||||
if err == nil {
|
|
||||||
c.Status(http.StatusCreated)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
layer, err := NewLayer(c.Request.Body, "")
|
layer, err := NewLayer(c.Request.Body, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -965,108 +956,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
c.Status(http.StatusCreated)
|
c.Status(http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
func localBlobCopy (src, dest string) error {
|
|
||||||
_, err := os.Stat(src)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = localCopy(src, dest)
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = defaultCopy(src, dest)
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("failed to copy blob")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) isLocal(c *gin.Context) bool {
|
|
||||||
if authz := c.GetHeader("Authorization"); authz != "" {
|
|
||||||
parts := strings.Split(authz, ":")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
|
|
||||||
requestData, err := base64.StdEncoding.DecodeString(parts[1])
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
partialRequestDataParts := strings.Split(string(requestData), ",")
|
|
||||||
if len(partialRequestDataParts) != 3 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
signature, err := base64.StdEncoding.DecodeString(parts[2])
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := clientPublicKey.Verify(requestData, &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
serverPublicKey, err := auth.GetPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
slog.Error(fmt.Sprintf("failed to get server public key: %v", err))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultCopy(path string, dest string) error {
|
|
||||||
// This function should be called if the server is local
|
|
||||||
// It should find the model directory, copy the blob over, and return the digest
|
|
||||||
dirPath := filepath.Dir(dest)
|
|
||||||
|
|
||||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy blob over
|
|
||||||
sourceFile, err := os.Open(path)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("could not open source file: %v", err)
|
|
||||||
}
|
|
||||||
defer sourceFile.Close()
|
|
||||||
|
|
||||||
destFile, err := os.Create(dest)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("could not create destination file: %v", err)
|
|
||||||
}
|
|
||||||
defer destFile.Close()
|
|
||||||
|
|
||||||
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error copying file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = destFile.Sync()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error flushing file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isLocalIP(ip netip.Addr) bool {
|
func isLocalIP(ip netip.Addr) bool {
|
||||||
if interfaces, err := net.Interfaces(); err == nil {
|
if interfaces, err := net.Interfaces(); err == nil {
|
||||||
for _, iface := range interfaces {
|
for _, iface := range interfaces {
|
||||||
@@ -1312,15 +1201,11 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case gin.H:
|
case gin.H:
|
||||||
status, ok := r["status"].(int)
|
|
||||||
if !ok {
|
|
||||||
status = http.StatusInternalServerError
|
|
||||||
}
|
|
||||||
if errorMsg, ok := r["error"].(string); ok {
|
if errorMsg, ok := r["error"].(string); ok {
|
||||||
c.JSON(status, gin.H{"error": errorMsg})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@@ -1410,7 +1295,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
caps := []Capability{CapabilityCompletion}
|
caps := []Capability{CapabilityCompletion}
|
||||||
if len(req.Tools) > 0 {
|
if req.Tools != nil {
|
||||||
caps = append(caps, CapabilityTools)
|
caps = append(caps, CapabilityTools)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1505,13 +1390,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
resp.Message.Content = sb.String()
|
resp.Message.Content = sb.String()
|
||||||
|
|
||||||
if len(req.Tools) > 0 {
|
|
||||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
resp.Message.ToolCalls = toolCalls
|
resp.Message.ToolCalls = toolCalls
|
||||||
resp.Message.Content = ""
|
resp.Message.Content = ""
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -491,42 +491,6 @@ func TestCreateTemplateSystem(t *testing.T) {
|
|||||||
if string(system) != "Say bye!" {
|
if string(system) != "Say bye!" {
|
||||||
t.Errorf("expected \"Say bye!\", actual %s", system)
|
t.Errorf("expected \"Say bye!\", actual %s", system)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("incomplete template", func(t *testing.T) {
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
|
||||||
Name: "test",
|
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
|
|
||||||
Stream: &stream,
|
|
||||||
})
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected status code 400, actual %d", w.Code)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("template with unclosed if", func(t *testing.T) {
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
|
||||||
Name: "test",
|
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
|
|
||||||
Stream: &stream,
|
|
||||||
})
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected status code 400, actual %d", w.Code)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("template with undefined function", func(t *testing.T) {
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
|
||||||
Name: "test",
|
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
|
|
||||||
Stream: &stream,
|
|
||||||
})
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected status code 400, actual %d", w.Code)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateLicenses(t *testing.T) {
|
func TestCreateLicenses(t *testing.T) {
|
||||||
|
|||||||
@@ -73,8 +73,8 @@ func TestGenerateChat(t *testing.T) {
|
|||||||
getCpuFn: gpu.GetCPUInfo,
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
reschedDelay: 250 * time.Millisecond,
|
reschedDelay: 250 * time.Millisecond,
|
||||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
// add small delay to simulate loading
|
// add 10ms delay to simulate loading
|
||||||
time.Sleep(time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
req.successCh <- &runnerRef{
|
req.successCh <- &runnerRef{
|
||||||
llama: &mock,
|
llama: &mock,
|
||||||
}
|
}
|
||||||
@@ -371,8 +371,6 @@ func TestGenerate(t *testing.T) {
|
|||||||
getCpuFn: gpu.GetCPUInfo,
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
reschedDelay: 250 * time.Millisecond,
|
reschedDelay: 250 * time.Millisecond,
|
||||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
// add small delay to simulate loading
|
|
||||||
time.Sleep(time.Millisecond)
|
|
||||||
req.successCh <- &runnerRef{
|
req.successCh <- &runnerRef{
|
||||||
llama: &mock,
|
llama: &mock,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,18 +10,15 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/auth"
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
@@ -530,64 +527,3 @@ func TestNormalize(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsLocalReal(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
clientPubLoc := t.TempDir()
|
|
||||||
t.Setenv("HOME", clientPubLoc)
|
|
||||||
t.Setenv("USERPROFILE", clientPubLoc)
|
|
||||||
|
|
||||||
_, err := auth.GetPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
ctx, _ := gin.CreateTestContext(w)
|
|
||||||
ctx.Request = &http.Request{
|
|
||||||
Header: make(http.Header),
|
|
||||||
}
|
|
||||||
|
|
||||||
requestURL := url.URL{
|
|
||||||
Scheme: "http",
|
|
||||||
Host: "localhost:8080",
|
|
||||||
Path: "/api/blobs",
|
|
||||||
}
|
|
||||||
request := &http.Request{
|
|
||||||
Method: http.MethodPost,
|
|
||||||
URL: &requestURL,
|
|
||||||
}
|
|
||||||
s := &Server{}
|
|
||||||
|
|
||||||
authz, err := api.Authorization(ctx, request)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set client authorization header
|
|
||||||
ctx.Request.Header.Set("Authorization", authz)
|
|
||||||
if !s.isLocal(ctx) {
|
|
||||||
t.Fatal("Expected isLocal to return true")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("different server pubkey", func(t *testing.T) {
|
|
||||||
serverPubLoc := t.TempDir()
|
|
||||||
t.Setenv("HOME", serverPubLoc)
|
|
||||||
t.Setenv("USERPROFILE", serverPubLoc)
|
|
||||||
_, err := auth.GetPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.isLocal(ctx) {
|
|
||||||
t.Fatal("Expected isLocal to return false")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("invalid pubkey", func(t *testing.T) {
|
|
||||||
ctx.Request.Header.Set("Authorization", "sha-25616:invalid")
|
|
||||||
if s.isLocal(ctx) {
|
|
||||||
t.Fatal("Expected isLocal to return false")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func TestLoad(t *testing.T) {
|
|||||||
require.Len(t, s.expiredCh, 1)
|
require.Len(t, s.expiredCh, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
type reqBundle struct {
|
type bundle struct {
|
||||||
ctx context.Context //nolint:containedctx
|
ctx context.Context //nolint:containedctx
|
||||||
ctxDone func()
|
ctxDone func()
|
||||||
srv *mockLlm
|
srv *mockLlm
|
||||||
@@ -102,13 +102,13 @@ type reqBundle struct {
|
|||||||
ggml *llm.GGML
|
ggml *llm.GGML
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scenario *reqBundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
return scenario.srv, nil
|
return scenario.srv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle {
|
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
|
||||||
b := &reqBundle{}
|
scenario := &bundle{}
|
||||||
b.ctx, b.ctxDone = context.WithCancel(ctx)
|
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||||
@@ -135,154 +135,124 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
|
|||||||
|
|
||||||
fname := f.Name()
|
fname := f.Name()
|
||||||
model := &Model{Name: modelName, ModelPath: fname}
|
model := &Model{Name: modelName, ModelPath: fname}
|
||||||
b.ggml, err = llm.LoadModel(model.ModelPath, 0)
|
scenario.ggml, err = llm.LoadModel(model.ModelPath, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if duration == nil {
|
scenario.req = &LlmRequest{
|
||||||
duration = &api.Duration{Duration: 5 * time.Millisecond}
|
ctx: scenario.ctx,
|
||||||
}
|
|
||||||
b.req = &LlmRequest{
|
|
||||||
ctx: b.ctx,
|
|
||||||
model: model,
|
model: model,
|
||||||
opts: api.DefaultOptions(),
|
opts: api.DefaultOptions(),
|
||||||
sessionDuration: duration,
|
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
|
||||||
successCh: make(chan *runnerRef, 1),
|
successCh: make(chan *runnerRef, 1),
|
||||||
errCh: make(chan error, 1),
|
errCh: make(chan error, 1),
|
||||||
}
|
}
|
||||||
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
||||||
return b
|
return scenario
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGpuFn() gpu.GpuInfoList {
|
func TestRequests(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer done()
|
||||||
|
|
||||||
|
// Same model, same request
|
||||||
|
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
|
||||||
|
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||||
|
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
|
||||||
|
scenario1b.req.model = scenario1a.req.model
|
||||||
|
scenario1b.ggml = scenario1a.ggml
|
||||||
|
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
|
|
||||||
|
// simple reload of same model
|
||||||
|
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
|
||||||
|
tmpModel := *scenario1a.req.model
|
||||||
|
scenario2a.req.model = &tmpModel
|
||||||
|
scenario2a.ggml = scenario1a.ggml
|
||||||
|
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||||
|
|
||||||
|
// Multiple loaded models
|
||||||
|
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
|
||||||
|
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
|
||||||
|
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
|
||||||
|
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||||
|
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
|
||||||
|
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.getGpuFn = func() gpu.GpuInfoList {
|
||||||
g := gpu.GpuInfo{Library: "metal"}
|
g := gpu.GpuInfo{Library: "metal"}
|
||||||
g.TotalMemory = 24 * format.GigaByte
|
g.TotalMemory = 24 * format.GigaByte
|
||||||
g.FreeMemory = 12 * format.GigaByte
|
g.FreeMemory = 12 * format.GigaByte
|
||||||
return []gpu.GpuInfo{g}
|
return []gpu.GpuInfo{g}
|
||||||
}
|
}
|
||||||
|
s.getCpuFn = func() gpu.GpuInfoList {
|
||||||
func getCpuFn() gpu.GpuInfoList {
|
|
||||||
g := gpu.GpuInfo{Library: "cpu"}
|
g := gpu.GpuInfo{Library: "cpu"}
|
||||||
g.TotalMemory = 32 * format.GigaByte
|
g.TotalMemory = 32 * format.GigaByte
|
||||||
g.FreeMemory = 26 * format.GigaByte
|
g.FreeMemory = 26 * format.GigaByte
|
||||||
return []gpu.GpuInfo{g}
|
return []gpu.GpuInfo{g}
|
||||||
}
|
}
|
||||||
|
s.newServerFn = scenario1a.newServer
|
||||||
func TestRequestsSameModelSameRequest(t *testing.T) {
|
slog.Info("scenario1a")
|
||||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
s.pendingReqCh <- scenario1a.req
|
||||||
defer done()
|
|
||||||
s := InitScheduler(ctx)
|
|
||||||
s.getGpuFn = getGpuFn
|
|
||||||
s.getCpuFn = getCpuFn
|
|
||||||
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
|
||||||
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
|
|
||||||
b.req.model = a.req.model
|
|
||||||
b.ggml = a.ggml
|
|
||||||
|
|
||||||
s.newServerFn = a.newServer
|
|
||||||
slog.Info("a")
|
|
||||||
s.pendingReqCh <- a.req
|
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
s.Run(ctx)
|
s.Run(ctx)
|
||||||
select {
|
select {
|
||||||
case resp := <-a.req.successCh:
|
case resp := <-scenario1a.req.successCh:
|
||||||
require.Equal(t, resp.llama, a.srv)
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, a.req.errCh)
|
require.Empty(t, scenario1a.req.errCh)
|
||||||
case err := <-a.req.errCh:
|
case err := <-scenario1a.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same runner as first request due to not needing a reload
|
// Same runner as first request due to not needing a reload
|
||||||
s.newServerFn = b.newServer
|
s.newServerFn = scenario1b.newServer
|
||||||
slog.Info("b")
|
slog.Info("scenario1b")
|
||||||
s.pendingReqCh <- b.req
|
s.pendingReqCh <- scenario1b.req
|
||||||
select {
|
select {
|
||||||
case resp := <-b.req.successCh:
|
case resp := <-scenario1b.req.successCh:
|
||||||
require.Equal(t, resp.llama, a.srv)
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, b.req.errCh)
|
require.Empty(t, scenario1b.req.errCh)
|
||||||
case err := <-b.req.errCh:
|
case err := <-scenario1b.req.errCh:
|
||||||
t.Fatal(err.Error())
|
|
||||||
case <-ctx.Done():
|
|
||||||
t.Fatal("timeout")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequestsSimpleReloadSameModel(t *testing.T) {
|
|
||||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
|
||||||
defer done()
|
|
||||||
s := InitScheduler(ctx)
|
|
||||||
s.getGpuFn = getGpuFn
|
|
||||||
s.getCpuFn = getCpuFn
|
|
||||||
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
|
||||||
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
|
|
||||||
tmpModel := *a.req.model
|
|
||||||
b.req.model = &tmpModel
|
|
||||||
b.ggml = a.ggml
|
|
||||||
|
|
||||||
s.newServerFn = a.newServer
|
|
||||||
slog.Info("a")
|
|
||||||
s.pendingReqCh <- a.req
|
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
|
||||||
s.Run(ctx)
|
|
||||||
select {
|
|
||||||
case resp := <-a.req.successCh:
|
|
||||||
require.Equal(t, resp.llama, a.srv)
|
|
||||||
require.Empty(t, s.pendingReqCh)
|
|
||||||
require.Empty(t, a.req.errCh)
|
|
||||||
case err := <-a.req.errCh:
|
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trigger a reload
|
// Trigger a reload
|
||||||
s.newServerFn = b.newServer
|
s.newServerFn = scenario2a.newServer
|
||||||
b.req.model.AdapterPaths = []string{"new"}
|
scenario2a.req.model.AdapterPaths = []string{"new"}
|
||||||
slog.Info("b")
|
slog.Info("scenario2a")
|
||||||
s.pendingReqCh <- b.req
|
s.pendingReqCh <- scenario2a.req
|
||||||
// finish first two requests, so model can reload
|
// finish first two requests, so model can reload
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
a.ctxDone()
|
scenario1a.ctxDone()
|
||||||
|
scenario1b.ctxDone()
|
||||||
select {
|
select {
|
||||||
case resp := <-b.req.successCh:
|
case resp := <-scenario2a.req.successCh:
|
||||||
require.Equal(t, resp.llama, b.srv)
|
require.Equal(t, resp.llama, scenario2a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, b.req.errCh)
|
require.Empty(t, scenario2a.req.errCh)
|
||||||
case err := <-b.req.errCh:
|
case err := <-scenario2a.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequestsMultipleLoadedModels(t *testing.T) {
|
|
||||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
|
||||||
defer done()
|
|
||||||
s := InitScheduler(ctx)
|
|
||||||
s.getGpuFn = getGpuFn
|
|
||||||
s.getCpuFn = getCpuFn
|
|
||||||
|
|
||||||
// Multiple loaded models
|
|
||||||
a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
|
|
||||||
b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil)
|
|
||||||
c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil)
|
|
||||||
c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
|
||||||
d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded
|
|
||||||
|
|
||||||
envconfig.MaxRunners = 1
|
envconfig.MaxRunners = 1
|
||||||
s.newServerFn = a.newServer
|
s.newServerFn = scenario3a.newServer
|
||||||
slog.Info("a")
|
slog.Info("scenario3a")
|
||||||
s.pendingReqCh <- a.req
|
s.pendingReqCh <- scenario3a.req
|
||||||
s.Run(ctx)
|
// finish prior request, so new model can load
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
scenario2a.ctxDone()
|
||||||
select {
|
select {
|
||||||
case resp := <-a.req.successCh:
|
case resp := <-scenario3a.req.successCh:
|
||||||
require.Equal(t, resp.llama, a.srv)
|
require.Equal(t, resp.llama, scenario3a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, a.req.errCh)
|
require.Empty(t, scenario3a.req.errCh)
|
||||||
case err := <-a.req.errCh:
|
case err := <-scenario3a.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
@@ -292,15 +262,15 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
envconfig.MaxRunners = 0
|
envconfig.MaxRunners = 0
|
||||||
s.newServerFn = b.newServer
|
s.newServerFn = scenario3b.newServer
|
||||||
slog.Info("b")
|
slog.Info("scenario3b")
|
||||||
s.pendingReqCh <- b.req
|
s.pendingReqCh <- scenario3b.req
|
||||||
select {
|
select {
|
||||||
case resp := <-b.req.successCh:
|
case resp := <-scenario3b.req.successCh:
|
||||||
require.Equal(t, resp.llama, b.srv)
|
require.Equal(t, resp.llama, scenario3b.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, b.req.errCh)
|
require.Empty(t, scenario3b.req.errCh)
|
||||||
case err := <-b.req.errCh:
|
case err := <-scenario3b.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
@@ -310,15 +280,15 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
// This is a CPU load with NumGPU = 0 so it should load
|
// This is a CPU load with NumGPU = 0 so it should load
|
||||||
s.newServerFn = c.newServer
|
s.newServerFn = scenario3c.newServer
|
||||||
slog.Info("c")
|
slog.Info("scenario3c")
|
||||||
s.pendingReqCh <- c.req
|
s.pendingReqCh <- scenario3c.req
|
||||||
select {
|
select {
|
||||||
case resp := <-c.req.successCh:
|
case resp := <-scenario3c.req.successCh:
|
||||||
require.Equal(t, resp.llama, c.srv)
|
require.Equal(t, resp.llama, scenario3c.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, c.req.errCh)
|
require.Empty(t, scenario3c.req.errCh)
|
||||||
case err := <-c.req.errCh:
|
case err := <-scenario3c.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
@@ -328,25 +298,25 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
// Try to load a model that wont fit
|
// Try to load a model that wont fit
|
||||||
s.newServerFn = d.newServer
|
s.newServerFn = scenario3d.newServer
|
||||||
slog.Info("d")
|
slog.Info("scenario3d")
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 3)
|
require.Len(t, s.loaded, 3)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
a.ctxDone() // Won't help since this one isn't big enough to make room
|
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
|
||||||
time.Sleep(2 * time.Millisecond)
|
time.Sleep(2 * time.Millisecond)
|
||||||
s.pendingReqCh <- d.req
|
s.pendingReqCh <- scenario3d.req
|
||||||
// finish prior request, so new model can load
|
// finish prior request, so new model can load
|
||||||
time.Sleep(6 * time.Millisecond)
|
time.Sleep(6 * time.Millisecond)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 2)
|
require.Len(t, s.loaded, 2)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
b.ctxDone()
|
scenario3b.ctxDone()
|
||||||
select {
|
select {
|
||||||
case resp := <-d.req.successCh:
|
case resp := <-scenario3d.req.successCh:
|
||||||
require.Equal(t, resp.llama, d.srv)
|
require.Equal(t, resp.llama, scenario3d.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, d.req.errCh)
|
require.Empty(t, scenario3d.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -359,19 +329,26 @@ func TestGetRunner(t *testing.T) {
|
|||||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
||||||
b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
|
||||||
|
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
|
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
|
||||||
|
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
envconfig.MaxQueuedRequests = 1
|
envconfig.MaxQueuedRequests = 1
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.getGpuFn = getGpuFn
|
s.getGpuFn = func() gpu.GpuInfoList {
|
||||||
s.getCpuFn = getCpuFn
|
g := gpu.GpuInfo{Library: "metal"}
|
||||||
s.newServerFn = a.newServer
|
g.TotalMemory = 24 * format.GigaByte
|
||||||
slog.Info("a")
|
g.FreeMemory = 12 * format.GigaByte
|
||||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
|
return []gpu.GpuInfo{g}
|
||||||
|
}
|
||||||
|
s.newServerFn = scenario1a.newServer
|
||||||
|
slog.Info("scenario1a")
|
||||||
|
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
slog.Info("b")
|
slog.Info("scenario1b")
|
||||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
|
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
require.Empty(t, successCh1b)
|
require.Empty(t, successCh1b)
|
||||||
require.Len(t, errCh1b, 1)
|
require.Len(t, errCh1b, 1)
|
||||||
@@ -380,24 +357,22 @@ func TestGetRunner(t *testing.T) {
|
|||||||
s.Run(ctx)
|
s.Run(ctx)
|
||||||
select {
|
select {
|
||||||
case resp := <-successCh1a:
|
case resp := <-successCh1a:
|
||||||
require.Equal(t, resp.llama, a.srv)
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, errCh1a)
|
require.Empty(t, errCh1a)
|
||||||
case err := <-errCh1a:
|
|
||||||
t.Fatal(err.Error())
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
a.ctxDone() // Set "a" model to idle so it can unload
|
scenario1a.ctxDone()
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 1)
|
require.Len(t, s.loaded, 1)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
c.req.model.ModelPath = "bad path"
|
scenario1c.req.model.ModelPath = "bad path"
|
||||||
slog.Info("c")
|
slog.Info("scenario1c")
|
||||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
|
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
|
||||||
// Starts in pending channel, then should be quickly processsed to return an error
|
// Starts in pending channel, then should be quickly processsed to return an error
|
||||||
time.Sleep(20 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
time.Sleep(5 * time.Millisecond)
|
||||||
require.Empty(t, successCh1c)
|
require.Empty(t, successCh1c)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Empty(t, s.loaded)
|
require.Empty(t, s.loaded)
|
||||||
@@ -405,7 +380,7 @@ func TestGetRunner(t *testing.T) {
|
|||||||
require.Len(t, errCh1c, 1)
|
require.Len(t, errCh1c, 1)
|
||||||
err = <-errCh1c
|
err = <-errCh1c
|
||||||
require.Contains(t, err.Error(), "bad path")
|
require.Contains(t, err.Error(), "bad path")
|
||||||
b.ctxDone()
|
scenario1b.ctxDone()
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
||||||
@@ -414,7 +389,7 @@ func TestPrematureExpired(t *testing.T) {
|
|||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
// Same model, same request
|
// Same model, same request
|
||||||
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil)
|
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.getGpuFn = func() gpu.GpuInfoList {
|
s.getGpuFn = func() gpu.GpuInfoList {
|
||||||
g := gpu.GpuInfo{Library: "metal"}
|
g := gpu.GpuInfo{Library: "metal"}
|
||||||
@@ -436,8 +411,6 @@ func TestPrematureExpired(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
slog.Info("sending premature expired event now")
|
slog.Info("sending premature expired event now")
|
||||||
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
||||||
case err := <-errCh1a:
|
|
||||||
t.Fatal(err.Error())
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -473,8 +446,6 @@ func TestUseLoadedRunner(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case success := <-req.successCh:
|
case success := <-req.successCh:
|
||||||
require.Equal(t, r1, success)
|
require.Equal(t, r1, success)
|
||||||
case err := <-req.errCh:
|
|
||||||
t.Fatal(err.Error())
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -654,7 +625,8 @@ func TestAlreadyCanceled(t *testing.T) {
|
|||||||
defer done()
|
defer done()
|
||||||
dctx, done2 := context.WithCancel(ctx)
|
dctx, done2 := context.WithCancel(ctx)
|
||||||
done2()
|
done2()
|
||||||
scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0})
|
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
|
||||||
|
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
slog.Info("scenario1a")
|
slog.Info("scenario1a")
|
||||||
s.pendingReqCh <- scenario1a.req
|
s.pendingReqCh <- scenario1a.req
|
||||||
|
|||||||
2
server/testdata/tools/command-r-plus.gotmpl
vendored
2
server/testdata/tools/command-r-plus.gotmpl
vendored
@@ -46,7 +46,7 @@ Action: ```json
|
|||||||
{{- range .ToolCalls }}
|
{{- range .ToolCalls }}
|
||||||
{
|
{
|
||||||
"tool_name": "{{ .Function.Name }}",
|
"tool_name": "{{ .Function.Name }}",
|
||||||
"parameters": {{ .Function.Arguments }}
|
"parameters": {{ json .Function.Arguments }}
|
||||||
}
|
}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
]```
|
]```
|
||||||
|
|||||||
4
server/testdata/tools/firefunction.gotmpl
vendored
4
server/testdata/tools/firefunction.gotmpl
vendored
@@ -17,7 +17,7 @@ If you decide to call functions:
|
|||||||
|
|
||||||
Available functions as JSON spec:
|
Available functions as JSON spec:
|
||||||
{{- if .Tools }}
|
{{- if .Tools }}
|
||||||
{{ .Tools }}
|
{{ json .Tools }}
|
||||||
{{- end }}<|eot_id|>
|
{{- end }}<|eot_id|>
|
||||||
{{- end }}
|
{{- end }}
|
||||||
{{- range .Messages }}<|start_header_id|>
|
{{- range .Messages }}<|start_header_id|>
|
||||||
@@ -25,7 +25,7 @@ Available functions as JSON spec:
|
|||||||
{{- end }}<|end_header_id|>
|
{{- end }}<|end_header_id|>
|
||||||
{{- if .Content }}{{ .Content }}
|
{{- if .Content }}{{ .Content }}
|
||||||
{{- else if .ToolCalls }} functools[
|
{{- else if .ToolCalls }} functools[
|
||||||
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
|
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
|
||||||
{{- end }}]
|
{{- end }}]
|
||||||
{{- end }}<|eot_id|>
|
{{- end }}<|eot_id|>
|
||||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
{{- if .Messages }}
|
|
||||||
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
|
||||||
|
|
||||||
{{ .System }}
|
|
||||||
{{- if .Tools }} You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
|
||||||
<tool_call>
|
|
||||||
{"name": <function-name>,"arguments": <args-dict>}
|
|
||||||
</tool_call>
|
|
||||||
|
|
||||||
Here are the available tools:
|
|
||||||
<tools>
|
|
||||||
{{- range .Tools }} {{ .Function }}
|
|
||||||
{{- end }} </tools>
|
|
||||||
{{- end }}
|
|
||||||
{{- end }}<|eot_id|>
|
|
||||||
{{- range .Messages }}
|
|
||||||
{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|>
|
|
||||||
|
|
||||||
{{ if eq .Role "user" }}{{ .Content }}
|
|
||||||
{{- else if eq .Role "assistant" }}
|
|
||||||
{{- if .Content }}{{ .Content }}
|
|
||||||
{{- else if .ToolCalls }}<tool_call>
|
|
||||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
|
||||||
{{- end }}
|
|
||||||
</tool_call>
|
|
||||||
{{- end }}
|
|
||||||
{{- else if eq .Role "tool" }}<tool_response>
|
|
||||||
{{ .Content }}
|
|
||||||
</tool_response>
|
|
||||||
{{- end }}<|eot_id|>
|
|
||||||
{{- end }}
|
|
||||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
{{ else }}
|
|
||||||
{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
|
||||||
|
|
||||||
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
|
||||||
|
|
||||||
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
{{ end }}{{ .Response }}
|
|
||||||
{{- if .Response }}<|eot_id|>
|
|
||||||
{{- end }}
|
|
||||||
24
server/testdata/tools/llama3-groq-tool-use.out
vendored
24
server/testdata/tools/llama3-groq-tool-use.out
vendored
@@ -1,24 +0,0 @@
|
|||||||
<|start_header_id|>system<|end_header_id|>
|
|
||||||
|
|
||||||
You are a knowledgable assistant. You can answer questions and perform tasks. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
|
||||||
<tool_call>
|
|
||||||
{"name": <function-name>,"arguments": <args-dict>}
|
|
||||||
</tool_call>
|
|
||||||
|
|
||||||
Here are the available tools:
|
|
||||||
<tools> {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} </tools><|eot_id|><|start_header_id|>user<|end_header_id|>
|
|
||||||
|
|
||||||
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
<tool_call>
|
|
||||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
|
||||||
</tool_call><|eot_id|><|start_header_id|>tool<|end_header_id|>
|
|
||||||
|
|
||||||
<tool_response>
|
|
||||||
22
|
|
||||||
</tool_response><|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
|
||||||
|
|
||||||
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
4
server/testdata/tools/mistral.gotmpl
vendored
4
server/testdata/tools/mistral.gotmpl
vendored
@@ -1,13 +1,13 @@
|
|||||||
{{- range $index, $_ := .Messages }}
|
{{- range $index, $_ := .Messages }}
|
||||||
{{- if eq .Role "user" }}
|
{{- if eq .Role "user" }}
|
||||||
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
|
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
|
||||||
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
||||||
|
|
||||||
{{ end }}{{ .Content }}[/INST]
|
{{ end }}{{ .Content }}[/INST]
|
||||||
{{- else if eq .Role "assistant" }}
|
{{- else if eq .Role "assistant" }}
|
||||||
{{- if .Content }} {{ .Content }}</s>
|
{{- if .Content }} {{ .Content }}</s>
|
||||||
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
|
||||||
{{- end }}]</s>
|
{{- end }}]</s>
|
||||||
{{- end }}
|
{{- end }}
|
||||||
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
||||||
|
|||||||
45
server/testdata/tools/xlam.gotmpl
vendored
45
server/testdata/tools/xlam.gotmpl
vendored
@@ -1,45 +0,0 @@
|
|||||||
{{- if .System }}{{ .System }}
|
|
||||||
{{ end }}
|
|
||||||
{{- range $i, $_ := .Messages }}
|
|
||||||
{{- if eq .Role "user" }}### Instruction:
|
|
||||||
{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }}
|
|
||||||
[BEGIN OF TASK INSTRUCTION]
|
|
||||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
|
||||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
|
||||||
If none of the functions can be used, point it out and refuse to answer.
|
|
||||||
If the given question lacks the parameters required by the function, also point it out.
|
|
||||||
[END OF TASK INSTRUCTION]
|
|
||||||
|
|
||||||
[BEGIN OF AVAILABLE TOOLS]
|
|
||||||
{{ $.Tools }}
|
|
||||||
[END OF AVAILABLE TOOLS]
|
|
||||||
|
|
||||||
[BEGIN OF FORMAT INSTRUCTION]
|
|
||||||
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
|
||||||
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"tool_calls": [
|
|
||||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
|
||||||
... (more tool calls as required)
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
[END OF FORMAT INSTRUCTION]
|
|
||||||
|
|
||||||
[BEGIN OF QUERY]
|
|
||||||
{{ .Content }}
|
|
||||||
[END OF QUERY]
|
|
||||||
|
|
||||||
|
|
||||||
{{ else }}
|
|
||||||
{{ .Content }}
|
|
||||||
{{ end }}
|
|
||||||
{{- else if .ToolCalls }}### Response:
|
|
||||||
{"tool_calls": [{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}]}
|
|
||||||
<|EOT|>
|
|
||||||
{{ else if eq .Role "assistant" }}### Response:
|
|
||||||
{{ .Content }}
|
|
||||||
<|EOT|>
|
|
||||||
{{ end }}
|
|
||||||
{{- end }}### Response:
|
|
||||||
40
server/testdata/tools/xlam.out
vendored
40
server/testdata/tools/xlam.out
vendored
@@ -1,40 +0,0 @@
|
|||||||
You are a knowledgable assistant. You can answer questions and perform tasks.
|
|
||||||
### Instruction:
|
|
||||||
What's the weather like today in Paris?
|
|
||||||
### Response:
|
|
||||||
{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]}
|
|
||||||
<|EOT|>
|
|
||||||
### Response:
|
|
||||||
The current temperature in Paris, France is 22 degrees Celsius.
|
|
||||||
<|EOT|>
|
|
||||||
### Instruction:
|
|
||||||
[BEGIN OF TASK INSTRUCTION]
|
|
||||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
|
||||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
|
||||||
If none of the functions can be used, point it out and refuse to answer.
|
|
||||||
If the given question lacks the parameters required by the function, also point it out.
|
|
||||||
[END OF TASK INSTRUCTION]
|
|
||||||
|
|
||||||
[BEGIN OF AVAILABLE TOOLS]
|
|
||||||
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]
|
|
||||||
[END OF AVAILABLE TOOLS]
|
|
||||||
|
|
||||||
[BEGIN OF FORMAT INSTRUCTION]
|
|
||||||
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
|
||||||
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"tool_calls": [
|
|
||||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
|
||||||
... (more tool calls as required)
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
[END OF FORMAT INSTRUCTION]
|
|
||||||
|
|
||||||
[BEGIN OF QUERY]
|
|
||||||
What's the weather like today in San Francisco and Toronto?
|
|
||||||
[END OF QUERY]
|
|
||||||
|
|
||||||
|
|
||||||
### Response:
|
|
||||||
@@ -150,7 +150,7 @@ func (t *Template) Vars() []string {
|
|||||||
|
|
||||||
type Values struct {
|
type Values struct {
|
||||||
Messages []api.Message
|
Messages []api.Message
|
||||||
api.Tools
|
Tools []api.Tool
|
||||||
Prompt string
|
Prompt string
|
||||||
Suffix string
|
Suffix string
|
||||||
|
|
||||||
@@ -217,7 +217,6 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
"System": system,
|
"System": system,
|
||||||
"Messages": messages,
|
"Messages": messages,
|
||||||
"Tools": v.Tools,
|
"Tools": v.Tools,
|
||||||
"Response": "",
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -264,7 +263,6 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
||||||
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
||||||
cut = true
|
cut = true
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return cut
|
return cut
|
||||||
@@ -274,7 +272,6 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Prompt": prompt,
|
"Prompt": prompt,
|
||||||
"Response": response,
|
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -260,26 +260,6 @@ func TestExecuteWithMessages(t *testing.T) {
|
|||||||
|
|
||||||
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"mistral assistant",
|
|
||||||
[]template{
|
|
||||||
{"no response", `[INST] {{ .Prompt }}[/INST] `},
|
|
||||||
{"response", `[INST] {{ .Prompt }}[/INST] {{ .Response }}`},
|
|
||||||
{"messages", `
|
|
||||||
{{- range $i, $m := .Messages }}
|
|
||||||
{{- if eq .Role "user" }}[INST] {{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}{{ end }}
|
|
||||||
{{- end }}`},
|
|
||||||
},
|
|
||||||
Values{
|
|
||||||
Messages: []api.Message{
|
|
||||||
{Role: "user", Content: "Hello friend!"},
|
|
||||||
{Role: "assistant", Content: "Hello human!"},
|
|
||||||
{Role: "user", Content: "What is your name?"},
|
|
||||||
{Role: "assistant", Content: "My name is Ollama and I"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I`,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"chatml",
|
"chatml",
|
||||||
[]template{
|
[]template{
|
||||||
|
|||||||
Reference in New Issue
Block a user