Compare commits

...

53 Commits

Author SHA1 Message Date
Josh Yan
b6c7d01af3 more cmt rmv 2024-07-11 17:21:36 -07:00
Josh Yan
9d517cf556 rm comment 2024-07-11 17:20:09 -07:00
Josh Yan
6bab0e2368 lint 2024-07-10 12:36:32 -07:00
Josh Yan
c4cccaf936 remove rebase err 2024-07-10 11:37:55 -07:00
Josh Yan
9fe5c393e4 hi 2024-07-10 11:35:01 -07:00
Josh Yan
007c988dba rmv double msg 2024-07-10 11:34:31 -07:00
Josh Yan
91d21e7c7b rmv double msg 2024-07-10 11:34:31 -07:00
Josh Yan
3e64284f69 percent 2024-07-10 11:34:31 -07:00
Josh Yan
39910f2ab2 percent 2024-07-10 11:34:29 -07:00
Josh Yan
96d0cd92f2 rebase 2024-07-10 11:31:53 -07:00
Josh Yan
3a724a7c80 isLocal firstdraft 2024-07-10 11:31:12 -07:00
Josh Yan
f520f0056e rm config 2024-07-10 11:29:51 -07:00
Josh Yan
d25f85ede4 on disk copy 2024-07-10 11:29:49 -07:00
Josh Yan
b48420b74b percent 2024-07-10 11:27:33 -07:00
Josh Yan
784958a1cb transfer data 2024-07-10 11:24:50 -07:00
Josh Yan
ae65cc8dea progress 2024-07-10 11:24:48 -07:00
Josh Yan
a037528bba lint 2024-07-10 11:20:02 -07:00
Josh Yan
04bf41deb5 clean 2024-07-10 11:20:02 -07:00
Josh Yan
c23cec9547 removed cmt and prints 2024-07-10 11:20:02 -07:00
Josh Yan
8377dc48d0 removed client isLocal() 2024-07-10 11:20:02 -07:00
Josh Yan
3aee405dfa lint 2024-07-10 11:20:02 -07:00
Josh Yan
9b3f47b674 lint 2024-07-10 11:20:02 -07:00
Josh Yan
f5441f01a2 lint 2024-07-10 11:20:02 -07:00
Josh Yan
ab165df43a syscopy windows 2024-07-10 11:20:02 -07:00
Josh Yan
79cc4c9585 os copy 2024-07-10 11:20:02 -07:00
Josh Yan
bc3f59a6ad rmv prints 2024-07-10 11:20:02 -07:00
Josh Yan
1a85cb904c local copy 2024-07-10 11:20:02 -07:00
Josh Yan
10ea0987e9 isLocal firstdraft 2024-07-10 11:19:50 -07:00
Josh Yan
413d368a6a clean 2024-07-10 11:19:32 -07:00
Josh Yan
cabf375059 rm bench 2024-07-10 11:19:32 -07:00
Josh Yan
ca0ee1d4fe rm config 2024-07-10 11:19:32 -07:00
Josh Yan
1142999aab rm config 2024-07-10 11:19:32 -07:00
Josh Yan
0d5a72aba9 clean 2024-07-10 11:19:32 -07:00
Josh Yan
ea837412c2 local path 2024-07-10 11:19:32 -07:00
Josh Yan
736ad6f438 still works 2024-07-10 11:19:32 -07:00
Josh Yan
64607d16a5 working 2024-07-10 11:19:32 -07:00
Josh Yan
a6cfe7f00b benchmark 2024-07-10 11:19:32 -07:00
Josh Yan
c3b411a515 on disk copy 2024-07-10 11:19:32 -07:00
Josh Yan
928f37e3ae start tests 2024-07-10 11:19:32 -07:00
Daniel Hiltgen
2d1e3c3229 Merge pull request #5503 from dhiltgen/dual_rocm
Workaround broken ROCm p2p copy
2024-07-09 15:44:16 -07:00
royjhan
4918fae535 OpenAI v1/completions: allow stop token list (#5551)
* stop token parsing fix

* add stop test
2024-07-09 14:01:26 -07:00
royjhan
0aff67877e separate request tests (#5578) 2024-07-09 13:48:31 -07:00
Daniel Hiltgen
9544a57ee4 Merge pull request #5579 from dhiltgen/win_static_deps
Statically link c++ and thread lib on windows
2024-07-09 12:21:13 -07:00
Daniel Hiltgen
b51e3b63ac Statically link c++ and thread lib
This makes sure we statically link the c++ and thread library on windows
to avoid unnecessary runtime dependencies on non-standard DLLs
2024-07-09 11:34:30 -07:00
Michael Yang
6bbbc50f10 Merge pull request #5440 from ollama/mxyng/messages-templates
update named templates
2024-07-09 09:36:32 -07:00
Michael Yang
9bbddc37a7 Merge pull request #5126 from ollama/mxyng/messages
update message processing
2024-07-09 09:20:44 -07:00
Jeffrey Morgan
e4ff73297d server: fix model reloads when setting OLLAMA_NUM_PARALLEL (#5560)
* server: fix unneeded model reloads when setting `OLLAMA_NUM_PARALLEL`

* remove whitespace change

* undo some changes
2024-07-08 22:32:15 -07:00
Daniel Hiltgen
0bacb30007 Workaround broken ROCm p2p copy
Enable the build flag for llama.cpp to use CPU copy for multi-GPU scenarios.
2024-07-08 09:40:52 -07:00
Michael Yang
fb6cbc02fb update named templates 2024-07-05 16:29:32 -07:00
Michael Yang
326363b3a7 no funcs 2024-07-05 13:17:25 -07:00
Michael Yang
ac7a842e55 fix model reloading
ensure runtime model changes (template, system prompt, messages,
options) are captured on model updates without needing to reload the
server
2024-07-05 13:17:25 -07:00
Michael Yang
2c3fe1fd97 comments 2024-07-05 13:17:24 -07:00
Michael Yang
269ed6e6a2 update message processing 2024-07-05 13:16:58 -07:00
95 changed files with 1806 additions and 953 deletions

View File

@@ -304,11 +304,6 @@ jobs:
write-host "Installing plugin"
& "${env:RUNNER_TEMP}\plugin\*\kmscng.msi" /quiet
write-host "plugin installed"
- name: remove unwanted mingw dll.a files
run: |
Get-ChildItem -Path "C:\mingw64" -Recurse -Filter "libpthread.dll.a" -File | Remove-Item -Force
Get-ChildItem -Path "C:\mingw64" -Recurse -Filter "libwinpthread.dll.a" -File | Remove-Item -Force
Get-ChildItem -Path "C:\mingw64" -Recurse -Filter "libstdc++.dll.a" -File | Remove-Item -Force
- uses: actions/setup-go@v5
with:
go-version-file: go.mod

View File

@@ -17,14 +17,20 @@ import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/version"
@@ -374,3 +380,27 @@ func (c *Client) Version(ctx context.Context) (string, error) {
return version.Version, nil
}
func Authorization(ctx context.Context, request *http.Request) (string, error) {
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
if err != nil {
return "", err
}
defer knownHostsFile.Close()
token, err := auth.Sign(ctx, data)
if err != nil {
return "", err
}
// interleave request data into the token
key, sig, _ := strings.Cut(token, ":")
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
}

View File

@@ -10,42 +10,37 @@ import (
"log/slog"
"os"
"path/filepath"
"strings"
"golang.org/x/crypto/ssh"
)
const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) {
func keyPath() (ssh.Signer, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
return nil, err
}
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
return nil, err
}
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
return ssh.ParsePrivateKey(privateKeyFile)
}
func GetPublicKey() (ssh.PublicKey, error) {
privateKey, err := keyPath()
// if privateKey, try public key directly
if err != nil {
return "", err
return nil, err
}
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
return strings.TrimSpace(string(publicKey)), nil
return privateKey.PublicKey(), nil
}
func NewNonce(r io.Reader, length int) (string, error) {
@@ -58,25 +53,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
}
func Sign(ctx context.Context, bts []byte) (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
}
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
}
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
privateKey, err := keyPath()
if err != nil {
return "", err
}
// get the pubkey, but remove the type
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
parts := bytes.Split(publicKey, []byte(" "))
publicKey, err := GetPublicKey()
if err != nil {
return "", err
}
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
parts := bytes.Split(publicKeyBytes, []byte(" "))
if len(parts) < 2 {
return "", fmt.Errorf("malformed public key")
}

View File

@@ -7,6 +7,7 @@ import (
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
@@ -15,6 +16,7 @@ import (
"math"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
@@ -78,6 +80,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
status := "transferring model data"
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer p.Stop()
for i := range modelfile.Commands {
switch modelfile.Commands[i].Name {
@@ -112,11 +115,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = tempfile
}
digest, err := createBlob(cmd, client, path)
digest, err := createBlob(cmd, client, path, spinner)
if err != nil {
return err
}
modelfile.Commands[i].Args = "@" + digest
}
}
@@ -138,7 +140,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner.Stop()
status = resp.Status
spinner = progress.NewSpinner(status)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
}
@@ -263,13 +265,22 @@ func tempZipFiles(path string) (string, error) {
return tempfile.Name(), nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
var ErrBlobExists = errors.New("blob exists")
func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
bin, err := os.Open(path)
if err != nil {
return "", err
}
defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return "", err
@@ -279,13 +290,151 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
return "", err
}
var pw progressWriter
status := "transferring model data 0%"
spinner.SetMessage(status)
ticker := time.NewTicker(60 * time.Millisecond)
done := make(chan struct{})
defer close(done)
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n/fileSize)))
case <-done:
spinner.SetMessage("transferring model data 100%")
return
}
}
}()
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
// We check if we can find the models directory locally
// If we can, we return the path to the directory
// If we can't, we return an error
// If the blob exists already, we return the digest
dest, err := getLocalPath(cmd.Context(), digest)
if errors.Is(err, ErrBlobExists) {
return digest, nil
}
// Successfully found the model directory
if err == nil {
// Copy blob in via OS specific copy
// Linux errors out to use io.copy
err = localCopy(path, dest)
if err == nil {
return digest, nil
}
// Default copy using io.copy
err = defaultCopy(path, dest)
if err == nil {
return digest, nil
}
}
// If at any point copying the blob over locally fails, we default to the copy through the server
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
}
type progressWriter struct {
n int64
}
func (w *progressWriter) Write(p []byte) (n int, err error) {
w.n += int64(len(p))
return len(p), nil
}
func getLocalPath(ctx context.Context, digest string) (string, error) {
ollamaHost := envconfig.Host
client := http.DefaultClient
base := &url.URL{
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
}
data, err := json.Marshal(digest)
if err != nil {
return "", err
}
reqBody := bytes.NewReader(data)
path := fmt.Sprintf("/api/blobs/%s", digest)
requestURL := base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
if err != nil {
return "", err
}
authz, err := api.Authorization(ctx, request)
if err != nil {
return "", err
}
request.Header.Set("Authorization", authz)
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
request.Header.Set("X-Redirect-Create", "1")
resp, err := client.Do(request)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusTemporaryRedirect {
dest := resp.Header.Get("LocalLocation")
return dest, nil
}
return "", ErrBlobExists
}
func defaultCopy(path string, dest string) error {
// This function should be called if the server is local
// It should find the model directory, copy the blob over, and return the digest
dirPath := filepath.Dir(dest)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
// Copy blob over
sourceFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("could not open source file: %v", err)
}
defer sourceFile.Close()
destFile, err := os.Create(dest)
if err != nil {
return fmt.Errorf("could not create destination file: %v", err)
}
defer destFile.Close()
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
if err != nil {
return fmt.Errorf("error copying file: %v", err)
}
err = destFile.Sync()
if err != nil {
return fmt.Errorf("error flushing file: %v", err)
}
return nil
}
func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true
@@ -379,11 +528,13 @@ func errFromUnknownKey(unknownKeyErr error) error {
if len(matches) > 0 {
serverPubKey := matches[0]
localPubKey, err := auth.GetPublicKey()
publicKey, err := auth.GetPublicKey()
if err != nil {
return unknownKeyErr
}
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")

23
cmd/copy_darwin.go Normal file
View File

@@ -0,0 +1,23 @@
package cmd
import (
"os"
"path/filepath"
"golang.org/x/sys/unix"
)
func localCopy(src, target string) error {
dirPath := filepath.Dir(target)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
err := unix.Clonefile(src, target, 0)
if err != nil {
return err
}
return nil
}

7
cmd/copy_linux.go Normal file
View File

@@ -0,0 +1,7 @@
package cmd
import "errors"
func localCopy(src, target string) error {
return errors.New("no local copy implementation for linux")
}

67
cmd/copy_windows.go Normal file
View File

@@ -0,0 +1,67 @@
//go:build windows
// +build windows
package cmd
import (
"os"
"path/filepath"
"syscall"
"unsafe"
)
func localCopy(src, target string) error {
// Create target directory if it doesn't exist
dirPath := filepath.Dir(target)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
// Open source file
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
// Create target file
targetFile, err := os.Create(target)
if err != nil {
return err
}
defer targetFile.Close()
// Use CopyFileExW to copy the file
err = copyFileEx(src, target)
if err != nil {
return err
}
return nil
}
func copyFileEx(src, dst string) error {
kernel32 := syscall.NewLazyDLL("kernel32.dll")
copyFileEx := kernel32.NewProc("CopyFileExW")
srcPtr, err := syscall.UTF16PtrFromString(src)
if err != nil {
return err
}
dstPtr, err := syscall.UTF16PtrFromString(dst)
if err != nil {
return err
}
r1, _, err := copyFileEx.Call(
uintptr(unsafe.Pointer(srcPtr)),
uintptr(unsafe.Pointer(dstPtr)),
0, 0, 0, 0)
if r1 == 0 {
return err
}
return nil
}

3
go.mod
View File

@@ -18,6 +18,7 @@ require (
require (
github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/google/go-cmp v0.6.0
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
@@ -71,7 +72,7 @@ require (
golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.20.0
golang.org/x/term v0.20.0
golang.org/x/text v0.15.0 // indirect
golang.org/x/text v0.15.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -254,7 +254,7 @@ if [ -z "${OLLAMA_SKIP_ROCM_GENERATE}" -a -d "${ROCM_PATH}" ]; then
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
fi
init_vars
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DLLAMA_CUDA_NO_PEER_COPY=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""

View File

@@ -366,6 +366,7 @@ function build_rocm() {
"-DCMAKE_C_COMPILER=clang.exe",
"-DCMAKE_CXX_COMPILER=clang++.exe",
"-DGGML_HIPBLAS=on",
"-DLLAMA_CUDA_NO_PEER_COPY=on",
"-DHIP_PLATFORM=amd",
"-DGGML_AVX=on",
"-DGGML_AVX2=off",

View File

@@ -4,8 +4,8 @@ package llm
// #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
// #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
// #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
// #cgo windows,amd64 LDFLAGS: -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
// #cgo windows,arm64 LDFLAGS: -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
// #cgo windows,amd64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
// #cgo windows,arm64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
// #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
// #include <stdlib.h>

View File

@@ -679,7 +679,7 @@ type CompletionRequest struct {
Prompt string
Format string
Images []ImageData
Options api.Options
Options *api.Options
}
type CompletionResponse struct {

View File

@@ -338,13 +338,17 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []string:
options["stop"] = stop
default:
if r.Stop != nil {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
case []any:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
stops = append(stops, str)
} else {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
}
}
options["stop"] = stops
}
if r.MaxTokens != nil {
options["num_predict"] = *r.MaxTokens

View File

@@ -3,7 +3,6 @@ package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@@ -16,7 +15,133 @@ import (
"github.com/stretchr/testify/assert"
)
func TestMiddleware(t *testing.T) {
func TestMiddlewareRequests(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
Handler func() gin.HandlerFunc
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *http.Request)
}
var capturedRequest *http.Request
captureRequestMiddleware := func() gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
capturedRequest = c.Request
c.Next()
}
}
testCases := []testCase{
{
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
}
if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
}
if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
Handler: CompletionsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var genReq api.GenerateRequest
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
t.Fatal(err)
}
if genReq.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
}
if genReq.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
}
stopTokens, ok := genReq.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
},
},
}
gin.SetMode(gin.TestMode)
router := gin.New()
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
router = gin.New()
router.Use(captureRequestMiddleware())
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, endpoint)
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest)
})
}
}
func TestMiddlewareResponses(t *testing.T) {
type testCase struct {
Name string
Method string
@@ -30,159 +155,7 @@ func TestMiddleware(t *testing.T) {
testCases := []testCase{
{
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
TestPath: "/api/chat",
Handler: ChatMiddleware,
Endpoint: func(c *gin.Context) {
var chatReq api.ChatRequest
if err := c.ShouldBindJSON(&chatReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
userMessage := chatReq.Messages[0].Content
var assistantMessage string
switch userMessage {
case "Hello":
assistantMessage = "Hello!"
default:
assistantMessage = "I'm not sure how to respond to that."
}
c.JSON(http.StatusOK, api.ChatResponse{
Message: api.Message{
Role: "assistant",
Content: assistantMessage,
},
})
},
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var chatResp ChatCompletion
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
t.Fatal(err)
}
if chatResp.Object != "chat.completion" {
t.Fatalf("expected chat.completion, got %s", chatResp.Object)
}
if chatResp.Choices[0].Message.Content != "Hello!" {
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.GenerateResponse{
Response: "Hello!",
})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Hello!" {
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with params",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
var generateReq api.GenerateRequest
if err := c.ShouldBindJSON(&generateReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
temperature := generateReq.Options["temperature"].(float64)
var assistantMessage string
switch temperature {
case 1.6:
assistantMessage = "Received temperature of 1.6"
default:
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
}
c.JSON(http.StatusOK, api.GenerateResponse{
Response: assistantMessage,
})
},
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Received temperature of 1.6" {
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with error",
Name: "completions handler error forwarding",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",

View File

@@ -31,6 +31,10 @@ func NewSpinner(message string) *Spinner {
return s
}
func (s *Spinner) SetMessage(message string) {
s.message = message
}
func (s *Spinner) String() string {
var sb strings.Builder
if len(s.message) > 0 {

View File

@@ -32,8 +32,11 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"golang.org/x/crypto/ssh"
)
var errCapabilityCompletion = errors.New("completion")
type Capability string
const CapabilityCompletion = Capability("completion")
@@ -62,7 +65,10 @@ type Model struct {
Template *template.Template
}
func (m *Model) Has(caps ...Capability) bool {
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
// any missing or unknown capabilities
func (m *Model) CheckCapabilities(caps ...Capability) error {
var errs []error
for _, cap := range caps {
switch cap {
case CapabilityCompletion:
@@ -81,15 +87,19 @@ func (m *Model) Has(caps ...Capability) bool {
}
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
return false
errs = append(errs, errCapabilityCompletion)
}
default:
slog.Error("unknown capability", "capability", cap)
return false
return fmt.Errorf("unknown capability: %s", cap)
}
}
return true
if err := errors.Join(errs...); err != nil {
return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
}
return nil
}
func (m *Model) String() string {
@@ -1055,11 +1065,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if anonymous {
// no user is associated with the public key, and the request requires non-anonymous access
pubKey, nestedErr := auth.GetPublicKey()
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, errUnauthorized
}
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
}
// user is associated with the public key, but is not authorized to make the request
return nil, errUnauthorized

View File

@@ -1,217 +1,83 @@
package server
import (
"fmt"
"bytes"
"context"
"log/slog"
"strings"
"text/template/parse"
"slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template"
)
// isResponseNode checks if the node contains .Response
func isResponseNode(node *parse.ActionNode) bool {
for _, cmd := range node.Pipe.Cmds {
for _, arg := range cmd.Args {
if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
if fieldNode.Ident[0] == "Response" {
type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
// pull out any system messages which should always be included in the prompt
var system []api.Message
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
if m.Role == "system" {
system = append(system, m)
return true
}
}
}
}
return false
}
})
// formatTemplateForResponse formats the template AST to:
// 1. remove all nodes after the first .Response (if generate=true)
// 2. add a .Response node to the end if it doesn't exist
// TODO(jmorganca): this should recursively cut the template before the first .Response
func formatTemplateForResponse(tmpl *template.Template, generate bool) {
var found bool
for i, node := range tmpl.Tree.Root.Nodes {
if actionNode, ok := node.(*parse.ActionNode); ok {
if isResponseNode(actionNode) {
found = true
if generate {
tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
if len(system) == 0 && m.System != "" {
// add model system prompt since it wasn't provided
system = append(system, api.Message{Role: "system", Content: m.System})
}
// always include the last message
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n - 1; i >= 0; i-- {
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
return "", nil, err
}
s, err := tokenize(ctx, b.String())
if err != nil {
return "", nil, err
}
c := len(s)
if m.ProjectorPaths != nil {
for _, m := range msgs[i:] {
// images are represented as 768 sized embeddings
// TODO: get embedding length from project metadata
c += 768 * len(m.Images)
}
}
if c > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
}
}
} else {
n = i
}
}
if !found {
// add the response node if it doesn't exist
responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
return "", nil, err
}
}
// Prompt renders a prompt from a template. If generate is set to true,
// the response and parts of the template following it are not rendered
func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
formatTemplateForResponse(tmpl, generate)
vars := map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
}
var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil {
return "", err
}
return sb.String(), nil
}
func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
rendered, err := Prompt(tmpl, system, prompt, response, false)
if err != nil {
return 0, err
}
tokens, err := encode(rendered)
if err != nil {
slog.Error("failed to encode prompt", "err", err)
return 0, err
}
return len(tokens), err
}
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
type prompt struct {
System string
Prompt string
Response string
images []int
tokens int
}
var p prompt
// iterate through messages to build up {system,user,response} prompts
var imgId int
var prompts []prompt
for _, msg := range messages {
switch strings.ToLower(msg.Role) {
case "system":
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.System = msg.Content
case "user":
if p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
var sb strings.Builder
for range msg.Images {
fmt.Fprintf(&sb, "[img-%d] ", imgId)
p.images = append(p.images, imgId)
imgId += 1
}
sb.WriteString(msg.Content)
p.Prompt = sb.String()
case "assistant":
if p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.Response = msg.Content
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// add final prompt
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
}
// calculate token lengths for each prompt, estimating 768 tokens per images
for i, p := range prompts {
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
if err != nil {
return "", err
}
prompts[i].tokens = tokens + len(prompts[i].images)*768
}
// truncate images and prompts starting from the beginning of the list
// until either one prompt remains or the total tokens fits the context window
// TODO (jmorganca): this doesn't account for the context window room required for the response
for {
var required int
for _, p := range prompts {
required += p.tokens
}
required += 1 // for bos token
if required <= window {
slog.Debug("prompt now fits in context window", "required", required, "window", window)
break
}
prompt := &prompts[0]
if len(prompt.images) > 1 {
img := prompt.images[0]
slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
prompt.images = prompt.images[1:]
prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
prompt.tokens -= 768
continue
}
if len(prompts) > 1 {
slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
system := prompt.System
prompts = prompts[1:]
if system != "" && prompts[0].System == "" {
prompts[0].System = system
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
if err != nil {
return "", err
}
prompts[0].tokens = tokens + len(prompts[0].images)*768
}
continue
}
// stop truncating if there's only one prompt left
break
}
var sb strings.Builder
for i, p := range prompts {
// last prompt should leave the response unrendered (for completion)
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
if err != nil {
return "", err
}
sb.WriteString(rendered)
}
return sb.String(), nil
for _, m := range msgs[n:] {
for _, i := range m.Images {
images = append(images, llm.ImageData{
ID: len(images),
Data: i,
})
}
}
return b.String(), images, nil
}

View File

@@ -1,6 +1,8 @@
package server
import (
"bytes"
"context"
"strings"
"testing"
@@ -8,208 +10,195 @@ import (
"github.com/ollama/ollama/template"
)
func TestPrompt(t *testing.T) {
tests := []struct {
name string
template string
system string
prompt string
response string
generate bool
want string
}{
{
name: "simple prompt",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "implicit response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
},
{
name: "response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
},
{
name: "cut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
generate: true,
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
},
{
name: "nocut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
},
func tokenize(_ context.Context, s string) (tokens []int, err error) {
for range strings.Fields(s) {
tokens = append(tokens, len(tokens))
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.Parse(tc.template)
if err != nil {
t.Fatal(err)
}
got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
if err != nil {
t.Errorf("error = %v", err)
}
if got != tc.want {
t.Errorf("got = %v, want %v", got, tc.want)
}
})
}
return
}
func TestChatPrompt(t *testing.T) {
tests := []struct {
type expect struct {
prompt string
images [][]byte
}
cases := []struct {
name string
template string
messages []api.Message
window int
want string
limit int
msgs []api.Message
expect
}{
{
name: "simple prompt",
template: "[INST] {{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "user", Content: "Hello"},
name: "messages",
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
window: 1024,
want: "[INST] Hello [/INST]",
},
{
name: "with system message",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
name: "truncate messages",
limit: 1,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "A test. And a thumping good one at that, I'd wager. ",
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
},
{
name: "with response",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
name: "truncate messages with image",
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
},
expect: expect{
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
},
{
name: "with implicit response",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
name: "truncate messages with images",
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
},
{
name: "with conversation",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "What are the potion ingredients?"},
{Role: "assistant", Content: "sugar"},
{Role: "user", Content: "Anything else?"},
name: "messages with images",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
},
{
name: "with truncation",
template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
{Role: "user", Content: "Why is the sky blue?"},
{Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
name: "message with image tag",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
},
window: 10,
want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
},
{
name: "images",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
name: "messages with interleaved images",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
},
window: 1024,
want: "You are a Wizard. [img-0] Hello",
},
{
name: "images truncated",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
name: "truncate message with interleaved images",
limit: 1024,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
},
window: 1024,
want: "You are a Wizard. [img-0] [img-1] Hello",
},
{
name: "empty list",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{},
window: 1024,
want: "",
name: "message with system prompt",
limit: 2048,
msgs: []api.Message{
{Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
{
name: "empty prompt",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
messages: []api.Message{
{Role: "user", Content: ""},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
},
window: 1024,
want: "",
},
}
encode := func(s string) ([]int, error) {
words := strings.Fields(s)
return make([]int, len(words)), nil
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.Parse(tc.template)
tmpl, err := template.Parse(`
{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}{{ .Prompt }} {{ end }}
{{- if .Response }}{{ .Response }} {{ end }}`)
if err != nil {
t.Fatal(err)
}
got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
if err != nil {
t.Errorf("error = %v", err)
t.Fatal(err)
}
if got != tc.want {
t.Errorf("got: %q, want: %q", got, tc.want)
if tt.prompt != prompt {
t.Errorf("expected %q, got %q", tt.prompt, prompt)
}
if len(images) != len(tt.images) {
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
}
for i := range images {
if images[i].ID != i {
t.Errorf("expected ID %d, got %d", i, images[i].ID)
}
if !bytes.Equal(images[i].Data, tt.images[i]) {
t.Errorf("expected %q, got %q", tt.images[i], images[i])
}
}
})
}

View File

@@ -1,13 +1,15 @@
package server
import (
"bytes"
"cmp"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log"
"log/slog"
"net"
"net/http"
@@ -22,8 +24,10 @@ import (
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/ssh"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
@@ -54,6 +58,8 @@ func init() {
gin.SetMode(mode)
}
var errRequired = errors.New("is required")
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
@@ -67,163 +73,140 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
return opts, nil
}
func isSupportedImageType(image []byte) bool {
contentType := http.DetectContentType(image)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
return slices.Contains(allowedTypes, contentType)
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
if name == "" {
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
}
model, err := GetModel(name)
if err != nil {
return nil, nil, nil, err
}
if err := model.CheckCapabilities(caps...); err != nil {
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
opts, err := modelOptions(model, requestOpts)
if err != nil {
return nil, nil, nil, err
}
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err = <-errCh:
return nil, nil, nil, err
}
return runner.llama, model, &opts, nil
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
if req.Format != "" && req.Format != "json" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
}
for _, img := range req.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
caps := []Capability{CapabilityCompletion}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
}
}
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} else if err != nil {
handleScheduleError(c, req.Model, err)
return
}
if !model.Has(CapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
// an empty request loads the model
// note: for a short while template was used in lieu
// of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" {
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
tmpl, err := template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
checkpointLoaded := time.Now()
var prompt string
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template == "" {
tmpl = model.Template
}
if req.System == "" {
req.System = model.System
}
slog.Debug("generate handler", "prompt", req.Prompt)
slog.Debug("generate handler", "template", req.Template)
slog.Debug("generate handler", "system", req.System)
var sb strings.Builder
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
fmt.Fprintf(&sb, "[img-%d] ", i)
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
sb.WriteString(req.Prompt)
prompt := req.Prompt
if !req.Raw {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
p, err := Prompt(tmpl, req.System, sb.String(), "", true)
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
tmpl := m.Template
if req.Template != "" {
tmpl, err = template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
sb.Reset()
var b bytes.Buffer
if req.Context != nil {
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.WriteString(prev)
b.WriteString(s)
}
sb.WriteString(p)
prompt = sb.String()
if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
slog.Debug("generate handler", "prompt", prompt)
prompt = b.String()
}
slog.Debug("generate request", "prompt", prompt, "images", images)
ch := make(chan any)
var generated strings.Builder
go func() {
defer close(ch)
fn := func(r llm.CompletionResponse) {
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp := api.GenerateResponse{
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
ch <- api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Response: r.Content,
Done: r.Done,
DoneReason: r.DoneReason,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
@@ -232,77 +215,35 @@ func (s *Server) GenerateHandler(c *gin.Context) {
EvalDuration: r.EvalDuration,
},
}
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = append(req.Context, tokens...)
}
}
ch <- resp
}
var images []llm.ImageData
for i := range req.Images {
images = append(images, llm.ImageData{
ID: i,
Data: req.Images[i],
})
}
// Start prediction
req := llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response
var final api.GenerateResponse
var r api.GenerateResponse
var sb strings.Builder
for resp := range ch {
switch r := resp.(type) {
for rr := range ch {
switch t := rr.(type) {
case api.GenerateResponse:
sb.WriteString(r.Response)
final = r
sb.WriteString(t.Response)
r = t
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
}
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return
}
}
final.Response = sb.String()
c.JSON(http.StatusOK, final)
r.Response = sb.String()
c.JSON(http.StatusOK, r)
return
}
@@ -311,44 +252,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
model, err := GetModel(req.Model)
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
handleScheduleError(c, req.Model, err)
return
}
@@ -358,17 +272,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
resp := api.EmbeddingResponse{
Embedding: embedding,
}
c.JSON(http.StatusOK, resp)
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
}
func (s *Server) PullModelHandler(c *gin.Context) {
@@ -649,9 +560,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
msgs := make([]api.Message, 0)
for _, msg := range m.Messages {
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
msgs := make([]api.Message, len(m.Messages))
for i, msg := range m.Messages {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
}
n := model.ParseName(req.Model)
@@ -863,7 +774,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err = os.Stat(path)
switch {
case errors.Is(err, os.ErrNotExist):
@@ -876,6 +786,12 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
return
}
if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
c.Header("LocalLocation", path)
c.Status(http.StatusTemporaryRedirect)
return
}
layer, err := NewLayer(c.Request.Body, "")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -890,6 +806,54 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusCreated)
}
func (s *Server) IsLocal(c *gin.Context) bool {
if authz := c.GetHeader("Authorization"); authz != "" {
parts := strings.Split(authz, ":")
if len(parts) != 3 {
return false
}
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
if err != nil {
return false
}
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
requestData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return false
}
partialRequestDataParts := strings.Split(string(requestData), ",")
if len(partialRequestDataParts) != 3 {
return false
}
signature, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return false
}
if err := clientPublicKey.Verify(requestData, &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
return false
}
serverPublicKey, err := auth.GetPublicKey()
if err != nil {
log.Fatal(err)
}
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
return true
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return false
}
return false
}
func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces {
@@ -1214,132 +1178,55 @@ func (s *Server) ProcessHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return runner.llama.Tokenize(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
if err != nil {
return "", err
}
return prompt, nil
}
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
caps := []Capability{CapabilityCompletion}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
} else if err != nil {
handleScheduleError(c, req.Model, err)
return
}
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !model.Has(CapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
checkpointLoaded := time.Now()
// if the first message is not a system message, then add the model's default system message
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
req.Messages = append([]api.Message{
{
Role: "system",
Content: model.System,
},
}, req.Messages...)
}
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// an empty request loads the model
if len(req.Messages) == 0 || prompt == "" {
resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant"},
Done: true,
DoneReason: "load",
Message: api.Message{Role: "assistant"},
}
c.JSON(http.StatusOK, resp)
})
return
}
// only send images that are in the prompt
var i int
var images []llm.ImageData
for _, m := range req.Messages {
for _, img := range m.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
images = append(images, llm.ImageData{Data: img, ID: i})
}
i += 1
}
}
slog.Debug("chat handler", "prompt", prompt, "images", len(images))
slog.Debug("chat request", "images", len(images), "prompt", prompt)
ch := make(chan any)
go func() {
defer close(ch)
fn := func(r llm.CompletionResponse) {
resp := api.ChatResponse{
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
ch <- api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
@@ -1352,64 +1239,52 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration,
},
}
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
ch <- resp
}
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}, fn); err != nil {
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response
var final api.ChatResponse
var r api.ChatResponse
var sb strings.Builder
for resp := range ch {
switch r := resp.(type) {
for rr := range ch {
switch t := rr.(type) {
case api.ChatResponse:
sb.WriteString(r.Message.Content)
final = r
sb.WriteString(t.Message.Content)
r = t
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
}
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return
}
}
final.Message = api.Message{Role: "assistant", Content: sb.String()}
c.JSON(http.StatusOK, final)
r.Message.Content = sb.String()
c.JSON(http.StatusOK, r)
return
}
streamResponse(c, ch)
}
func handleErrorResponse(c *gin.Context, err error) {
if errors.Is(err, context.Canceled) {
func handleScheduleError(c *gin.Context, name string, err error) {
switch {
case errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, context.Canceled):
c.JSON(499, gin.H{"error": "request canceled"})
return
}
if errors.Is(err, ErrMaxQueue) {
case errors.Is(err, ErrMaxQueue):
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
return
}
case errors.Is(err, os.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}

View File

@@ -545,9 +545,9 @@ func TestCreateDetectTemplate(t *testing.T) {
}
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"),
filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
filepath.Join(p, "blobs", "sha256-9512c372dfc7d84d6065b8dd2b601aeed8cc1a78e7a7aa784a42fff37f5524b7"),
filepath.Join(p, "blobs", "sha256-b8b78cb8c6eefd14c06f1af042e6161255bf87bbf2dd14fce57cdac893db8139"),
})
})

View File

@@ -133,10 +133,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
numParallel = 1
slog.Warn("multimodal models don't support parallel requests yet")
}
// Keep NumCtx and numParallel in sync
if numParallel > 1 {
pending.opts.NumCtx = pending.origNumCtx * numParallel
}
for {
cpus := s.getCpuFn()
@@ -234,9 +230,10 @@ func (s *Scheduler) processPending(ctx context.Context) {
// simplifying assumption of defaultParallel when in CPU mode
if numParallel <= 0 {
numParallel = defaultParallel
pending.opts.NumCtx = pending.origNumCtx * numParallel
}
pending.opts.NumCtx = pending.origNumCtx * numParallel
if loadedCount == 0 {
slog.Debug("cpu mode with first model, loading")
s.loadFn(pending, ggml, gpus, numParallel)

View File

@@ -1 +1,8 @@
{{- if .Messages }}
{{- if .System }}<start_system>{{ .System }}<end_message>
{{- end }}
{{- range .Messages }}<start_{{ .Role }}>{{ .Content }}<end_message>
{{- end }}<start_assistant>
{{- else }}
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
{{- end }}

View File

@@ -1,3 +1,14 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{- end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### Instruction:
{{- else if eq .Role "assistant" }}### Response:
{{- end }}
{{ .Content }}
{{ end }}### Response:
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction:
@@ -5,3 +16,4 @@
{{ end }}### Response:
{{ .Response }}
{{- end }}

View File

@@ -1,6 +1,15 @@
{{- if .Messages }}
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else }}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end }}

View File

@@ -1,5 +1,17 @@
{{- if .Messages }}
{{- if .System }}System: {{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}User:
{{- else if eq .Role "assistant" }}Assistant:
{{- end }} {{ .Content }}
{{ end }}Assistant:
{{- else }}
{{ if .System }}System: {{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}
{{- end }}

View File

@@ -1,3 +1,13 @@
{{- if .Messages }}
{{- if .System }}Source: system
{{ .System }} <step> {{ end }}
{{- range .Messages }}Source: {{ .Role }}
{{ .Content }} <step> {{ end }}Source: assistant
Destination: user
{{ else }}
{{ if .System }} Source: system
{{ .System }} <step>{{ end }} Source: user
@@ -6,3 +16,4 @@
Destination: user
{{ .Response }}<step>
{{- end }}

View File

@@ -1,3 +1,13 @@
{{- if .Messages }}
{{- if .System }}System: {{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}User:
{{ else if eq .Role "assistant" }}Falcon:
{{ end }}{{ .Content }}
{{ end }}Falcon:
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: {{ .Response }}
{{- end }}

View File

@@ -1,4 +1,16 @@
{{- if .Messages }}
{{- range $index, $_ := .Messages }}<start_of_turn>
{{- if eq .Role "user" }}user
{{- if and $.System (eq $index 0) }}
{{ $.System }}
{{- end }}
{{- else if eq .Role "assistant" }}model
{{- end }}
{{ .Content }}<end_of_turn>
{{ end }}<start_of_turn>model
{{ else }}
<start_of_turn>user
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
<start_of_turn>model
{{ .Response }}<end_of_turn>
{{- end }}

View File

@@ -1,3 +1,16 @@
{{- if .Messages }}
{{- if .System }}System:
{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}Question:
{{- else if eq .Role "assistant" }}Answer:
{{- end }}
{{ .Content }}
{{ end }}Answer:
{{ else }}
{{ if .System }}
System:
{{ .System }}
@@ -7,3 +20,4 @@ System:
{{ end }}Answer:
{{ .Response }}
{{- end }}

View File

@@ -1,3 +1,16 @@
{{- if .Messages }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if eq $index 0 }}<<SYS>>
{{- if $.System }}
{{ $.System }}
{{ end }}<</SYS>>
{{ end }}{{ .Content }}
{{- else }} [/INST] {{ .Content }}</s><s>
{{- end }}
{{- end }} [/INST]
{{- else }}
[INST] <<SYS>>{{ .System }}<</SYS>>
{{ .Prompt }} [/INST] {{ .Response }}
{{- end }}

View File

@@ -1,3 +1,14 @@
{{- if .Messages }}
{{- if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
{{ .Content }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>
{{ else }}
{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
@@ -5,3 +16,4 @@
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|>
{{- end }}

View File

@@ -1,3 +1,15 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}@@ Instruction
{{- else if eq .Role "assistant" }}@@ Response
{{- end }}
{{ .Content }}
{{ end }}@@ Response
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}@@ Instruction
@@ -5,3 +17,4 @@
{{ end }}@@ Response
{{ .Response }}
{{- end }}

View File

@@ -1,6 +1,9 @@
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- if .Messages }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and $.System (eq (len (slice $.Messages $index)) 1) }}{{ $.System }}
{{ end }}{{ .Content }}
{{- else if eq .Role "assistant" }}[/INST] {{ .Content }}</s>
{{- end }}
{{- end }}[/INST]
{{- else }}[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST] {{ .Response }}
{{- end }}

View File

@@ -1 +1,11 @@
{{- if .Messages }}
{{- if .System }}GPT Correct System: {{ .System }}<|end_of_turn|>
{{- end }}
{{- range .Messages }}GPT Correct
{{- if eq .Role "user" }} User:
{{- else if eq .Role "assistant" }} Assistant:
{{- end }} {{ .Content }}<|end_of_turn|>
{{- end }}GPT Correct Assistant:
{{- else }}
{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
{{- end }}

View File

@@ -1,6 +1,15 @@
{{- if .Messages }}
{{- if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}
{{- range .Messages }}<|{{ .Role }}|>
{{ .Content }}<|end|>
{{ end }}<|assistant|>
{{ else }}
{{ if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}<|end|>
{{ end }}<|assistant|>
{{ .Response }}<|end|>
{{- end }}

View File

@@ -1,3 +1,16 @@
{{- if .Messages }}
{{- if .System }}### System:
{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### User:
{{ .Content }}
{{ else if eq .Role "assistant" }}### Assistant:
{{ .Content }}</s>
{{ end }}
{{ end }}### Assistant:
{{ else }}
{{ if .System }}### System:
{{ .System }}
@@ -6,3 +19,4 @@
{{ end }}### Assistant:
{{ .Response }}
{{- end }}

View File

@@ -1,3 +1,17 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### Instruction
{{ .Content }}
{{ else if eq .Role "assistant" }}### Response
{{ .Content }}<|endoftext|>
{{ end }}
{{- end }}### Response
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction
@@ -7,3 +21,4 @@
{{ end }}### Response
{{ .Response }}<|endoftext|>
{{- end }}

View File

@@ -5,6 +5,7 @@ import (
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"slices"
@@ -14,6 +15,7 @@ import (
"text/template/parse"
"github.com/agnivade/levenshtein"
"github.com/ollama/ollama/api"
"golang.org/x/exp/maps"
)
@@ -74,31 +76,60 @@ func Named(s string) (*named, error) {
return nil, errors.New("no matching template found")
}
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
type Template struct {
*template.Template
raw string
}
// response is a template node that can be added to templates that don't already have one
var response = parse.ActionNode{
NodeType: parse.NodeAction,
Pipe: &parse.PipeNode{
NodeType: parse.NodePipe,
Cmds: []*parse.CommandNode{
{
NodeType: parse.NodeCommand,
Args: []parse.Node{
&parse.FieldNode{
NodeType: parse.NodeField,
Ident: []string{"Response"},
},
},
},
},
},
}
func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero")
tmpl, err := tmpl.Parse(s)
if err != nil {
return nil, err
}
t := Template{Template: tmpl, raw: s}
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
// touch up the template and append {{ .Response }}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
}
return &t, nil
}
func (t *Template) String() string {
return t.raw
}
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
func Parse(s string) (*Template, error) {
t, err := template.New("").Option("missingkey=zero").Parse(s)
if err != nil {
return nil, err
}
return &Template{Template: t, raw: s}, nil
}
func (t *Template) Vars() []string {
var vars []string
for _, n := range t.Tree.Root.Nodes {
for _, tt := range t.Templates() {
for _, n := range tt.Root.Nodes {
vars = append(vars, parseNode(n)...)
}
}
set := make(map[string]struct{})
for _, n := range vars {
@@ -110,6 +141,103 @@ func (t *Template) Vars() []string {
return vars
}
type Values struct {
Messages []api.Message
}
func (t *Template) Execute(w io.Writer, v Values) error {
system, collated := collate(v.Messages)
if slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": collated,
})
}
var b bytes.Buffer
var prompt, response string
for i, m := range collated {
if m.Role == "user" {
prompt = m.Content
} else {
response = m.Content
}
if i != len(collated)-1 && prompt != "" && response != "" {
if err := t.Template.Execute(&b, map[string]any{
"System": "",
"Prompt": prompt,
"Response": response,
}); err != nil {
return err
}
prompt = ""
response = ""
}
}
var cut bool
tree := t.Template.Copy()
// for the last message, cut everything after "{{ .Response }}"
tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool {
if slices.Contains(parseNode(n), "Response") {
cut = true
}
return cut
})
if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
}); err != nil {
return err
}
_, err := io.Copy(w, &b)
return err
}
type messages []*api.Message
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system"
// which are templated separately. As a side effect, it mangles message content adding image
// tags ([img-%d]) as needed
func collate(msgs []api.Message) (system string, collated messages) {
var n int
for i := range msgs {
msg := msgs[i]
if msg.Role == "system" {
if system != "" {
system += "\n\n"
}
system += msg.Content
continue
}
for range msg.Images {
imageTag := fmt.Sprintf("[img-%d]", n)
if !strings.Contains(msg.Content, "[img]") {
msg.Content = strings.TrimSpace("[img] " + msg.Content)
}
msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
n++
}
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
collated[len(collated)-1].Content += "\n\n" + msg.Content
} else {
collated = append(collated, &msg)
}
}
return
}
func parseNode(n parse.Node) []string {
switch n := n.(type) {
case *parse.ActionNode:
@@ -152,6 +280,8 @@ func parseNode(n parse.Node) []string {
return names
case *parse.FieldNode:
return n.Ident
case *parse.TemplateNode:
return parseNode(n.Pipe)
}
return nil

View File

@@ -8,9 +8,11 @@ import (
"os"
"path/filepath"
"slices"
"strings"
"testing"
"text/template"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
@@ -46,7 +48,7 @@ func TestNamed(t *testing.T) {
t.Fatal(err)
}
tmpl, err := template.New(s).Parse(b.String())
tmpl, err := Parse(b.String())
if err != nil {
t.Fatal(err)
}
@@ -59,18 +61,81 @@ func TestNamed(t *testing.T) {
}
}
func TestTemplate(t *testing.T) {
cases := make(map[string][]api.Message)
for _, mm := range [][]api.Message{
{
{Role: "user", Content: "Hello, how are you?"},
},
{
{Role: "user", Content: "Hello, how are you?"},
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
{Role: "user", Content: "I'd like to show off how chat templating works!"},
},
{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
{Role: "user", Content: "I'd like to show off how chat templating works!"},
},
} {
var roles []string
for _, m := range mm {
roles = append(roles, m.Role)
}
cases[strings.Join(roles, "-")] = mm
}
matches, err := filepath.Glob("*.gotmpl")
if err != nil {
t.Fatal(err)
}
for _, match := range matches {
t.Run(match, func(t *testing.T) {
bts, err := os.ReadFile(match)
if err != nil {
t.Fatal(err)
}
tmpl, err := Parse(string(bts))
if err != nil {
t.Fatal(err)
}
for n, tt := range cases {
t.Run(n, func(t *testing.T) {
var actual bytes.Buffer
if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
t.Fatal(err)
}
expect, err := os.ReadFile(filepath.Join("testdata", match, n))
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(actual.Bytes(), expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
})
}
}
func TestParse(t *testing.T) {
cases := []struct {
template string
vars []string
}{
{"{{ .Prompt }}", []string{"prompt"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
{"{{ .Prompt }}", []string{"prompt", "response"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}},
}
for _, tt := range cases {
@@ -87,3 +152,159 @@ func TestParse(t *testing.T) {
})
}
}
func TestExecuteWithMessages(t *testing.T) {
type template struct {
name string
template string
}
cases := []struct {
name string
templates []template
values Values
expected string
}{
{
"mistral",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
},
{
"mistral system",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant!"},
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
What is your name?[/INST] `,
},
{
"chatml",
[]template{
// this does not have a "no response" test because it's impossible to render the same output
{"response", `{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
`},
{"messages", `
{{- range $index, $_ := .Messages }}
{{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
{{ $.System }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>assistant
`},
},
Values{
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant!"},
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`<|im_start|>user
Hello friend!<|im_end|>
<|im_start|>assistant
Hello human!<|im_end|>
<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
What is your name?<|im_end|>
<|im_start|>assistant
`,
},
{
"moondream",
[]template{
// this does not have a "no response" test because it's impossible to render the same output
{"response", `{{ if .Prompt }}Question: {{ .Prompt }}
{{ end }}Answer: {{ .Response }}
`},
{"messages", `
{{- range .Messages }}
{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
{{- end }}
{{- end }}Answer: `},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
{Role: "assistant", Content: "It's a hot dog."},
{Role: "user", Content: "What's in _this_ image?"},
{Role: "user", Images: []api.ImageData{[]byte("")}},
{Role: "user", Content: "Is it a hot dog?"},
},
},
`Question: [img-0] What's in this image?
Answer: It's a hot dog.
Question: What's in _this_ image?
[img-1]
Is it a hot dog?
Answer: `,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
for _, ttt := range tt.templates {
t.Run(ttt.name, func(t *testing.T) {
tmpl, err := Parse(ttt.template)
if err != nil {
t.Fatal(err)
}
var b bytes.Buffer
if err := tmpl.Execute(&b, tt.values); err != nil {
t.Fatal(err)
}
if b.String() != tt.expected {
t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
}
})
}
})
}
}

View File

@@ -0,0 +1 @@
<start_system>You are a helpful assistant.<end_message><start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>

1
template/testdata/alfred.gotmpl/user vendored Normal file
View File

@@ -0,0 +1 @@
<start_user>Hello, how are you?<end_message><start_assistant>

View File

@@ -0,0 +1 @@
<start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>

View File

@@ -0,0 +1,10 @@
You are a helpful assistant.### Instruction:
Hello, how are you?
### Response:
I'm doing great. How can I help you today?
### Instruction:
I'd like to show off how chat templating works!
### Response:

4
template/testdata/alpaca.gotmpl/user vendored Normal file
View File

@@ -0,0 +1,4 @@
### Instruction:
Hello, how are you?
### Response:

View File

@@ -0,0 +1,10 @@
### Instruction:
Hello, how are you?
### Response:
I'm doing great. How can I help you today?
### Instruction:
I'd like to show off how chat templating works!
### Response:

View File

@@ -0,0 +1,9 @@
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
I'm doing great. How can I help you today?<|im_end|>
<|im_start|>user
I'd like to show off how chat templating works!<|im_end|>
<|im_start|>assistant

3
template/testdata/chatml.gotmpl/user vendored Normal file
View File

@@ -0,0 +1,3 @@
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant

View File

@@ -0,0 +1,7 @@
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
I'm doing great. How can I help you today?<|im_end|>
<|im_start|>user
I'd like to show off how chat templating works!<|im_end|>
<|im_start|>assistant

View File

@@ -0,0 +1,9 @@
System: You are a helpful assistant.
User: Hello, how are you?
Assistant: I'm doing great. How can I help you today?
User: I'd like to show off how chat templating works!
Assistant:

3
template/testdata/chatqa.gotmpl/user vendored Normal file
View File

@@ -0,0 +1,3 @@
User: Hello, how are you?
Assistant:

View File

@@ -0,0 +1,7 @@
User: Hello, how are you?
Assistant: I'm doing great. How can I help you today?
User: I'd like to show off how chat templating works!
Assistant:

View File

@@ -0,0 +1,11 @@
Source: system
You are a helpful assistant. <step> Source: user
Hello, how are you? <step> Source: assistant
I'm doing great. How can I help you today? <step> Source: user
I'd like to show off how chat templating works! <step> Source: assistant
Destination: user

View File

@@ -0,0 +1,5 @@
Source: user
Hello, how are you? <step> Source: assistant
Destination: user

View File

@@ -0,0 +1,9 @@
Source: user
Hello, how are you? <step> Source: assistant
I'm doing great. How can I help you today? <step> Source: user
I'd like to show off how chat templating works! <step> Source: assistant
Destination: user

View File

@@ -0,0 +1,8 @@
System: You are a helpful assistant.
User:
Hello, how are you?
Falcon:
I'm doing great. How can I help you today?
User:
I'd like to show off how chat templating works!
Falcon:

View File

@@ -0,0 +1,3 @@
User:
Hello, how are you?
Falcon:

View File

@@ -0,0 +1,7 @@
User:
Hello, how are you?
Falcon:
I'm doing great. How can I help you today?
User:
I'd like to show off how chat templating works!
Falcon:

View File

@@ -0,0 +1,8 @@
<start_of_turn>user
You are a helpful assistant.
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model

View File

@@ -0,0 +1,3 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model

View File

@@ -0,0 +1,7 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model

View File

@@ -0,0 +1,13 @@
System:
You are a helpful assistant.
Question:
Hello, how are you?
Answer:
I'm doing great. How can I help you today?
Question:
I'd like to show off how chat templating works!
Answer:

View File

@@ -0,0 +1,4 @@
Question:
Hello, how are you?
Answer:

View File

@@ -0,0 +1,10 @@
Question:
Hello, how are you?
Answer:
I'm doing great. How can I help you today?
Question:
I'd like to show off how chat templating works!
Answer:

View File

@@ -0,0 +1,5 @@
[INST] <<SYS>>
You are a helpful assistant.
<</SYS>>
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]

View File

@@ -0,0 +1,3 @@
[INST] <<SYS>><</SYS>>
Hello, how are you? [/INST]

View File

@@ -0,0 +1,3 @@
[INST] <<SYS>><</SYS>>
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]

View File

@@ -0,0 +1,10 @@
<|start_header_id|>system<|end_header_id|>
You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>
Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
I'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>
I'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -0,0 +1,4 @@
<|start_header_id|>user<|end_header_id|>
Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -0,0 +1,8 @@
<|start_header_id|>user<|end_header_id|>
Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
I'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>
I'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -0,0 +1,12 @@
You are a helpful assistant.
@@ Instruction
Hello, how are you?
@@ Response
I'm doing great. How can I help you today?
@@ Instruction
I'd like to show off how chat templating works!
@@ Response

View File

@@ -0,0 +1,4 @@
@@ Instruction
Hello, how are you?
@@ Response

View File

@@ -0,0 +1,10 @@
@@ Instruction
Hello, how are you?
@@ Response
I'm doing great. How can I help you today?
@@ Instruction
I'd like to show off how chat templating works!
@@ Response

View File

@@ -0,0 +1,2 @@
[INST] Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] You are a helpful assistant.
I'd like to show off how chat templating works![/INST]

View File

@@ -0,0 +1 @@
[INST] Hello, how are you?[/INST]

View File

@@ -0,0 +1 @@
[INST] Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works![/INST]

View File

@@ -0,0 +1 @@
GPT Correct System: You are a helpful assistant.<|end_of_turn|>GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:

View File

@@ -0,0 +1 @@
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant:

View File

@@ -0,0 +1 @@
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:

View File

@@ -0,0 +1,9 @@
<|system|>
You are a helpful assistant.<|end|>
<|user|>
Hello, how are you?<|end|>
<|assistant|>
I'm doing great. How can I help you today?<|end|>
<|user|>
I'd like to show off how chat templating works!<|end|>
<|assistant|>

3
template/testdata/phi-3.gotmpl/user vendored Normal file
View File

@@ -0,0 +1,3 @@
<|user|>
Hello, how are you?<|end|>
<|assistant|>

View File

@@ -0,0 +1,7 @@
<|user|>
Hello, how are you?<|end|>
<|assistant|>
I'm doing great. How can I help you today?<|end|>
<|user|>
I'd like to show off how chat templating works!<|end|>
<|assistant|>

View File

@@ -0,0 +1,13 @@
### System:
You are a helpful assistant.
### User:
Hello, how are you?
### Assistant:
I'm doing great. How can I help you today?</s>
### User:
I'd like to show off how chat templating works!
### Assistant:

View File

@@ -0,0 +1,4 @@
### User:
Hello, how are you?
### Assistant:

View File

@@ -0,0 +1,10 @@
### User:
Hello, how are you?
### Assistant:
I'm doing great. How can I help you today?</s>
### User:
I'd like to show off how chat templating works!
### Assistant:

View File

@@ -0,0 +1,12 @@
You are a helpful assistant.
### Instruction
Hello, how are you?
### Response
I'm doing great. How can I help you today?<|endoftext|>
### Instruction
I'd like to show off how chat templating works!
### Response

View File

@@ -0,0 +1,4 @@
### Instruction
Hello, how are you?
### Response

View File

@@ -0,0 +1,10 @@
### Instruction
Hello, how are you?
### Response
I'm doing great. How can I help you today?<|endoftext|>
### Instruction
I'd like to show off how chat templating works!
### Response

View File

@@ -0,0 +1,6 @@
You are a helpful assistant.
USER: Hello, how are you?
ASSISTANT: I'm doing great. How can I help you today?</s>
USER: I'd like to show off how chat templating works!
ASSISTANT:

2
template/testdata/vicuna.gotmpl/user vendored Normal file
View File

@@ -0,0 +1,2 @@
USER: Hello, how are you?
ASSISTANT:

View File

@@ -0,0 +1,4 @@
USER: Hello, how are you?
ASSISTANT: I'm doing great. How can I help you today?</s>
USER: I'd like to show off how chat templating works!
ASSISTANT:

View File

@@ -0,0 +1,9 @@
<|system|>
You are a helpful assistant.</s>
<|user|>
Hello, how are you?</s>
<|assistant|>
I'm doing great. How can I help you today?</s>
<|user|>
I'd like to show off how chat templating works!</s>
<|assistant|>

3
template/testdata/zephyr.gotmpl/user vendored Normal file
View File

@@ -0,0 +1,3 @@
<|user|>
Hello, how are you?</s>
<|assistant|>

View File

@@ -0,0 +1,7 @@
<|user|>
Hello, how are you?</s>
<|assistant|>
I'm doing great. How can I help you today?</s>
<|user|>
I'd like to show off how chat templating works!</s>
<|assistant|>

View File

@@ -1,3 +1,14 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}USER: {{ .Content }}
{{ else if eq .Role "assistant" }}ASSISTANT: {{ .Content }}</s>
{{ end }}
{{- end }}ASSISTANT:
{{- else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}USER: {{ .Prompt }}
{{ end }}ASSISTANT: {{ .Response }}
{{- end }}

View File

@@ -1,6 +1,15 @@
{{- if .Messages }}
{{- if .System }}<|system|>
{{ .System }}</s>
{{ end }}
{{- range .Messages }}<|{{ .Role }}|>
{{ .Content }}</s>
{{ end }}<|assistant|>
{{ else }}
{{ if .System }}<|system|>
{{ .System }}</s>
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}</s>
{{ end }}<|assistant|>
{{ .Response }}</s>
{{- end }}