Compare commits
53 Commits
v0.2.0
...
jyan/progr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b6c7d01af3 | ||
|
|
9d517cf556 | ||
|
|
6bab0e2368 | ||
|
|
c4cccaf936 | ||
|
|
9fe5c393e4 | ||
|
|
007c988dba | ||
|
|
91d21e7c7b | ||
|
|
3e64284f69 | ||
|
|
39910f2ab2 | ||
|
|
96d0cd92f2 | ||
|
|
3a724a7c80 | ||
|
|
f520f0056e | ||
|
|
d25f85ede4 | ||
|
|
b48420b74b | ||
|
|
784958a1cb | ||
|
|
ae65cc8dea | ||
|
|
a037528bba | ||
|
|
04bf41deb5 | ||
|
|
c23cec9547 | ||
|
|
8377dc48d0 | ||
|
|
3aee405dfa | ||
|
|
9b3f47b674 | ||
|
|
f5441f01a2 | ||
|
|
ab165df43a | ||
|
|
79cc4c9585 | ||
|
|
bc3f59a6ad | ||
|
|
1a85cb904c | ||
|
|
10ea0987e9 | ||
|
|
413d368a6a | ||
|
|
cabf375059 | ||
|
|
ca0ee1d4fe | ||
|
|
1142999aab | ||
|
|
0d5a72aba9 | ||
|
|
ea837412c2 | ||
|
|
736ad6f438 | ||
|
|
64607d16a5 | ||
|
|
a6cfe7f00b | ||
|
|
c3b411a515 | ||
|
|
928f37e3ae | ||
|
|
2d1e3c3229 | ||
|
|
4918fae535 | ||
|
|
0aff67877e | ||
|
|
9544a57ee4 | ||
|
|
b51e3b63ac | ||
|
|
6bbbc50f10 | ||
|
|
9bbddc37a7 | ||
|
|
e4ff73297d | ||
|
|
0bacb30007 | ||
|
|
fb6cbc02fb | ||
|
|
326363b3a7 | ||
|
|
ac7a842e55 | ||
|
|
2c3fe1fd97 | ||
|
|
269ed6e6a2 |
5
.github/workflows/release.yaml
vendored
5
.github/workflows/release.yaml
vendored
@@ -304,11 +304,6 @@ jobs:
|
||||
write-host "Installing plugin"
|
||||
& "${env:RUNNER_TEMP}\plugin\*\kmscng.msi" /quiet
|
||||
write-host "plugin installed"
|
||||
- name: remove unwanted mingw dll.a files
|
||||
run: |
|
||||
Get-ChildItem -Path "C:\mingw64" -Recurse -Filter "libpthread.dll.a" -File | Remove-Item -Force
|
||||
Get-ChildItem -Path "C:\mingw64" -Recurse -Filter "libwinpthread.dll.a" -File | Remove-Item -Force
|
||||
Get-ChildItem -Path "C:\mingw64" -Recurse -Filter "libstdc++.dll.a" -File | Remove-Item -Force
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
@@ -17,14 +17,20 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -374,3 +380,27 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
|
||||
return version.Version, nil
|
||||
}
|
||||
|
||||
func Authorization(ctx context.Context, request *http.Request) (string, error) {
|
||||
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer knownHostsFile.Close()
|
||||
|
||||
token, err := auth.Sign(ctx, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// interleave request data into the token
|
||||
key, sig, _ := strings.Cut(token, ":")
|
||||
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
|
||||
}
|
||||
|
||||
54
auth/auth.go
54
auth/auth.go
@@ -10,42 +10,37 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const defaultPrivateKey = "id_ed25519"
|
||||
|
||||
func keyPath() (string, error) {
|
||||
func keyPath() (ssh.Signer, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||
}
|
||||
|
||||
func GetPublicKey() (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||
return ssh.ParsePrivateKey(privateKeyFile)
|
||||
}
|
||||
|
||||
func GetPublicKey() (ssh.PublicKey, error) {
|
||||
privateKey, err := keyPath()
|
||||
// if privateKey, try public key directly
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||
|
||||
return strings.TrimSpace(string(publicKey)), nil
|
||||
return privateKey.PublicKey(), nil
|
||||
}
|
||||
|
||||
func NewNonce(r io.Reader, length int) (string, error) {
|
||||
@@ -58,25 +53,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
||||
}
|
||||
|
||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||
privateKey, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// get the pubkey, but remove the type
|
||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||
parts := bytes.Split(publicKey, []byte(" "))
|
||||
publicKey, err := GetPublicKey()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
|
||||
|
||||
parts := bytes.Split(publicKeyBytes, []byte(" "))
|
||||
if len(parts) < 2 {
|
||||
return "", fmt.Errorf("malformed public key")
|
||||
}
|
||||
|
||||
163
cmd/cmd.go
163
cmd/cmd.go
@@ -7,6 +7,7 @@ import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
@@ -78,6 +80,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
status := "transferring model data"
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
defer p.Stop()
|
||||
|
||||
for i := range modelfile.Commands {
|
||||
switch modelfile.Commands[i].Name {
|
||||
@@ -112,11 +115,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
path = tempfile
|
||||
}
|
||||
|
||||
digest, err := createBlob(cmd, client, path)
|
||||
digest, err := createBlob(cmd, client, path, spinner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelfile.Commands[i].Args = "@" + digest
|
||||
}
|
||||
}
|
||||
@@ -138,7 +140,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
spinner.Stop()
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
}
|
||||
|
||||
@@ -263,13 +265,22 @@ func tempZipFiles(path string) (string, error) {
|
||||
return tempfile.Name(), nil
|
||||
}
|
||||
|
||||
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
|
||||
var ErrBlobExists = errors.New("blob exists")
|
||||
|
||||
func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
|
||||
bin, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
// Get file info to retrieve the size
|
||||
fileInfo, err := bin.Stat()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
fileSize := fileInfo.Size()
|
||||
|
||||
hash := sha256.New()
|
||||
if _, err := io.Copy(hash, bin); err != nil {
|
||||
return "", err
|
||||
@@ -279,13 +290,151 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
||||
return "", err
|
||||
}
|
||||
|
||||
var pw progressWriter
|
||||
status := "transferring model data 0%"
|
||||
spinner.SetMessage(status)
|
||||
|
||||
ticker := time.NewTicker(60 * time.Millisecond)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n/fileSize)))
|
||||
case <-done:
|
||||
spinner.SetMessage("transferring model data 100%")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
|
||||
|
||||
// We check if we can find the models directory locally
|
||||
// If we can, we return the path to the directory
|
||||
// If we can't, we return an error
|
||||
// If the blob exists already, we return the digest
|
||||
dest, err := getLocalPath(cmd.Context(), digest)
|
||||
|
||||
if errors.Is(err, ErrBlobExists) {
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
// Successfully found the model directory
|
||||
if err == nil {
|
||||
// Copy blob in via OS specific copy
|
||||
// Linux errors out to use io.copy
|
||||
err = localCopy(path, dest)
|
||||
if err == nil {
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
// Default copy using io.copy
|
||||
err = defaultCopy(path, dest)
|
||||
if err == nil {
|
||||
return digest, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If at any point copying the blob over locally fails, we default to the copy through the server
|
||||
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
type progressWriter struct {
|
||||
n int64
|
||||
}
|
||||
|
||||
func (w *progressWriter) Write(p []byte) (n int, err error) {
|
||||
w.n += int64(len(p))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func getLocalPath(ctx context.Context, digest string) (string, error) {
|
||||
ollamaHost := envconfig.Host
|
||||
|
||||
client := http.DefaultClient
|
||||
base := &url.URL{
|
||||
Scheme: ollamaHost.Scheme,
|
||||
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(digest)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
reqBody := bytes.NewReader(data)
|
||||
path := fmt.Sprintf("/api/blobs/%s", digest)
|
||||
requestURL := base.JoinPath(path)
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
authz, err := api.Authorization(ctx, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
request.Header.Set("Authorization", authz)
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
request.Header.Set("X-Redirect-Create", "1")
|
||||
|
||||
resp, err := client.Do(request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusTemporaryRedirect {
|
||||
dest := resp.Header.Get("LocalLocation")
|
||||
|
||||
return dest, nil
|
||||
}
|
||||
return "", ErrBlobExists
|
||||
}
|
||||
|
||||
func defaultCopy(path string, dest string) error {
|
||||
// This function should be called if the server is local
|
||||
// It should find the model directory, copy the blob over, and return the digest
|
||||
dirPath := filepath.Dir(dest)
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy blob over
|
||||
sourceFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not open source file: %v", err)
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
destFile, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create destination file: %v", err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error copying file: %v", err)
|
||||
}
|
||||
|
||||
err = destFile.Sync()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error flushing file: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
@@ -379,11 +528,13 @@ func errFromUnknownKey(unknownKeyErr error) error {
|
||||
if len(matches) > 0 {
|
||||
serverPubKey := matches[0]
|
||||
|
||||
localPubKey, err := auth.GetPublicKey()
|
||||
publicKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
|
||||
|
||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||
// try the ollama service public key
|
||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||
|
||||
23
cmd/copy_darwin.go
Normal file
23
cmd/copy_darwin.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func localCopy(src, target string) error {
|
||||
dirPath := filepath.Dir(target)
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := unix.Clonefile(src, target, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
7
cmd/copy_linux.go
Normal file
7
cmd/copy_linux.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package cmd
|
||||
|
||||
import "errors"
|
||||
|
||||
func localCopy(src, target string) error {
|
||||
return errors.New("no local copy implementation for linux")
|
||||
}
|
||||
67
cmd/copy_windows.go
Normal file
67
cmd/copy_windows.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func localCopy(src, target string) error {
|
||||
// Create target directory if it doesn't exist
|
||||
dirPath := filepath.Dir(target)
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Open source file
|
||||
sourceFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
// Create target file
|
||||
targetFile, err := os.Create(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer targetFile.Close()
|
||||
|
||||
// Use CopyFileExW to copy the file
|
||||
err = copyFileEx(src, target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyFileEx(src, dst string) error {
|
||||
kernel32 := syscall.NewLazyDLL("kernel32.dll")
|
||||
copyFileEx := kernel32.NewProc("CopyFileExW")
|
||||
|
||||
srcPtr, err := syscall.UTF16PtrFromString(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dstPtr, err := syscall.UTF16PtrFromString(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err := copyFileEx.Call(
|
||||
uintptr(unsafe.Pointer(srcPtr)),
|
||||
uintptr(unsafe.Pointer(dstPtr)),
|
||||
0, 0, 0, 0)
|
||||
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
3
go.mod
3
go.mod
@@ -18,6 +18,7 @@ require (
|
||||
require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
@@ -71,7 +72,7 @@ require (
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sys v0.20.0
|
||||
golang.org/x/term v0.20.0
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
golang.org/x/text v0.15.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -254,7 +254,7 @@ if [ -z "${OLLAMA_SKIP_ROCM_GENERATE}" -a -d "${ROCM_PATH}" ]; then
|
||||
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
|
||||
fi
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DLLAMA_CUDA_NO_PEER_COPY=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
||||
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
|
||||
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
|
||||
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""
|
||||
|
||||
@@ -366,6 +366,7 @@ function build_rocm() {
|
||||
"-DCMAKE_C_COMPILER=clang.exe",
|
||||
"-DCMAKE_CXX_COMPILER=clang++.exe",
|
||||
"-DGGML_HIPBLAS=on",
|
||||
"-DLLAMA_CUDA_NO_PEER_COPY=on",
|
||||
"-DHIP_PLATFORM=amd",
|
||||
"-DGGML_AVX=on",
|
||||
"-DGGML_AVX2=off",
|
||||
|
||||
@@ -4,8 +4,8 @@ package llm
|
||||
// #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
|
||||
// #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
|
||||
// #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
|
||||
// #cgo windows,amd64 LDFLAGS: -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
|
||||
// #cgo windows,arm64 LDFLAGS: -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
|
||||
// #cgo windows,amd64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
|
||||
// #cgo windows,arm64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
|
||||
// #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
|
||||
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
|
||||
// #include <stdlib.h>
|
||||
|
||||
@@ -679,7 +679,7 @@ type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Images []ImageData
|
||||
Options api.Options
|
||||
Options *api.Options
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
|
||||
@@ -338,12 +338,16 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
switch stop := r.Stop.(type) {
|
||||
case string:
|
||||
options["stop"] = []string{stop}
|
||||
case []string:
|
||||
options["stop"] = stop
|
||||
default:
|
||||
if r.Stop != nil {
|
||||
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
|
||||
case []any:
|
||||
var stops []string
|
||||
for _, s := range stop {
|
||||
if str, ok := s.(string); ok {
|
||||
stops = append(stops, str)
|
||||
} else {
|
||||
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
|
||||
}
|
||||
}
|
||||
options["stop"] = stops
|
||||
}
|
||||
|
||||
if r.MaxTokens != nil {
|
||||
|
||||
@@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -16,7 +15,133 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
func TestMiddlewareRequests(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Method string
|
||||
Path string
|
||||
Handler func() gin.HandlerFunc
|
||||
Setup func(t *testing.T, req *http.Request)
|
||||
Expected func(t *testing.T, req *http.Request)
|
||||
}
|
||||
|
||||
var capturedRequest *http.Request
|
||||
|
||||
captureRequestMiddleware := func() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
capturedRequest = c.Request
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
Name: "chat handler",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/chat",
|
||||
Handler: ChatMiddleware,
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||
}
|
||||
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
},
|
||||
Expected: func(t *testing.T, req *http.Request) {
|
||||
var chatReq api.ChatRequest
|
||||
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if chatReq.Messages[0].Role != "user" {
|
||||
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
||||
}
|
||||
|
||||
if chatReq.Messages[0].Content != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "completions handler",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/generate",
|
||||
Handler: CompletionsMiddleware,
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
temp := float32(0.8)
|
||||
body := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Temperature: &temp,
|
||||
Stop: []string{"\n", "stop"},
|
||||
}
|
||||
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
},
|
||||
Expected: func(t *testing.T, req *http.Request) {
|
||||
var genReq api.GenerateRequest
|
||||
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if genReq.Prompt != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
|
||||
}
|
||||
|
||||
if genReq.Options["temperature"] != 1.6 {
|
||||
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
|
||||
}
|
||||
|
||||
stopTokens, ok := genReq.Options["stop"].([]any)
|
||||
|
||||
if !ok {
|
||||
t.Fatalf("expected stop tokens to be a list")
|
||||
}
|
||||
|
||||
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
||||
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
router = gin.New()
|
||||
router.Use(captureRequestMiddleware())
|
||||
router.Use(tc.Handler())
|
||||
router.Handle(tc.Method, tc.Path, endpoint)
|
||||
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
|
||||
|
||||
if tc.Setup != nil {
|
||||
tc.Setup(t, req)
|
||||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
tc.Expected(t, capturedRequest)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareResponses(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Method string
|
||||
@@ -30,159 +155,7 @@ func TestMiddleware(t *testing.T) {
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
Name: "chat handler",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/chat",
|
||||
TestPath: "/api/chat",
|
||||
Handler: ChatMiddleware,
|
||||
Endpoint: func(c *gin.Context) {
|
||||
var chatReq api.ChatRequest
|
||||
if err := c.ShouldBindJSON(&chatReq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
userMessage := chatReq.Messages[0].Content
|
||||
var assistantMessage string
|
||||
|
||||
switch userMessage {
|
||||
case "Hello":
|
||||
assistantMessage = "Hello!"
|
||||
default:
|
||||
assistantMessage = "I'm not sure how to respond to that."
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, api.ChatResponse{
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: assistantMessage,
|
||||
},
|
||||
})
|
||||
},
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||
}
|
||||
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
},
|
||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
|
||||
var chatResp ChatCompletion
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if chatResp.Object != "chat.completion" {
|
||||
t.Fatalf("expected chat.completion, got %s", chatResp.Object)
|
||||
}
|
||||
|
||||
if chatResp.Choices[0].Message.Content != "Hello!" {
|
||||
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "completions handler",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/generate",
|
||||
TestPath: "/api/generate",
|
||||
Handler: CompletionsMiddleware,
|
||||
Endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Response: "Hello!",
|
||||
})
|
||||
},
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
}
|
||||
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
},
|
||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
var completionResp Completion
|
||||
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if completionResp.Object != "text_completion" {
|
||||
t.Fatalf("expected text_completion, got %s", completionResp.Object)
|
||||
}
|
||||
|
||||
if completionResp.Choices[0].Text != "Hello!" {
|
||||
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "completions handler with params",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/generate",
|
||||
TestPath: "/api/generate",
|
||||
Handler: CompletionsMiddleware,
|
||||
Endpoint: func(c *gin.Context) {
|
||||
var generateReq api.GenerateRequest
|
||||
if err := c.ShouldBindJSON(&generateReq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
temperature := generateReq.Options["temperature"].(float64)
|
||||
var assistantMessage string
|
||||
|
||||
switch temperature {
|
||||
case 1.6:
|
||||
assistantMessage = "Received temperature of 1.6"
|
||||
default:
|
||||
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Response: assistantMessage,
|
||||
})
|
||||
},
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
temp := float32(0.8)
|
||||
body := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Temperature: &temp,
|
||||
}
|
||||
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
},
|
||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
var completionResp Completion
|
||||
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if completionResp.Object != "text_completion" {
|
||||
t.Fatalf("expected text_completion, got %s", completionResp.Object)
|
||||
}
|
||||
|
||||
if completionResp.Choices[0].Text != "Received temperature of 1.6" {
|
||||
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "completions handler with error",
|
||||
Name: "completions handler error forwarding",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/generate",
|
||||
TestPath: "/api/generate",
|
||||
|
||||
@@ -31,6 +31,10 @@ func NewSpinner(message string) *Spinner {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Spinner) SetMessage(message string) {
|
||||
s.message = message
|
||||
}
|
||||
|
||||
func (s *Spinner) String() string {
|
||||
var sb strings.Builder
|
||||
if len(s.message) > 0 {
|
||||
|
||||
@@ -32,8 +32,11 @@ import (
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var errCapabilityCompletion = errors.New("completion")
|
||||
|
||||
type Capability string
|
||||
|
||||
const CapabilityCompletion = Capability("completion")
|
||||
@@ -62,7 +65,10 @@ type Model struct {
|
||||
Template *template.Template
|
||||
}
|
||||
|
||||
func (m *Model) Has(caps ...Capability) bool {
|
||||
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
||||
// any missing or unknown capabilities
|
||||
func (m *Model) CheckCapabilities(caps ...Capability) error {
|
||||
var errs []error
|
||||
for _, cap := range caps {
|
||||
switch cap {
|
||||
case CapabilityCompletion:
|
||||
@@ -81,15 +87,19 @@ func (m *Model) Has(caps ...Capability) bool {
|
||||
}
|
||||
|
||||
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
|
||||
return false
|
||||
errs = append(errs, errCapabilityCompletion)
|
||||
}
|
||||
default:
|
||||
slog.Error("unknown capability", "capability", cap)
|
||||
return false
|
||||
return fmt.Errorf("unknown capability: %s", cap)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
if err := errors.Join(errs...); err != nil {
|
||||
return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) String() string {
|
||||
@@ -1055,11 +1065,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
if anonymous {
|
||||
// no user is associated with the public key, and the request requires non-anonymous access
|
||||
pubKey, nestedErr := auth.GetPublicKey()
|
||||
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
|
||||
if nestedErr != nil {
|
||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
||||
return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
|
||||
}
|
||||
// user is associated with the public key, but is not authorized to make the request
|
||||
return nil, errUnauthorized
|
||||
|
||||
250
server/prompt.go
250
server/prompt.go
@@ -1,217 +1,83 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"bytes"
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"text/template/parse"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
// isResponseNode checks if the node contains .Response
|
||||
func isResponseNode(node *parse.ActionNode) bool {
|
||||
for _, cmd := range node.Pipe.Cmds {
|
||||
for _, arg := range cmd.Args {
|
||||
if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
|
||||
if fieldNode.Ident[0] == "Response" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||
|
||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
|
||||
// pull out any system messages which should always be included in the prompt
|
||||
var system []api.Message
|
||||
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
|
||||
if m.Role == "system" {
|
||||
system = append(system, m)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// formatTemplateForResponse formats the template AST to:
|
||||
// 1. remove all nodes after the first .Response (if generate=true)
|
||||
// 2. add a .Response node to the end if it doesn't exist
|
||||
// TODO(jmorganca): this should recursively cut the template before the first .Response
|
||||
func formatTemplateForResponse(tmpl *template.Template, generate bool) {
|
||||
var found bool
|
||||
for i, node := range tmpl.Tree.Root.Nodes {
|
||||
if actionNode, ok := node.(*parse.ActionNode); ok {
|
||||
if isResponseNode(actionNode) {
|
||||
found = true
|
||||
if generate {
|
||||
tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
|
||||
break
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
if len(system) == 0 && m.System != "" {
|
||||
// add model system prompt since it wasn't provided
|
||||
system = append(system, api.Message{Role: "system", Content: m.System})
|
||||
}
|
||||
|
||||
// always include the last message
|
||||
n := len(msgs) - 1
|
||||
// in reverse, find all messages that fit into context window
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
// add the response node if it doesn't exist
|
||||
responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
|
||||
responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
|
||||
responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
|
||||
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
|
||||
}
|
||||
}
|
||||
|
||||
// Prompt renders a prompt from a template. If generate is set to true,
|
||||
// the response and parts of the template following it are not rendered
|
||||
func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
|
||||
formatTemplateForResponse(tmpl, generate)
|
||||
|
||||
vars := map[string]any{
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
if err := tmpl.Execute(&sb, vars); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
|
||||
rendered, err := Prompt(tmpl, system, prompt, response, false)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
tokens, err := encode(rendered)
|
||||
if err != nil {
|
||||
slog.Error("failed to encode prompt", "err", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(tokens), err
|
||||
}
|
||||
|
||||
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
||||
func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
||||
type prompt struct {
|
||||
System string
|
||||
Prompt string
|
||||
Response string
|
||||
|
||||
images []int
|
||||
tokens int
|
||||
}
|
||||
|
||||
var p prompt
|
||||
|
||||
// iterate through messages to build up {system,user,response} prompts
|
||||
var imgId int
|
||||
var prompts []prompt
|
||||
for _, msg := range messages {
|
||||
switch strings.ToLower(msg.Role) {
|
||||
case "system":
|
||||
if p.System != "" || p.Prompt != "" || p.Response != "" {
|
||||
prompts = append(prompts, p)
|
||||
p = prompt{}
|
||||
}
|
||||
|
||||
p.System = msg.Content
|
||||
case "user":
|
||||
if p.Prompt != "" || p.Response != "" {
|
||||
prompts = append(prompts, p)
|
||||
p = prompt{}
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for range msg.Images {
|
||||
fmt.Fprintf(&sb, "[img-%d] ", imgId)
|
||||
p.images = append(p.images, imgId)
|
||||
imgId += 1
|
||||
}
|
||||
|
||||
sb.WriteString(msg.Content)
|
||||
p.Prompt = sb.String()
|
||||
case "assistant":
|
||||
if p.Response != "" {
|
||||
prompts = append(prompts, p)
|
||||
p = prompt{}
|
||||
}
|
||||
|
||||
p.Response = msg.Content
|
||||
default:
|
||||
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
}
|
||||
}
|
||||
|
||||
// add final prompt
|
||||
if p.System != "" || p.Prompt != "" || p.Response != "" {
|
||||
prompts = append(prompts, p)
|
||||
}
|
||||
|
||||
// calculate token lengths for each prompt, estimating 768 tokens per images
|
||||
for i, p := range prompts {
|
||||
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
|
||||
s, err := tokenize(ctx, b.String())
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
prompts[i].tokens = tokens + len(prompts[i].images)*768
|
||||
}
|
||||
|
||||
// truncate images and prompts starting from the beginning of the list
|
||||
// until either one prompt remains or the total tokens fits the context window
|
||||
// TODO (jmorganca): this doesn't account for the context window room required for the response
|
||||
for {
|
||||
var required int
|
||||
for _, p := range prompts {
|
||||
required += p.tokens
|
||||
c := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, m := range msgs[i:] {
|
||||
// images are represented as 768 sized embeddings
|
||||
// TODO: get embedding length from project metadata
|
||||
c += 768 * len(m.Images)
|
||||
}
|
||||
}
|
||||
|
||||
required += 1 // for bos token
|
||||
|
||||
if required <= window {
|
||||
slog.Debug("prompt now fits in context window", "required", required, "window", window)
|
||||
if c > opts.NumCtx {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||
break
|
||||
} else {
|
||||
n = i
|
||||
}
|
||||
|
||||
prompt := &prompts[0]
|
||||
|
||||
if len(prompt.images) > 1 {
|
||||
img := prompt.images[0]
|
||||
slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
|
||||
prompt.images = prompt.images[1:]
|
||||
prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
|
||||
prompt.tokens -= 768
|
||||
continue
|
||||
}
|
||||
|
||||
if len(prompts) > 1 {
|
||||
slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
|
||||
system := prompt.System
|
||||
prompts = prompts[1:]
|
||||
|
||||
if system != "" && prompts[0].System == "" {
|
||||
prompts[0].System = system
|
||||
|
||||
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
prompts[0].tokens = tokens + len(prompts[0].images)*768
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// stop truncating if there's only one prompt left
|
||||
break
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for i, p := range prompts {
|
||||
// last prompt should leave the response unrendered (for completion)
|
||||
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString(rendered)
|
||||
// truncate any messages that do not fit into the context window
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
for _, m := range msgs[n:] {
|
||||
for _, i := range m.Images {
|
||||
images = append(images, llm.ImageData{
|
||||
ID: len(images),
|
||||
Data: i,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), images, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -8,208 +10,195 @@ import (
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
func TestPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
system string
|
||||
prompt string
|
||||
response string
|
||||
generate bool
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "simple prompt",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
system: "You are a Wizard.",
|
||||
prompt: "What are the potion ingredients?",
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "implicit response",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
system: "You are a Wizard.",
|
||||
prompt: "What are the potion ingredients?",
|
||||
response: "I don't know.",
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
|
||||
},
|
||||
{
|
||||
name: "response",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
||||
system: "You are a Wizard.",
|
||||
prompt: "What are the potion ingredients?",
|
||||
response: "I don't know.",
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
|
||||
},
|
||||
{
|
||||
name: "cut",
|
||||
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
|
||||
system: "You are a Wizard.",
|
||||
prompt: "What are the potion ingredients?",
|
||||
response: "I don't know.",
|
||||
generate: true,
|
||||
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
|
||||
},
|
||||
{
|
||||
name: "nocut",
|
||||
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
|
||||
system: "You are a Wizard.",
|
||||
prompt: "What are the potion ingredients?",
|
||||
response: "I don't know.",
|
||||
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
|
||||
},
|
||||
func tokenize(_ context.Context, s string) (tokens []int, err error) {
|
||||
for range strings.Fields(s) {
|
||||
tokens = append(tokens, len(tokens))
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.Parse(tc.template)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
|
||||
if err != nil {
|
||||
t.Errorf("error = %v", err)
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("got = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestChatPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
messages []api.Message
|
||||
window int
|
||||
want string
|
||||
type expect struct {
|
||||
prompt string
|
||||
images [][]byte
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
limit int
|
||||
msgs []api.Message
|
||||
expect
|
||||
}{
|
||||
{
|
||||
name: "simple prompt",
|
||||
template: "[INST] {{ .Prompt }} [/INST]",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
name: "messages",
|
||||
limit: 64,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
window: 1024,
|
||||
want: "[INST] Hello [/INST]",
|
||||
},
|
||||
{
|
||||
name: "with system message",
|
||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
name: "truncate messages",
|
||||
limit: 1,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
window: 1024,
|
||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
|
||||
},
|
||||
{
|
||||
name: "with response",
|
||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "I am?"},
|
||||
name: "truncate messages with image",
|
||||
limit: 64,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
},
|
||||
},
|
||||
window: 1024,
|
||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
|
||||
},
|
||||
{
|
||||
name: "with implicit response",
|
||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "I am?"},
|
||||
name: "truncate messages with images",
|
||||
limit: 64,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
window: 1024,
|
||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
|
||||
},
|
||||
{
|
||||
name: "with conversation",
|
||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "What are the potion ingredients?"},
|
||||
{Role: "assistant", Content: "sugar"},
|
||||
{Role: "user", Content: "Anything else?"},
|
||||
name: "messages with images",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
window: 1024,
|
||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
|
||||
},
|
||||
{
|
||||
name: "with truncation",
|
||||
template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "I am?"},
|
||||
{Role: "user", Content: "Why is the sky blue?"},
|
||||
{Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
|
||||
name: "message with image tag",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
window: 10,
|
||||
want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
|
||||
},
|
||||
{
|
||||
name: "images",
|
||||
template: "{{ .System }} {{ .Prompt }}",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
|
||||
name: "messages with interleaved images",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
window: 1024,
|
||||
want: "You are a Wizard. [img-0] Hello",
|
||||
},
|
||||
{
|
||||
name: "images truncated",
|
||||
template: "{{ .System }} {{ .Prompt }}",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
|
||||
name: "truncate message with interleaved images",
|
||||
limit: 1024,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
window: 1024,
|
||||
want: "You are a Wizard. [img-0] [img-1] Hello",
|
||||
},
|
||||
{
|
||||
name: "empty list",
|
||||
template: "{{ .System }} {{ .Prompt }}",
|
||||
messages: []api.Message{},
|
||||
window: 1024,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty prompt",
|
||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: ""},
|
||||
name: "message with system prompt",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are the Test Who Lived."},
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
window: 1024,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
encode := func(s string) ([]int, error) {
|
||||
words := strings.Fields(s)
|
||||
return make([]int, len(words)), nil
|
||||
tmpl, err := template.Parse(`
|
||||
{{- if .System }}{{ .System }} {{ end }}
|
||||
{{- if .Prompt }}{{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}{{ .Response }} {{ end }}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.Parse(tc.template)
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
|
||||
if err != nil {
|
||||
t.Errorf("error = %v", err)
|
||||
if tt.prompt != prompt {
|
||||
t.Errorf("expected %q, got %q", tt.prompt, prompt)
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("got: %q, want: %q", got, tc.want)
|
||||
if len(images) != len(tt.images) {
|
||||
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
|
||||
}
|
||||
|
||||
for i := range images {
|
||||
if images[i].ID != i {
|
||||
t.Errorf("expected ID %d, got %d", i, images[i].ID)
|
||||
}
|
||||
|
||||
if !bytes.Equal(images[i].Data, tt.images[i]) {
|
||||
t.Errorf("expected %q, got %q", tt.images[i], images[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
557
server/routes.go
557
server/routes.go
@@ -1,13 +1,15 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -22,8 +24,10 @@ import (
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/ollama/ollama/llm"
|
||||
@@ -54,6 +58,8 @@ func init() {
|
||||
gin.SetMode(mode)
|
||||
}
|
||||
|
||||
var errRequired = errors.New("is required")
|
||||
|
||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||
opts := api.DefaultOptions()
|
||||
if err := opts.FromMap(model.Options); err != nil {
|
||||
@@ -67,163 +73,140 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func isSupportedImageType(image []byte) bool {
|
||||
contentType := http.DetectContentType(image)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
|
||||
return slices.Contains(allowedTypes, contentType)
|
||||
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
||||
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
||||
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
||||
if name == "" {
|
||||
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
|
||||
}
|
||||
|
||||
model, err := GetModel(name)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if err := model.CheckCapabilities(caps...); err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
||||
}
|
||||
|
||||
opts, err := modelOptions(model, requestOpts)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err = <-errCh:
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
var req api.GenerateRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// validate the request
|
||||
switch {
|
||||
case req.Model == "":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
if req.Format != "" && req.Format != "json" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
|
||||
return
|
||||
case len(req.Format) > 0 && req.Format != "json":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
||||
return
|
||||
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
|
||||
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
||||
return
|
||||
}
|
||||
|
||||
for _, img := range req.Images {
|
||||
if !isSupportedImageType(img) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
if errors.As(err, &pErr) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
caps := []Capability{CapabilityCompletion}
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||
return
|
||||
} else if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !model.Has(CapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-rCh:
|
||||
case err = <-eCh:
|
||||
handleErrorResponse(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// an empty request loads the model
|
||||
// note: for a short while template was used in lieu
|
||||
// of `raw` mode so we need to check for it too
|
||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tmpl, err := template.Parse(req.Template)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
images := make([]llm.ImageData, len(req.Images))
|
||||
for i := range req.Images {
|
||||
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
var prompt string
|
||||
switch {
|
||||
case req.Raw:
|
||||
prompt = req.Prompt
|
||||
case req.Prompt != "":
|
||||
if req.Template == "" {
|
||||
tmpl = model.Template
|
||||
prompt := req.Prompt
|
||||
if !req.Raw {
|
||||
var msgs []api.Message
|
||||
if req.System != "" {
|
||||
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
||||
} else if m.System != "" {
|
||||
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
|
||||
}
|
||||
|
||||
if req.System == "" {
|
||||
req.System = model.System
|
||||
for _, i := range images {
|
||||
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
||||
}
|
||||
|
||||
slog.Debug("generate handler", "prompt", req.Prompt)
|
||||
slog.Debug("generate handler", "template", req.Template)
|
||||
slog.Debug("generate handler", "system", req.System)
|
||||
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||
|
||||
var sb strings.Builder
|
||||
for i := range req.Images {
|
||||
fmt.Fprintf(&sb, "[img-%d] ", i)
|
||||
tmpl := m.Template
|
||||
if req.Template != "" {
|
||||
tmpl, err = template.Parse(req.Template)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(req.Prompt)
|
||||
|
||||
p, err := Prompt(tmpl, req.System, sb.String(), "", true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
sb.Reset()
|
||||
var b bytes.Buffer
|
||||
if req.Context != nil {
|
||||
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
|
||||
s, err := r.Detokenize(c.Request.Context(), req.Context)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
sb.WriteString(prev)
|
||||
b.WriteString(s)
|
||||
}
|
||||
|
||||
sb.WriteString(p)
|
||||
if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
prompt = sb.String()
|
||||
prompt = b.String()
|
||||
}
|
||||
|
||||
slog.Debug("generate handler", "prompt", prompt)
|
||||
slog.Debug("generate request", "prompt", prompt, "images", images)
|
||||
|
||||
ch := make(chan any)
|
||||
var generated strings.Builder
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
fn := func(r llm.CompletionResponse) {
|
||||
// Build up the full response
|
||||
if _, err := generated.WriteString(r.Content); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
resp := api.GenerateResponse{
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
ch <- api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: r.Done,
|
||||
Response: r.Content,
|
||||
Done: r.Done,
|
||||
DoneReason: r.DoneReason,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
@@ -232,77 +215,35 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
EvalDuration: r.EvalDuration,
|
||||
},
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
resp.TotalDuration = time.Since(checkpointStart)
|
||||
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
|
||||
if !req.Raw {
|
||||
p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// TODO (jmorganca): encode() should not strip special tokens
|
||||
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
resp.Context = append(req.Context, tokens...)
|
||||
}
|
||||
}
|
||||
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
var images []llm.ImageData
|
||||
for i := range req.Images {
|
||||
images = append(images, llm.ImageData{
|
||||
ID: i,
|
||||
Data: req.Images[i],
|
||||
})
|
||||
}
|
||||
|
||||
// Start prediction
|
||||
req := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
// Accumulate responses into the final response
|
||||
var final api.GenerateResponse
|
||||
var r api.GenerateResponse
|
||||
var sb strings.Builder
|
||||
for resp := range ch {
|
||||
switch r := resp.(type) {
|
||||
for rr := range ch {
|
||||
switch t := rr.(type) {
|
||||
case api.GenerateResponse:
|
||||
sb.WriteString(r.Response)
|
||||
final = r
|
||||
sb.WriteString(t.Response)
|
||||
r = t
|
||||
case gin.H:
|
||||
if errorMsg, ok := r["error"].(string); ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
||||
return
|
||||
msg, ok := t["error"].(string)
|
||||
if !ok {
|
||||
msg = "unexpected error format in response"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||
return
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
final.Response = sb.String()
|
||||
c.JSON(http.StatusOK, final)
|
||||
r.Response = sb.String()
|
||||
c.JSON(http.StatusOK, r)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -311,44 +252,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
var req api.EmbeddingRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
model, err := GetModel(req.Model)
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
if errors.As(err, &pErr) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-rCh:
|
||||
case err = <-eCh:
|
||||
handleErrorResponse(c, err)
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -358,17 +272,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
|
||||
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
}
|
||||
|
||||
resp := api.EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
|
||||
}
|
||||
|
||||
func (s *Server) PullModelHandler(c *gin.Context) {
|
||||
@@ -649,9 +560,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
}
|
||||
|
||||
msgs := make([]api.Message, 0)
|
||||
for _, msg := range m.Messages {
|
||||
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
|
||||
msgs := make([]api.Message, len(m.Messages))
|
||||
for i, msg := range m.Messages {
|
||||
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
||||
}
|
||||
|
||||
n := model.ParseName(req.Model)
|
||||
@@ -863,7 +774,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = os.Stat(path)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
@@ -876,6 +786,12 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
|
||||
c.Header("LocalLocation", path)
|
||||
c.Status(http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
layer, err := NewLayer(c.Request.Body, "")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -890,6 +806,54 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
c.Status(http.StatusCreated)
|
||||
}
|
||||
|
||||
func (s *Server) IsLocal(c *gin.Context) bool {
|
||||
if authz := c.GetHeader("Authorization"); authz != "" {
|
||||
parts := strings.Split(authz, ":")
|
||||
if len(parts) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
|
||||
requestData, err := base64.StdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
partialRequestDataParts := strings.Split(string(requestData), ",")
|
||||
if len(partialRequestDataParts) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
signature, err := base64.StdEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := clientPublicKey.Verify(requestData, &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
serverPublicKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
|
||||
return true
|
||||
}
|
||||
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isLocalIP(ip netip.Addr) bool {
|
||||
if interfaces, err := net.Interfaces(); err == nil {
|
||||
for _, iface := range interfaces {
|
||||
@@ -1214,132 +1178,55 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||
}
|
||||
|
||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||
func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
|
||||
encode := func(s string) ([]int, error) {
|
||||
return runner.llama.Tokenize(ctx, s)
|
||||
}
|
||||
|
||||
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return prompt, nil
|
||||
}
|
||||
|
||||
func (s *Server) ChatHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
var req api.ChatRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// validate the request
|
||||
switch {
|
||||
case req.Model == "":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
caps := []Capability{CapabilityCompletion}
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||
return
|
||||
case len(req.Format) > 0 && req.Format != "json":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
||||
} else if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
if errors.As(err, &pErr) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if !model.Has(CapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-rCh:
|
||||
case err = <-eCh:
|
||||
handleErrorResponse(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
// if the first message is not a system message, then add the model's default system message
|
||||
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
|
||||
req.Messages = append([]api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: model.System,
|
||||
},
|
||||
}, req.Messages...)
|
||||
}
|
||||
|
||||
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// an empty request loads the model
|
||||
if len(req.Messages) == 0 || prompt == "" {
|
||||
resp := api.ChatResponse{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
if len(req.Messages) == 0 {
|
||||
c.JSON(http.StatusOK, api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant"},
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
Message: api.Message{Role: "assistant"},
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// only send images that are in the prompt
|
||||
var i int
|
||||
var images []llm.ImageData
|
||||
for _, m := range req.Messages {
|
||||
for _, img := range m.Images {
|
||||
if !isSupportedImageType(img) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
|
||||
images = append(images, llm.ImageData{Data: img, ID: i})
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("chat handler", "prompt", prompt, "images", len(images))
|
||||
slog.Debug("chat request", "images", len(images), "prompt", prompt)
|
||||
|
||||
ch := make(chan any)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
fn := func(r llm.CompletionResponse) {
|
||||
resp := api.ChatResponse{
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
ch <- api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
@@ -1352,64 +1239,52 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
EvalDuration: r.EvalDuration,
|
||||
},
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
resp.TotalDuration = time.Since(checkpointStart)
|
||||
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}, fn); err != nil {
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
// Accumulate responses into the final response
|
||||
var final api.ChatResponse
|
||||
var r api.ChatResponse
|
||||
var sb strings.Builder
|
||||
for resp := range ch {
|
||||
switch r := resp.(type) {
|
||||
for rr := range ch {
|
||||
switch t := rr.(type) {
|
||||
case api.ChatResponse:
|
||||
sb.WriteString(r.Message.Content)
|
||||
final = r
|
||||
sb.WriteString(t.Message.Content)
|
||||
r = t
|
||||
case gin.H:
|
||||
if errorMsg, ok := r["error"].(string); ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
||||
return
|
||||
msg, ok := t["error"].(string)
|
||||
if !ok {
|
||||
msg = "unexpected error format in response"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||
return
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
final.Message = api.Message{Role: "assistant", Content: sb.String()}
|
||||
c.JSON(http.StatusOK, final)
|
||||
r.Message.Content = sb.String()
|
||||
c.JSON(http.StatusOK, r)
|
||||
return
|
||||
}
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func handleErrorResponse(c *gin.Context, err error) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||
switch {
|
||||
case errors.Is(err, errRequired):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
case errors.Is(err, context.Canceled):
|
||||
c.JSON(499, gin.H{"error": "request canceled"})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, ErrMaxQueue) {
|
||||
case errors.Is(err, ErrMaxQueue):
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
||||
return
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
@@ -545,9 +545,9 @@ func TestCreateDetectTemplate(t *testing.T) {
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"),
|
||||
filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
|
||||
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
||||
filepath.Join(p, "blobs", "sha256-9512c372dfc7d84d6065b8dd2b601aeed8cc1a78e7a7aa784a42fff37f5524b7"),
|
||||
filepath.Join(p, "blobs", "sha256-b8b78cb8c6eefd14c06f1af042e6161255bf87bbf2dd14fce57cdac893db8139"),
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -133,10 +133,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
numParallel = 1
|
||||
slog.Warn("multimodal models don't support parallel requests yet")
|
||||
}
|
||||
// Keep NumCtx and numParallel in sync
|
||||
if numParallel > 1 {
|
||||
pending.opts.NumCtx = pending.origNumCtx * numParallel
|
||||
}
|
||||
|
||||
for {
|
||||
cpus := s.getCpuFn()
|
||||
@@ -234,9 +230,10 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
// simplifying assumption of defaultParallel when in CPU mode
|
||||
if numParallel <= 0 {
|
||||
numParallel = defaultParallel
|
||||
pending.opts.NumCtx = pending.origNumCtx * numParallel
|
||||
}
|
||||
|
||||
pending.opts.NumCtx = pending.origNumCtx * numParallel
|
||||
|
||||
if loadedCount == 0 {
|
||||
slog.Debug("cpu mode with first model, loading")
|
||||
s.loadFn(pending, ggml, gpus, numParallel)
|
||||
|
||||
@@ -1 +1,8 @@
|
||||
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}<start_system>{{ .System }}<end_message>
|
||||
{{- end }}
|
||||
{{- range .Messages }}<start_{{ .Role }}>{{ .Content }}<end_message>
|
||||
{{- end }}<start_assistant>
|
||||
{{- else }}
|
||||
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
|
||||
{{- end }}
|
||||
@@ -1,7 +1,19 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}{{ .System }}
|
||||
{{- end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}### Instruction:
|
||||
{{- else if eq .Role "assistant" }}### Response:
|
||||
{{- end }}
|
||||
{{ .Content }}
|
||||
|
||||
{{ end }}### Response:
|
||||
{{ else }}
|
||||
{{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}{{ if .Prompt }}### Instruction:
|
||||
{{ .Prompt }}
|
||||
|
||||
{{ end }}### Response:
|
||||
{{ .Response }}
|
||||
{{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1,6 +1,15 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range .Messages }}<|im_start|>{{ .Role }}
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ else }}
|
||||
{{ if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ .Response }}<|im_end|>
|
||||
{{ .Response }}<|im_end|>
|
||||
{{- end }}
|
||||
@@ -1,5 +1,17 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}System: {{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}User:
|
||||
{{- else if eq .Role "assistant" }}Assistant:
|
||||
{{- end }} {{ .Content }}
|
||||
|
||||
{{ end }}Assistant:
|
||||
{{- else }}
|
||||
{{ if .System }}System: {{ .System }}
|
||||
|
||||
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
||||
|
||||
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}
|
||||
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1,3 +1,13 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}Source: system
|
||||
|
||||
{{ .System }} <step> {{ end }}
|
||||
{{- range .Messages }}Source: {{ .Role }}
|
||||
|
||||
{{ .Content }} <step> {{ end }}Source: assistant
|
||||
Destination: user
|
||||
|
||||
{{ else }}
|
||||
{{ if .System }} Source: system
|
||||
|
||||
{{ .System }} <step>{{ end }} Source: user
|
||||
@@ -5,4 +15,5 @@
|
||||
{{ .Prompt }} <step> Source: assistant
|
||||
Destination: user
|
||||
|
||||
{{ .Response }}<step>
|
||||
{{ .Response }}<step>
|
||||
{{- end }}
|
||||
@@ -1,3 +1,13 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}System: {{ .System }}
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}User:
|
||||
{{ else if eq .Role "assistant" }}Falcon:
|
||||
{{ end }}{{ .Content }}
|
||||
{{ end }}Falcon:
|
||||
{{ else }}
|
||||
{{ if .System }}{{ .System }}
|
||||
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
||||
{{ end }}Assistant: {{ .Response }}
|
||||
{{ end }}Assistant: {{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1,4 +1,16 @@
|
||||
{{- if .Messages }}
|
||||
{{- range $index, $_ := .Messages }}<start_of_turn>
|
||||
{{- if eq .Role "user" }}user
|
||||
{{- if and $.System (eq $index 0) }}
|
||||
{{ $.System }}
|
||||
{{- end }}
|
||||
{{- else if eq .Role "assistant" }}model
|
||||
{{- end }}
|
||||
{{ .Content }}<end_of_turn>
|
||||
{{ end }}<start_of_turn>model
|
||||
{{ else }}
|
||||
<start_of_turn>user
|
||||
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
|
||||
<start_of_turn>model
|
||||
{{ .Response }}<end_of_turn>
|
||||
{{ .Response }}<end_of_turn>
|
||||
{{- end }}
|
||||
@@ -1,3 +1,16 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}System:
|
||||
{{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}Question:
|
||||
{{- else if eq .Role "assistant" }}Answer:
|
||||
{{- end }}
|
||||
{{ .Content }}
|
||||
|
||||
{{ end }}Answer:
|
||||
{{ else }}
|
||||
{{ if .System }}
|
||||
System:
|
||||
{{ .System }}
|
||||
@@ -6,4 +19,5 @@ System:
|
||||
{{ .Prompt }}
|
||||
|
||||
{{ end }}Answer:
|
||||
{{ .Response }}
|
||||
{{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1,3 +1,16 @@
|
||||
{{- if .Messages }}
|
||||
{{- range $index, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}[INST] {{ if eq $index 0 }}<<SYS>>
|
||||
{{- if $.System }}
|
||||
{{ $.System }}
|
||||
{{ end }}<</SYS>>
|
||||
|
||||
{{ end }}{{ .Content }}
|
||||
{{- else }} [/INST] {{ .Content }}</s><s>
|
||||
{{- end }}
|
||||
{{- end }} [/INST]
|
||||
{{- else }}
|
||||
[INST] <<SYS>>{{ .System }}<</SYS>>
|
||||
|
||||
{{ .Prompt }} [/INST] {{ .Response }}
|
||||
{{ .Prompt }} [/INST] {{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1,7 +1,19 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
{{ .System }}<|eot_id|>
|
||||
{{- end }}
|
||||
{{- range .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
|
||||
|
||||
{{ .Content }}<|eot_id|>
|
||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ else }}
|
||||
{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ .Response }}<|eot_id|>
|
||||
{{ .Response }}<|eot_id|>
|
||||
{{- end }}
|
||||
@@ -1,7 +1,20 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}{{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}@@ Instruction
|
||||
{{- else if eq .Role "assistant" }}@@ Response
|
||||
{{- end }}
|
||||
{{ .Content }}
|
||||
|
||||
{{ end }}@@ Response
|
||||
{{ else }}
|
||||
{{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}{{ if .Prompt }}@@ Instruction
|
||||
{{ .Prompt }}
|
||||
|
||||
{{ end }}@@ Response
|
||||
{{ .Response }}
|
||||
{{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1,6 +1,9 @@
|
||||
{{ if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ .Response }}<|im_end|>
|
||||
{{- if .Messages }}
|
||||
{{- range $index, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}[INST] {{ if and $.System (eq (len (slice $.Messages $index)) 1) }}{{ $.System }}
|
||||
{{ end }}{{ .Content }}
|
||||
{{- else if eq .Role "assistant" }}[/INST] {{ .Content }}</s>
|
||||
{{- end }}
|
||||
{{- end }}[/INST]
|
||||
{{- else }}[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST] {{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1 +1,11 @@
|
||||
{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}GPT Correct System: {{ .System }}<|end_of_turn|>
|
||||
{{- end }}
|
||||
{{- range .Messages }}GPT Correct
|
||||
{{- if eq .Role "user" }} User:
|
||||
{{- else if eq .Role "assistant" }} Assistant:
|
||||
{{- end }} {{ .Content }}<|end_of_turn|>
|
||||
{{- end }}GPT Correct Assistant:
|
||||
{{- else }}
|
||||
{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
|
||||
{{- end }}
|
||||
@@ -1,6 +1,15 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}<|system|>
|
||||
{{ .System }}<|end|>
|
||||
{{ end }}
|
||||
{{- range .Messages }}<|{{ .Role }}|>
|
||||
{{ .Content }}<|end|>
|
||||
{{ end }}<|assistant|>
|
||||
{{ else }}
|
||||
{{ if .System }}<|system|>
|
||||
{{ .System }}<|end|>
|
||||
{{ end }}{{ if .Prompt }}<|user|>
|
||||
{{ .Prompt }}<|end|>
|
||||
{{ end }}<|assistant|>
|
||||
{{ .Response }}<|end|>
|
||||
{{ .Response }}<|end|>
|
||||
{{- end }}
|
||||
@@ -1,3 +1,16 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}### System:
|
||||
{{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}### User:
|
||||
{{ .Content }}
|
||||
{{ else if eq .Role "assistant" }}### Assistant:
|
||||
{{ .Content }}</s>
|
||||
{{ end }}
|
||||
{{ end }}### Assistant:
|
||||
{{ else }}
|
||||
{{ if .System }}### System:
|
||||
{{ .System }}
|
||||
|
||||
@@ -5,4 +18,5 @@
|
||||
{{ .Prompt }}
|
||||
|
||||
{{ end }}### Assistant:
|
||||
{{ .Response }}
|
||||
{{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1,3 +1,17 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}{{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}### Instruction
|
||||
{{ .Content }}
|
||||
|
||||
{{ else if eq .Role "assistant" }}### Response
|
||||
{{ .Content }}<|endoftext|>
|
||||
|
||||
{{ end }}
|
||||
{{- end }}### Response
|
||||
{{ else }}
|
||||
{{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}{{ if .Prompt }}### Instruction
|
||||
@@ -7,3 +21,4 @@
|
||||
{{ end }}### Response
|
||||
{{ .Response }}<|endoftext|>
|
||||
|
||||
{{- end }}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"slices"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
"text/template/parse"
|
||||
|
||||
"github.com/agnivade/levenshtein"
|
||||
"github.com/ollama/ollama/api"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
@@ -74,30 +76,59 @@ func Named(s string) (*named, error) {
|
||||
return nil, errors.New("no matching template found")
|
||||
}
|
||||
|
||||
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
|
||||
|
||||
type Template struct {
|
||||
*template.Template
|
||||
raw string
|
||||
}
|
||||
|
||||
// response is a template node that can be added to templates that don't already have one
|
||||
var response = parse.ActionNode{
|
||||
NodeType: parse.NodeAction,
|
||||
Pipe: &parse.PipeNode{
|
||||
NodeType: parse.NodePipe,
|
||||
Cmds: []*parse.CommandNode{
|
||||
{
|
||||
NodeType: parse.NodeCommand,
|
||||
Args: []parse.Node{
|
||||
&parse.FieldNode{
|
||||
NodeType: parse.NodeField,
|
||||
Ident: []string{"Response"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func Parse(s string) (*Template, error) {
|
||||
tmpl := template.New("").Option("missingkey=zero")
|
||||
|
||||
tmpl, err := tmpl.Parse(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := Template{Template: tmpl, raw: s}
|
||||
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
|
||||
// touch up the template and append {{ .Response }}
|
||||
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
|
||||
}
|
||||
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func (t *Template) String() string {
|
||||
return t.raw
|
||||
}
|
||||
|
||||
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
|
||||
|
||||
func Parse(s string) (*Template, error) {
|
||||
t, err := template.New("").Option("missingkey=zero").Parse(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Template{Template: t, raw: s}, nil
|
||||
}
|
||||
|
||||
func (t *Template) Vars() []string {
|
||||
var vars []string
|
||||
for _, n := range t.Tree.Root.Nodes {
|
||||
vars = append(vars, parseNode(n)...)
|
||||
for _, tt := range t.Templates() {
|
||||
for _, n := range tt.Root.Nodes {
|
||||
vars = append(vars, parseNode(n)...)
|
||||
}
|
||||
}
|
||||
|
||||
set := make(map[string]struct{})
|
||||
@@ -110,6 +141,103 @@ func (t *Template) Vars() []string {
|
||||
return vars
|
||||
}
|
||||
|
||||
type Values struct {
|
||||
Messages []api.Message
|
||||
}
|
||||
|
||||
func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
system, collated := collate(v.Messages)
|
||||
if slices.Contains(t.Vars(), "messages") {
|
||||
return t.Template.Execute(w, map[string]any{
|
||||
"System": system,
|
||||
"Messages": collated,
|
||||
})
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
var prompt, response string
|
||||
for i, m := range collated {
|
||||
if m.Role == "user" {
|
||||
prompt = m.Content
|
||||
} else {
|
||||
response = m.Content
|
||||
}
|
||||
|
||||
if i != len(collated)-1 && prompt != "" && response != "" {
|
||||
if err := t.Template.Execute(&b, map[string]any{
|
||||
"System": "",
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prompt = ""
|
||||
response = ""
|
||||
}
|
||||
}
|
||||
|
||||
var cut bool
|
||||
tree := t.Template.Copy()
|
||||
// for the last message, cut everything after "{{ .Response }}"
|
||||
tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool {
|
||||
if slices.Contains(parseNode(n), "Response") {
|
||||
cut = true
|
||||
}
|
||||
|
||||
return cut
|
||||
})
|
||||
|
||||
if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := io.Copy(w, &b)
|
||||
return err
|
||||
}
|
||||
|
||||
type messages []*api.Message
|
||||
|
||||
// collate messages based on role. consecutive messages of the same role are merged
|
||||
// into a single message. collate also pulls out and merges messages with Role == "system"
|
||||
// which are templated separately. As a side effect, it mangles message content adding image
|
||||
// tags ([img-%d]) as needed
|
||||
func collate(msgs []api.Message) (system string, collated messages) {
|
||||
var n int
|
||||
for i := range msgs {
|
||||
msg := msgs[i]
|
||||
if msg.Role == "system" {
|
||||
if system != "" {
|
||||
system += "\n\n"
|
||||
}
|
||||
|
||||
system += msg.Content
|
||||
continue
|
||||
}
|
||||
|
||||
for range msg.Images {
|
||||
imageTag := fmt.Sprintf("[img-%d]", n)
|
||||
if !strings.Contains(msg.Content, "[img]") {
|
||||
msg.Content = strings.TrimSpace("[img] " + msg.Content)
|
||||
}
|
||||
|
||||
msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
|
||||
n++
|
||||
}
|
||||
|
||||
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
|
||||
collated[len(collated)-1].Content += "\n\n" + msg.Content
|
||||
} else {
|
||||
collated = append(collated, &msg)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func parseNode(n parse.Node) []string {
|
||||
switch n := n.(type) {
|
||||
case *parse.ActionNode:
|
||||
@@ -152,6 +280,8 @@ func parseNode(n parse.Node) []string {
|
||||
return names
|
||||
case *parse.FieldNode:
|
||||
return n.Ident
|
||||
case *parse.TemplateNode:
|
||||
return parseNode(n.Pipe)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -8,9 +8,11 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"text/template"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
@@ -46,7 +48,7 @@ func TestNamed(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tmpl, err := template.New(s).Parse(b.String())
|
||||
tmpl, err := Parse(b.String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -59,18 +61,81 @@ func TestNamed(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplate(t *testing.T) {
|
||||
cases := make(map[string][]api.Message)
|
||||
for _, mm := range [][]api.Message{
|
||||
{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
|
||||
{Role: "user", Content: "I'd like to show off how chat templating works!"},
|
||||
},
|
||||
{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
|
||||
{Role: "user", Content: "I'd like to show off how chat templating works!"},
|
||||
},
|
||||
} {
|
||||
var roles []string
|
||||
for _, m := range mm {
|
||||
roles = append(roles, m.Role)
|
||||
}
|
||||
|
||||
cases[strings.Join(roles, "-")] = mm
|
||||
}
|
||||
|
||||
matches, err := filepath.Glob("*.gotmpl")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
t.Run(match, func(t *testing.T) {
|
||||
bts, err := os.ReadFile(match)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tmpl, err := Parse(string(bts))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for n, tt := range cases {
|
||||
t.Run(n, func(t *testing.T) {
|
||||
var actual bytes.Buffer
|
||||
if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect, err := os.ReadFile(filepath.Join("testdata", match, n))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(actual.Bytes(), expect); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
cases := []struct {
|
||||
template string
|
||||
vars []string
|
||||
}{
|
||||
{"{{ .Prompt }}", []string{"prompt"}},
|
||||
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
|
||||
{"{{ .Prompt }}", []string{"prompt", "response"}},
|
||||
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
|
||||
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
|
||||
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}},
|
||||
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
|
||||
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
|
||||
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
|
||||
{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
@@ -87,3 +152,159 @@ func TestParse(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteWithMessages(t *testing.T) {
|
||||
type template struct {
|
||||
name string
|
||||
template string
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
templates []template
|
||||
values Values
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"mistral",
|
||||
[]template{
|
||||
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
|
||||
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||
{"messages", `{{- range $index, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
|
||||
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
||||
{{- end }}
|
||||
{{- end }}`},
|
||||
},
|
||||
Values{
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello friend!"},
|
||||
{Role: "assistant", Content: "Hello human!"},
|
||||
{Role: "user", Content: "What is your name?"},
|
||||
},
|
||||
},
|
||||
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
||||
},
|
||||
{
|
||||
"mistral system",
|
||||
[]template{
|
||||
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
|
||||
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||
{"messages", `
|
||||
{{- range $index, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
|
||||
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
||||
{{- end }}
|
||||
{{- end }}`},
|
||||
},
|
||||
Values{
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant!"},
|
||||
{Role: "user", Content: "Hello friend!"},
|
||||
{Role: "assistant", Content: "Hello human!"},
|
||||
{Role: "user", Content: "What is your name?"},
|
||||
},
|
||||
},
|
||||
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
|
||||
|
||||
What is your name?[/INST] `,
|
||||
},
|
||||
{
|
||||
"chatml",
|
||||
[]template{
|
||||
// this does not have a "no response" test because it's impossible to render the same output
|
||||
{"response", `{{ if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ .Response }}<|im_end|>
|
||||
`},
|
||||
{"messages", `
|
||||
{{- range $index, $_ := .Messages }}
|
||||
{{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
|
||||
{{ $.System }}<|im_end|>{{ "\n" }}
|
||||
{{- end }}<|im_start|>{{ .Role }}
|
||||
{{ .Content }}<|im_end|>{{ "\n" }}
|
||||
{{- end }}<|im_start|>assistant
|
||||
`},
|
||||
},
|
||||
Values{
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant!"},
|
||||
{Role: "user", Content: "Hello friend!"},
|
||||
{Role: "assistant", Content: "Hello human!"},
|
||||
{Role: "user", Content: "What is your name?"},
|
||||
},
|
||||
},
|
||||
`<|im_start|>user
|
||||
Hello friend!<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hello human!<|im_end|>
|
||||
<|im_start|>system
|
||||
You are a helpful assistant!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is your name?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
{
|
||||
"moondream",
|
||||
[]template{
|
||||
// this does not have a "no response" test because it's impossible to render the same output
|
||||
{"response", `{{ if .Prompt }}Question: {{ .Prompt }}
|
||||
|
||||
{{ end }}Answer: {{ .Response }}
|
||||
|
||||
`},
|
||||
{"messages", `
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
|
||||
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
|
||||
{{- end }}
|
||||
{{- end }}Answer: `},
|
||||
},
|
||||
Values{
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
|
||||
{Role: "assistant", Content: "It's a hot dog."},
|
||||
{Role: "user", Content: "What's in _this_ image?"},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("")}},
|
||||
{Role: "user", Content: "Is it a hot dog?"},
|
||||
},
|
||||
},
|
||||
`Question: [img-0] What's in this image?
|
||||
|
||||
Answer: It's a hot dog.
|
||||
|
||||
Question: What's in _this_ image?
|
||||
|
||||
[img-1]
|
||||
|
||||
Is it a hot dog?
|
||||
|
||||
Answer: `,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for _, ttt := range tt.templates {
|
||||
t.Run(ttt.name, func(t *testing.T) {
|
||||
tmpl, err := Parse(ttt.template)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, tt.values); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if b.String() != tt.expected {
|
||||
t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
1
template/testdata/alfred.gotmpl/system-user-assistant-user
vendored
Normal file
1
template/testdata/alfred.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1 @@
|
||||
<start_system>You are a helpful assistant.<end_message><start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>
|
||||
1
template/testdata/alfred.gotmpl/user
vendored
Normal file
1
template/testdata/alfred.gotmpl/user
vendored
Normal file
@@ -0,0 +1 @@
|
||||
<start_user>Hello, how are you?<end_message><start_assistant>
|
||||
1
template/testdata/alfred.gotmpl/user-assistant-user
vendored
Normal file
1
template/testdata/alfred.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1 @@
|
||||
<start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>
|
||||
10
template/testdata/alpaca.gotmpl/system-user-assistant-user
vendored
Normal file
10
template/testdata/alpaca.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
You are a helpful assistant.### Instruction:
|
||||
Hello, how are you?
|
||||
|
||||
### Response:
|
||||
I'm doing great. How can I help you today?
|
||||
|
||||
### Instruction:
|
||||
I'd like to show off how chat templating works!
|
||||
|
||||
### Response:
|
||||
4
template/testdata/alpaca.gotmpl/user
vendored
Normal file
4
template/testdata/alpaca.gotmpl/user
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
### Instruction:
|
||||
Hello, how are you?
|
||||
|
||||
### Response:
|
||||
10
template/testdata/alpaca.gotmpl/user-assistant-user
vendored
Normal file
10
template/testdata/alpaca.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
### Instruction:
|
||||
Hello, how are you?
|
||||
|
||||
### Response:
|
||||
I'm doing great. How can I help you today?
|
||||
|
||||
### Instruction:
|
||||
I'd like to show off how chat templating works!
|
||||
|
||||
### Response:
|
||||
9
template/testdata/chatml.gotmpl/system-user-assistant-user
vendored
Normal file
9
template/testdata/chatml.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
Hello, how are you?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
I'm doing great. How can I help you today?<|im_end|>
|
||||
<|im_start|>user
|
||||
I'd like to show off how chat templating works!<|im_end|>
|
||||
<|im_start|>assistant
|
||||
3
template/testdata/chatml.gotmpl/user
vendored
Normal file
3
template/testdata/chatml.gotmpl/user
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
<|im_start|>user
|
||||
Hello, how are you?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
7
template/testdata/chatml.gotmpl/user-assistant-user
vendored
Normal file
7
template/testdata/chatml.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
<|im_start|>user
|
||||
Hello, how are you?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
I'm doing great. How can I help you today?<|im_end|>
|
||||
<|im_start|>user
|
||||
I'd like to show off how chat templating works!<|im_end|>
|
||||
<|im_start|>assistant
|
||||
9
template/testdata/chatqa.gotmpl/system-user-assistant-user
vendored
Normal file
9
template/testdata/chatqa.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
System: You are a helpful assistant.
|
||||
|
||||
User: Hello, how are you?
|
||||
|
||||
Assistant: I'm doing great. How can I help you today?
|
||||
|
||||
User: I'd like to show off how chat templating works!
|
||||
|
||||
Assistant:
|
||||
3
template/testdata/chatqa.gotmpl/user
vendored
Normal file
3
template/testdata/chatqa.gotmpl/user
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
User: Hello, how are you?
|
||||
|
||||
Assistant:
|
||||
7
template/testdata/chatqa.gotmpl/user-assistant-user
vendored
Normal file
7
template/testdata/chatqa.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
User: Hello, how are you?
|
||||
|
||||
Assistant: I'm doing great. How can I help you today?
|
||||
|
||||
User: I'd like to show off how chat templating works!
|
||||
|
||||
Assistant:
|
||||
11
template/testdata/codellama-70b-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
11
template/testdata/codellama-70b-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
Source: system
|
||||
|
||||
You are a helpful assistant. <step> Source: user
|
||||
|
||||
Hello, how are you? <step> Source: assistant
|
||||
|
||||
I'm doing great. How can I help you today? <step> Source: user
|
||||
|
||||
I'd like to show off how chat templating works! <step> Source: assistant
|
||||
Destination: user
|
||||
|
||||
5
template/testdata/codellama-70b-instruct.gotmpl/user
vendored
Normal file
5
template/testdata/codellama-70b-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
Source: user
|
||||
|
||||
Hello, how are you? <step> Source: assistant
|
||||
Destination: user
|
||||
|
||||
9
template/testdata/codellama-70b-instruct.gotmpl/user-assistant-user
vendored
Normal file
9
template/testdata/codellama-70b-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
Source: user
|
||||
|
||||
Hello, how are you? <step> Source: assistant
|
||||
|
||||
I'm doing great. How can I help you today? <step> Source: user
|
||||
|
||||
I'd like to show off how chat templating works! <step> Source: assistant
|
||||
Destination: user
|
||||
|
||||
8
template/testdata/falcon-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
8
template/testdata/falcon-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
System: You are a helpful assistant.
|
||||
User:
|
||||
Hello, how are you?
|
||||
Falcon:
|
||||
I'm doing great. How can I help you today?
|
||||
User:
|
||||
I'd like to show off how chat templating works!
|
||||
Falcon:
|
||||
3
template/testdata/falcon-instruct.gotmpl/user
vendored
Normal file
3
template/testdata/falcon-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
User:
|
||||
Hello, how are you?
|
||||
Falcon:
|
||||
7
template/testdata/falcon-instruct.gotmpl/user-assistant-user
vendored
Normal file
7
template/testdata/falcon-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
User:
|
||||
Hello, how are you?
|
||||
Falcon:
|
||||
I'm doing great. How can I help you today?
|
||||
User:
|
||||
I'd like to show off how chat templating works!
|
||||
Falcon:
|
||||
8
template/testdata/gemma-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
8
template/testdata/gemma-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
<start_of_turn>user
|
||||
You are a helpful assistant.
|
||||
Hello, how are you?<end_of_turn>
|
||||
<start_of_turn>model
|
||||
I'm doing great. How can I help you today?<end_of_turn>
|
||||
<start_of_turn>user
|
||||
I'd like to show off how chat templating works!<end_of_turn>
|
||||
<start_of_turn>model
|
||||
3
template/testdata/gemma-instruct.gotmpl/user
vendored
Normal file
3
template/testdata/gemma-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
<start_of_turn>user
|
||||
Hello, how are you?<end_of_turn>
|
||||
<start_of_turn>model
|
||||
7
template/testdata/gemma-instruct.gotmpl/user-assistant-user
vendored
Normal file
7
template/testdata/gemma-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
<start_of_turn>user
|
||||
Hello, how are you?<end_of_turn>
|
||||
<start_of_turn>model
|
||||
I'm doing great. How can I help you today?<end_of_turn>
|
||||
<start_of_turn>user
|
||||
I'd like to show off how chat templating works!<end_of_turn>
|
||||
<start_of_turn>model
|
||||
13
template/testdata/granite-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
13
template/testdata/granite-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
System:
|
||||
You are a helpful assistant.
|
||||
|
||||
Question:
|
||||
Hello, how are you?
|
||||
|
||||
Answer:
|
||||
I'm doing great. How can I help you today?
|
||||
|
||||
Question:
|
||||
I'd like to show off how chat templating works!
|
||||
|
||||
Answer:
|
||||
4
template/testdata/granite-instruct.gotmpl/user
vendored
Normal file
4
template/testdata/granite-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
Question:
|
||||
Hello, how are you?
|
||||
|
||||
Answer:
|
||||
10
template/testdata/granite-instruct.gotmpl/user-assistant-user
vendored
Normal file
10
template/testdata/granite-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
Question:
|
||||
Hello, how are you?
|
||||
|
||||
Answer:
|
||||
I'm doing great. How can I help you today?
|
||||
|
||||
Question:
|
||||
I'd like to show off how chat templating works!
|
||||
|
||||
Answer:
|
||||
5
template/testdata/llama2-chat.gotmpl/system-user-assistant-user
vendored
Normal file
5
template/testdata/llama2-chat.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
[INST] <<SYS>>
|
||||
You are a helpful assistant.
|
||||
<</SYS>>
|
||||
|
||||
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]
|
||||
3
template/testdata/llama2-chat.gotmpl/user
vendored
Normal file
3
template/testdata/llama2-chat.gotmpl/user
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[INST] <<SYS>><</SYS>>
|
||||
|
||||
Hello, how are you? [/INST]
|
||||
3
template/testdata/llama2-chat.gotmpl/user-assistant-user
vendored
Normal file
3
template/testdata/llama2-chat.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[INST] <<SYS>><</SYS>>
|
||||
|
||||
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]
|
||||
10
template/testdata/llama3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
10
template/testdata/llama3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
I'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
I'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
4
template/testdata/llama3-instruct.gotmpl/user
vendored
Normal file
4
template/testdata/llama3-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
8
template/testdata/llama3-instruct.gotmpl/user-assistant-user
vendored
Normal file
8
template/testdata/llama3-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
I'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
I'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
12
template/testdata/magicoder.gotmpl/system-user-assistant-user
vendored
Normal file
12
template/testdata/magicoder.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
You are a helpful assistant.
|
||||
|
||||
@@ Instruction
|
||||
Hello, how are you?
|
||||
|
||||
@@ Response
|
||||
I'm doing great. How can I help you today?
|
||||
|
||||
@@ Instruction
|
||||
I'd like to show off how chat templating works!
|
||||
|
||||
@@ Response
|
||||
4
template/testdata/magicoder.gotmpl/user
vendored
Normal file
4
template/testdata/magicoder.gotmpl/user
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
@@ Instruction
|
||||
Hello, how are you?
|
||||
|
||||
@@ Response
|
||||
10
template/testdata/magicoder.gotmpl/user-assistant-user
vendored
Normal file
10
template/testdata/magicoder.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
@@ Instruction
|
||||
Hello, how are you?
|
||||
|
||||
@@ Response
|
||||
I'm doing great. How can I help you today?
|
||||
|
||||
@@ Instruction
|
||||
I'd like to show off how chat templating works!
|
||||
|
||||
@@ Response
|
||||
2
template/testdata/mistral-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
2
template/testdata/mistral-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
[INST] Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] You are a helpful assistant.
|
||||
I'd like to show off how chat templating works![/INST]
|
||||
1
template/testdata/mistral-instruct.gotmpl/user
vendored
Normal file
1
template/testdata/mistral-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1 @@
|
||||
[INST] Hello, how are you?[/INST]
|
||||
1
template/testdata/mistral-instruct.gotmpl/user-assistant-user
vendored
Normal file
1
template/testdata/mistral-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1 @@
|
||||
[INST] Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works![/INST]
|
||||
1
template/testdata/openchat.gotmpl/system-user-assistant-user
vendored
Normal file
1
template/testdata/openchat.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1 @@
|
||||
GPT Correct System: You are a helpful assistant.<|end_of_turn|>GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:
|
||||
1
template/testdata/openchat.gotmpl/user
vendored
Normal file
1
template/testdata/openchat.gotmpl/user
vendored
Normal file
@@ -0,0 +1 @@
|
||||
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant:
|
||||
1
template/testdata/openchat.gotmpl/user-assistant-user
vendored
Normal file
1
template/testdata/openchat.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1 @@
|
||||
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:
|
||||
9
template/testdata/phi-3.gotmpl/system-user-assistant-user
vendored
Normal file
9
template/testdata/phi-3.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
<|system|>
|
||||
You are a helpful assistant.<|end|>
|
||||
<|user|>
|
||||
Hello, how are you?<|end|>
|
||||
<|assistant|>
|
||||
I'm doing great. How can I help you today?<|end|>
|
||||
<|user|>
|
||||
I'd like to show off how chat templating works!<|end|>
|
||||
<|assistant|>
|
||||
3
template/testdata/phi-3.gotmpl/user
vendored
Normal file
3
template/testdata/phi-3.gotmpl/user
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
<|user|>
|
||||
Hello, how are you?<|end|>
|
||||
<|assistant|>
|
||||
7
template/testdata/phi-3.gotmpl/user-assistant-user
vendored
Normal file
7
template/testdata/phi-3.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
<|user|>
|
||||
Hello, how are you?<|end|>
|
||||
<|assistant|>
|
||||
I'm doing great. How can I help you today?<|end|>
|
||||
<|user|>
|
||||
I'd like to show off how chat templating works!<|end|>
|
||||
<|assistant|>
|
||||
13
template/testdata/solar-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
13
template/testdata/solar-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
### System:
|
||||
You are a helpful assistant.
|
||||
|
||||
### User:
|
||||
Hello, how are you?
|
||||
|
||||
### Assistant:
|
||||
I'm doing great. How can I help you today?</s>
|
||||
|
||||
### User:
|
||||
I'd like to show off how chat templating works!
|
||||
|
||||
### Assistant:
|
||||
4
template/testdata/solar-instruct.gotmpl/user
vendored
Normal file
4
template/testdata/solar-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
### User:
|
||||
Hello, how are you?
|
||||
|
||||
### Assistant:
|
||||
10
template/testdata/solar-instruct.gotmpl/user-assistant-user
vendored
Normal file
10
template/testdata/solar-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
### User:
|
||||
Hello, how are you?
|
||||
|
||||
### Assistant:
|
||||
I'm doing great. How can I help you today?</s>
|
||||
|
||||
### User:
|
||||
I'd like to show off how chat templating works!
|
||||
|
||||
### Assistant:
|
||||
12
template/testdata/starcoder2-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
12
template/testdata/starcoder2-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
You are a helpful assistant.
|
||||
|
||||
### Instruction
|
||||
Hello, how are you?
|
||||
|
||||
### Response
|
||||
I'm doing great. How can I help you today?<|endoftext|>
|
||||
|
||||
### Instruction
|
||||
I'd like to show off how chat templating works!
|
||||
|
||||
### Response
|
||||
4
template/testdata/starcoder2-instruct.gotmpl/user
vendored
Normal file
4
template/testdata/starcoder2-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
### Instruction
|
||||
Hello, how are you?
|
||||
|
||||
### Response
|
||||
10
template/testdata/starcoder2-instruct.gotmpl/user-assistant-user
vendored
Normal file
10
template/testdata/starcoder2-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
### Instruction
|
||||
Hello, how are you?
|
||||
|
||||
### Response
|
||||
I'm doing great. How can I help you today?<|endoftext|>
|
||||
|
||||
### Instruction
|
||||
I'd like to show off how chat templating works!
|
||||
|
||||
### Response
|
||||
6
template/testdata/vicuna.gotmpl/system-user-assistant-user
vendored
Normal file
6
template/testdata/vicuna.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
You are a helpful assistant.
|
||||
|
||||
USER: Hello, how are you?
|
||||
ASSISTANT: I'm doing great. How can I help you today?</s>
|
||||
USER: I'd like to show off how chat templating works!
|
||||
ASSISTANT:
|
||||
2
template/testdata/vicuna.gotmpl/user
vendored
Normal file
2
template/testdata/vicuna.gotmpl/user
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
USER: Hello, how are you?
|
||||
ASSISTANT:
|
||||
4
template/testdata/vicuna.gotmpl/user-assistant-user
vendored
Normal file
4
template/testdata/vicuna.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
USER: Hello, how are you?
|
||||
ASSISTANT: I'm doing great. How can I help you today?</s>
|
||||
USER: I'd like to show off how chat templating works!
|
||||
ASSISTANT:
|
||||
9
template/testdata/zephyr.gotmpl/system-user-assistant-user
vendored
Normal file
9
template/testdata/zephyr.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
<|system|>
|
||||
You are a helpful assistant.</s>
|
||||
<|user|>
|
||||
Hello, how are you?</s>
|
||||
<|assistant|>
|
||||
I'm doing great. How can I help you today?</s>
|
||||
<|user|>
|
||||
I'd like to show off how chat templating works!</s>
|
||||
<|assistant|>
|
||||
3
template/testdata/zephyr.gotmpl/user
vendored
Normal file
3
template/testdata/zephyr.gotmpl/user
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
<|user|>
|
||||
Hello, how are you?</s>
|
||||
<|assistant|>
|
||||
7
template/testdata/zephyr.gotmpl/user-assistant-user
vendored
Normal file
7
template/testdata/zephyr.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
<|user|>
|
||||
Hello, how are you?</s>
|
||||
<|assistant|>
|
||||
I'm doing great. How can I help you today?</s>
|
||||
<|user|>
|
||||
I'd like to show off how chat templating works!</s>
|
||||
<|assistant|>
|
||||
@@ -1,3 +1,14 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}{{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}USER: {{ .Content }}
|
||||
{{ else if eq .Role "assistant" }}ASSISTANT: {{ .Content }}</s>
|
||||
{{ end }}
|
||||
{{- end }}ASSISTANT:
|
||||
{{- else }}
|
||||
{{ if .System }}{{ .System }}
|
||||
{{ end }}{{ if .Prompt }}USER: {{ .Prompt }}
|
||||
{{ end }}ASSISTANT: {{ .Response }}
|
||||
{{ end }}ASSISTANT: {{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1,6 +1,15 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}<|system|>
|
||||
{{ .System }}</s>
|
||||
{{ end }}
|
||||
{{- range .Messages }}<|{{ .Role }}|>
|
||||
{{ .Content }}</s>
|
||||
{{ end }}<|assistant|>
|
||||
{{ else }}
|
||||
{{ if .System }}<|system|>
|
||||
{{ .System }}</s>
|
||||
{{ end }}{{ if .Prompt }}<|user|>
|
||||
{{ .Prompt }}</s>
|
||||
{{ end }}<|assistant|>
|
||||
{{ .Response }}</s>
|
||||
{{ .Response }}</s>
|
||||
{{- end }}
|
||||
Reference in New Issue
Block a user