Compare commits

...

82 Commits

Author SHA1 Message Date
Roy Han
568416ba17 add suffix 2024-07-16 16:51:27 -07:00
Roy Han
80cba42ab2 Update docs 2024-07-16 16:51:27 -07:00
royjhan
6477a7aca4 Merge branch 'royh-completions-docs' of https://github.com/ollama/ollama into royh-completions-docs 2024-07-16 16:51:11 -07:00
royjhan
51214ddef5 Update docs/openai.md 2024-07-16 16:34:31 -07:00
royjhan
b950d749a9 Update docs/openai.md 2024-07-16 16:34:31 -07:00
Roy Han
3702ed7532 token bug corrected 2024-07-16 16:34:31 -07:00
Roy Han
6266603b17 Update docs 2024-07-16 16:34:31 -07:00
Michael Yang
499e87c9ba Merge pull request #5730 from ollama/mxyng/cleanup
remove unneeded tool calls
2024-07-16 14:42:13 -07:00
Michael Yang
cd0853f2d5 Merge pull request #5207 from ollama/mxyng/suffix
add insert support to generate endpoint
2024-07-16 14:37:32 -07:00
Michael Yang
d290e87513 add suffix support to generate endpoint
this change is triggered by the presence of "suffix", particularly
useful for code completion tasks
2024-07-16 14:31:35 -07:00
Thorsten Sommer
97c20ede33 README: Added AI Studio to the list of UIs (#5721)
* Added AI Studio to the list of UIs
2024-07-16 14:24:27 -07:00
Michael Yang
5a83f79afd remove unneeded tool calls 2024-07-16 13:48:45 -07:00
royjhan
987dbab0b0 OpenAI: /v1/embeddings compatibility (#5285)
* OpenAI v1 models

* Empty List Testing

* Add back envconfig

* v1/models docs

* Remove Docs

* OpenAI batch embed compatibility

* merge conflicts

* integrate with api/embed

* ep

* merge conflicts

* request tests

* rm resp test

* merge conflict

* merge conflict

* test fixes

* test fn renaming

* input validation for empty string

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2024-07-16 13:36:08 -07:00
Michael Yang
a8388beb94 Merge pull request #5726 from ollama/mxyng/tools-templates
fix unmarshal type errors
2024-07-16 12:12:10 -07:00
Michael Yang
5afbb60fc4 fix unmarshal type errors 2024-07-16 11:39:34 -07:00
Jeffrey Morgan
4cb5d7decc server: omit model system prompt if empty (#5717) 2024-07-16 11:09:00 -07:00
Michael Yang
8eac50dd4f Merge pull request #5684 from ollama/mxyng/tests
add chat and generate tests with mock runner
2024-07-16 09:44:45 -07:00
Michael Yang
4a565cbf94 add chat and generate tests with mock runner 2024-07-16 09:39:31 -07:00
Michael Yang
64039df6d7 Merge pull request #5284 from ollama/mxyng/tools
tools
2024-07-15 18:03:37 -07:00
Jeffrey Morgan
7ac6d462ec server: return empty slice on empty /api/embed request (#5713)
* server: return empty slice on empty `/api/embed` request

* fix tests
2024-07-15 17:39:44 -07:00
Michael Yang
ef5136a745 tools test 2024-07-15 17:18:21 -07:00
Daniel Hiltgen
8288ec8824 Merge pull request #5710 from dhiltgen/rocm_bump
Bump linux ROCm to 6.1.2
2024-07-15 15:32:18 -07:00
Michael Yang
d02bbebb11 tools 2024-07-15 15:26:16 -07:00
Daniel Hiltgen
224337b32f Bump linux ROCm to 6.1.2 2024-07-15 15:10:22 -07:00
Jeffrey Morgan
9e35d9bbee server: lowercase roles for compatibility with clients (#5695) 2024-07-15 13:55:57 -07:00
royjhan
b9f5e16c80 Introduce /api/embed endpoint supporting batch embedding (#5127)
* Initial Batch Embedding

* Revert "Initial Batch Embedding"

This reverts commit c22d54895a.

* Initial Draft

* mock up notes

* api/embed draft

* add server function

* check normalization

* clean up

* normalization

* playing around with truncate stuff

* Truncation

* Truncation

* move normalization to go

* Integration Test Template

* Truncation Integration Tests

* Clean up

* use float32

* move normalize

* move normalize test

* refactoring

* integration float32

* input handling and handler testing

* Refactoring of legacy and new

* clear comments

* merge conflicts

* touches

* embedding type 64

* merge conflicts

* fix hanging on single string

* refactoring

* test values

* set context length

* clean up

* testing clean up

* testing clean up

* remove function closure

* Revert "remove function closure"

This reverts commit 55d48c6ed1.

* remove function closure

* remove redundant error check

* clean up

* more clean up

* clean up
2024-07-15 12:14:24 -07:00
royjhan
e9f7f36029 Support image input for OpenAI chat compatibility (#5208)
* OpenAI v1 models

* Refactor Writers

* Add Test

Co-Authored-By: Attila Kerekes

* Credit Co-Author

Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com>

* Empty List Testing

* Use Namespace for Ownedby

* Update Test

* Add back envconfig

* v1/models docs

* Use ModelName Parser

* Test Names

* Remove Docs

* Clean Up

* Test name

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* Add Middleware for Chat and List

* Testing Cleanup

* Test with Fatal

* Add functionality to chat test

* Support image input for OpenAI chat

* Decoding

* Fix message processing logic

* openai vision test

* type errors

* clean up

* redundant check

* merge conflicts

* merge conflicts

* merge conflicts

* flattening and smaller image

* add test

* support python and js SDKs and mandate prefixing

* clean up

---------

Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com>
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-07-13 22:07:45 -07:00
Patrick Devine
057d31861e remove template (#5655) 2024-07-13 20:56:24 -07:00
jmorganca
f7ee012300 server: prepend system message in chat handler 2024-07-13 15:08:00 -07:00
Jeffrey Morgan
1ed0aa8fea server: fix context, load_duration and total_duration fields (#5676)
* server: fix `contet`, `load_duration` and `total_duration` fields

* Update server/routes.go
2024-07-13 09:25:31 -07:00
Jeffrey Morgan
ef98803d63 llm: looser checks for minimum memory (#5677) 2024-07-13 09:20:05 -07:00
Jarek
02fea420e5 Add Kerlig AI, an app for macOS (#5675) 2024-07-13 08:33:46 -07:00
Michael Yang
22c5451fc2 fix system prompt (#5662)
* fix system prompt

* execute template when hitting previous roles

* fix tests

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2024-07-12 21:04:44 -07:00
Patrick Devine
23ebbaa46e Revert "remove template from tests"
This reverts commit 9ac0a7a50b.
2024-07-12 15:47:17 -07:00
Patrick Devine
9ac0a7a50b remove template from tests 2024-07-12 15:41:31 -07:00
Michael Yang
e5c65a85df Merge pull request #5653 from ollama/mxyng/collect-system
template: preprocess message and collect system
2024-07-12 12:32:34 -07:00
Jeffrey Morgan
33627331a3 app: also clean up tempdir runners on install (#5646) 2024-07-12 12:29:23 -07:00
Michael Yang
36c87c433b template: preprocess message and collect system 2024-07-12 12:26:43 -07:00
Jeffrey Morgan
179737feb7 Clean up old files when installing on Windows (#5645)
* app: always clean up install dir; force close applications

* remove wildcard

* revert `CloseApplications`

* whitespace

* update `LOCALAPPDATA` var
2024-07-11 22:53:46 -07:00
Michael Yang
47353f5ee4 Merge pull request #5639 from ollama/mxyng/unaggregated-system 2024-07-11 17:48:50 -07:00
Josh
10e768826c fix: quant err message (#5616) 2024-07-11 17:24:29 -07:00
Michael Yang
5056bb9c01 rename aggregate to contents 2024-07-11 17:00:26 -07:00
Jeffrey Morgan
c4cf8ad559 llm: avoid loading model if system memory is too small (#5637)
* llm: avoid loading model if system memory is too small

* update log

* Instrument swap free space

On linux and windows, expose how much swap space is available
so we can take that into consideration when scheduling models

* use `systemSwapFreeMemory` in check

---------

Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
2024-07-11 16:42:57 -07:00
Michael Yang
57ec6901eb revert embedded templates to use prompt/response
This reverts commit 19753c18c0.

for compat. messages will be added at a later date
2024-07-11 14:49:35 -07:00
Michael Yang
e64f9ebb44 do no automatically aggregate system messages 2024-07-11 14:49:35 -07:00
Jeffrey Morgan
791650ddef sched: only error when over-allocating system memory (#5626) 2024-07-11 00:53:12 -07:00
Jeffrey Morgan
efbf41ed81 llm: dont link cuda with compat libs (#5621) 2024-07-10 20:01:52 -07:00
Michael Yang
cf15589851 Merge pull request #5620 from ollama/mxyng/templates
update embedded templates
2024-07-10 17:16:24 -07:00
Michael Yang
19753c18c0 update embedded templates 2024-07-10 17:03:08 -07:00
Michael Yang
41be28096a add system prompt to first legacy template 2024-07-10 17:03:08 -07:00
Michael Yang
37a570f962 Merge pull request #5612 from ollama/mxyng/mem
chatglm graph
2024-07-10 14:18:33 -07:00
Michael Yang
5a739ff4cb chatglm graph 2024-07-10 13:43:47 -07:00
Jeffrey Morgan
4e262eb2a8 remove GGML_CUDA_FORCE_MMQ=on from build (#5588) 2024-07-10 13:17:13 -07:00
Daniel Hiltgen
4cfcbc328f Merge pull request #5124 from dhiltgen/amd_windows
Wire up windows AMD driver reporting
2024-07-10 12:50:23 -07:00
Daniel Hiltgen
79292ff3e0 Merge pull request #5555 from dhiltgen/msvc_deps
Bundle missing CRT libraries
2024-07-10 12:50:02 -07:00
Daniel Hiltgen
8ea500441d Merge pull request #5580 from dhiltgen/cuda_overhead
Detect CUDA OS overhead
2024-07-10 12:47:31 -07:00
Daniel Hiltgen
b50c818623 Merge pull request #5607 from dhiltgen/win_rocm_v6
Bump ROCm on windows to 6.1.2
2024-07-10 12:47:10 -07:00
Daniel Hiltgen
b99e750b62 Merge pull request #5605 from dhiltgen/merge_glitch
Remove duplicate merge glitch
2024-07-10 11:47:08 -07:00
Daniel Hiltgen
1f50356e8e Bump ROCm on windows to 6.1.2
This also adjusts our algorithm to favor our bundled ROCm.
I've confirmed VRAM reporting still doesn't work properly so we
can't yet enable concurrency by default.
2024-07-10 11:01:22 -07:00
Daniel Hiltgen
22c81f62ec Remove duplicate merge glitch 2024-07-10 09:01:33 -07:00
Daniel Hiltgen
2d1e3c3229 Merge pull request #5503 from dhiltgen/dual_rocm
Workaround broken ROCm p2p copy
2024-07-09 15:44:16 -07:00
royjhan
4918fae535 OpenAI v1/completions: allow stop token list (#5551)
* stop token parsing fix

* add stop test
2024-07-09 14:01:26 -07:00
royjhan
0aff67877e separate request tests (#5578) 2024-07-09 13:48:31 -07:00
Daniel Hiltgen
f6f759fc5f Detect CUDA OS Overhead
This adds logic to detect skew between the driver and
management library which can be attributed to OS overhead
and records that so we can adjust subsequent management
library free VRAM updates and avoid OOM scenarios.
2024-07-09 12:21:50 -07:00
Daniel Hiltgen
9544a57ee4 Merge pull request #5579 from dhiltgen/win_static_deps
Statically link c++ and thread lib on windows
2024-07-09 12:21:13 -07:00
Daniel Hiltgen
b51e3b63ac Statically link c++ and thread lib
This makes sure we statically link the c++ and thread library on windows
to avoid unnecessary runtime dependencies on non-standard DLLs
2024-07-09 11:34:30 -07:00
Michael Yang
6bbbc50f10 Merge pull request #5440 from ollama/mxyng/messages-templates
update named templates
2024-07-09 09:36:32 -07:00
Michael Yang
9bbddc37a7 Merge pull request #5126 from ollama/mxyng/messages
update message processing
2024-07-09 09:20:44 -07:00
Jeffrey Morgan
e4ff73297d server: fix model reloads when setting OLLAMA_NUM_PARALLEL (#5560)
* server: fix unneeded model reloads when setting `OLLAMA_NUM_PARALLEL`

* remove whitespace change

* undo some changes
2024-07-08 22:32:15 -07:00
Daniel Hiltgen
b44320db13 Bundle missing CRT libraries
Some users are experienging runner startup errors due
to not having these msvc redist libraries on their host
2024-07-08 18:24:21 -07:00
royjhan
2644c4e682 Update docs/openai.md 2024-07-08 14:46:05 -07:00
royjhan
04cde43b2a Update docs/openai.md 2024-07-08 14:44:16 -07:00
Daniel Hiltgen
0bacb30007 Workaround broken ROCm p2p copy
Enable the build flag for llama.cpp to use CPU copy for multi-GPU scenarios.
2024-07-08 09:40:52 -07:00
Michael Yang
fb6cbc02fb update named templates 2024-07-05 16:29:32 -07:00
Michael Yang
326363b3a7 no funcs 2024-07-05 13:17:25 -07:00
Michael Yang
ac7a842e55 fix model reloading
ensure runtime model changes (template, system prompt, messages,
options) are captured on model updates without needing to reload the
server
2024-07-05 13:17:25 -07:00
Michael Yang
2c3fe1fd97 comments 2024-07-05 13:17:24 -07:00
Michael Yang
269ed6e6a2 update message processing 2024-07-05 13:16:58 -07:00
Roy Han
105e36765d token bug corrected 2024-07-03 15:03:54 -07:00
royjhan
fa7be5aab4 Merge branch 'main' into royh-completions-docs 2024-07-02 14:52:56 -07:00
Roy Han
02169f3e60 Update docs 2024-06-26 14:30:28 -07:00
Daniel Hiltgen
784bf88b0d Wire up windows AMD driver reporting
This seems to be ROCm version, not actually driver version, but
it may be useful for toggling logic for VRAM reporting in the future
2024-06-18 16:22:47 -07:00
126 changed files with 3822 additions and 1118 deletions

View File

@@ -147,7 +147,7 @@ jobs:
run: | run: |
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer" write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP" write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP" write-host "Completed AMD HIP"
@@ -304,11 +304,6 @@ jobs:
write-host "Installing plugin" write-host "Installing plugin"
& "${env:RUNNER_TEMP}\plugin\*\kmscng.msi" /quiet & "${env:RUNNER_TEMP}\plugin\*\kmscng.msi" /quiet
write-host "plugin installed" write-host "plugin installed"
- name: remove unwanted mingw dll.a files
run: |
Get-ChildItem -Path "C:\mingw64" -Recurse -Filter "libpthread.dll.a" -File | Remove-Item -Force
Get-ChildItem -Path "C:\mingw64" -Recurse -Filter "libwinpthread.dll.a" -File | Remove-Item -Force
Get-ChildItem -Path "C:\mingw64" -Recurse -Filter "libstdc++.dll.a" -File | Remove-Item -Force
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version-file: go.mod

View File

@@ -126,7 +126,7 @@ jobs:
strategy: strategy:
matrix: matrix:
rocm-version: rocm-version:
- '6.1.1' - '6.1.2'
runs-on: linux runs-on: linux
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }} container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
steps: steps:
@@ -169,7 +169,7 @@ jobs:
run: | run: |
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer" write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP" write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP" write-host "Completed AMD HIP"

View File

@@ -2,7 +2,7 @@ ARG GOLANG_VERSION=1.22.1
ARG CMAKE_VERSION=3.22.1 ARG CMAKE_VERSION=3.22.1
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md # this CUDA_VERSION corresponds with the one specified in docs/gpu.md
ARG CUDA_VERSION=11.3.1 ARG CUDA_VERSION=11.3.1
ARG ROCM_VERSION=6.1.1 ARG ROCM_VERSION=6.1.2
# Copy the minimal context we need to run the generate scripts # Copy the minimal context we need to run the generate scripts
FROM scratch AS llm-code FROM scratch AS llm-code

View File

@@ -293,6 +293,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama) - [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
### Terminal ### Terminal

View File

@@ -347,7 +347,16 @@ func (c *Client) Heartbeat(ctx context.Context) error {
return nil return nil
} }
// Embeddings generates embeddings from a model. // Embed generates embeddings from a model.
func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
var resp EmbedResponse
if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Embeddings generates an embedding from a model.
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) { func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
var resp EmbeddingResponse var resp EmbeddingResponse
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil { if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {

View File

@@ -47,6 +47,9 @@ type GenerateRequest struct {
// Prompt is the textual prompt to send to the model. // Prompt is the textual prompt to send to the model.
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
// Suffix is the text that comes after the inserted text.
Suffix string `json:"suffix"`
// System overrides the model's default system message/prompt. // System overrides the model's default system message/prompt.
System string `json:"system"` System string `json:"system"`
@@ -97,6 +100,9 @@ type ChatRequest struct {
// followin the request. // followin the request.
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Tools is an optional list of tools the model has access to.
Tools []Tool `json:"tools,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
@@ -105,9 +111,46 @@ type ChatRequest struct {
// role ("system", "user", or "assistant"), the content and an optional list // role ("system", "user", or "assistant"), the content and an optional list
// of images. // of images.
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content,omitempty"`
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
type ToolCall struct {
Function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
} `json:"function"`
}
type Tool struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
} `json:"function"`
}
func (m *Message) UnmarshalJSON(b []byte) error {
type Alias Message
var a Alias
if err := json.Unmarshal(b, &a); err != nil {
return err
}
*m = Message(a)
m.Role = strings.ToLower(m.Role)
return nil
} }
// ChatResponse is the response returned by [Client.Chat]. Its fields are // ChatResponse is the response returned by [Client.Chat]. Its fields are
@@ -173,6 +216,30 @@ type Runner struct {
NumThread int `json:"num_thread,omitempty"` NumThread int `json:"num_thread,omitempty"`
} }
// EmbedRequest is the request passed to [Client.Embed].
type EmbedRequest struct {
// Model is the model name.
Model string `json:"model"`
// Input is the input to embed.
Input any `json:"input"`
// KeepAlive controls how long the model will stay loaded in memory following
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
Truncate *bool `json:"truncate,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
// EmbedResponse is the response from [Client.Embed].
type EmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float32 `json:"embeddings"`
}
// EmbeddingRequest is the request passed to [Client.Embeddings]. // EmbeddingRequest is the request passed to [Client.Embeddings].
type EmbeddingRequest struct { type EmbeddingRequest struct {
// Model is the model name. // Model is the model name.
@@ -219,8 +286,10 @@ type DeleteRequest struct {
// ShowRequest is the request passed to [Client.Show]. // ShowRequest is the request passed to [Client.Show].
type ShowRequest struct { type ShowRequest struct {
Model string `json:"model"` Model string `json:"model"`
System string `json:"system"` System string `json:"system"`
// Template is deprecated
Template string `json:"template"` Template string `json:"template"`
Verbose bool `json:"verbose"` Verbose bool `json:"verbose"`
@@ -336,6 +405,9 @@ type GenerateResponse struct {
// Response is the textual response itself. // Response is the textual response itself.
Response string `json:"response"` Response string `json:"response"`
// ToolCalls is the list of tools the model wants to call
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
// Done specifies if the response is complete. // Done specifies if the response is complete.
Done bool `json:"done"` Done bool `json:"done"`

View File

@@ -208,3 +208,26 @@ func TestUseMmapFormatParams(t *testing.T) {
}) })
} }
} }
func TestMessage_UnmarshalJSON(t *testing.T) {
tests := []struct {
input string
expected string
}{
{`{"role": "USER", "content": "Hello!"}`, "user"},
{`{"role": "System", "content": "Initialization complete."}`, "system"},
{`{"role": "assistant", "content": "How can I help you?"}`, "assistant"},
{`{"role": "TOOl", "content": "Access granted."}`, "tool"},
}
for _, test := range tests {
var msg Message
if err := json.Unmarshal([]byte(test.input), &msg); err != nil {
t.Errorf("Unexpected error: %v", err)
}
if msg.Role != test.expected {
t.Errorf("role not lowercased: got %v, expected %v", msg.Role, test.expected)
}
}
}

View File

@@ -127,6 +127,10 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models"
Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history" Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved ; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
[InstallDelete]
Type: filesandordirs; Name: "{%TEMP}\ollama*"
Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama"
[Messages] [Messages]
WizardReady=Ollama Windows Preview WizardReady=Ollama Windows Preview
ReadyLabel1=%nLet's get you up and running with your own large language models. ReadyLabel1=%nLet's get you up and running with your own large language models.

View File

@@ -843,7 +843,6 @@ type runOptions struct {
WordWrap bool WordWrap bool
Format string Format string
System string System string
Template string
Images []api.ImageData Images []api.ImageData
Options map[string]interface{} Options map[string]interface{}
MultiModal bool MultiModal bool
@@ -1037,7 +1036,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
Images: opts.Images, Images: opts.Images,
Format: opts.Format, Format: opts.Format,
System: opts.System, System: opts.System,
Template: opts.Template,
Options: opts.Options, Options: opts.Options,
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
} }

View File

@@ -27,7 +27,6 @@ const (
MultilineNone MultilineState = iota MultilineNone MultilineState = iota
MultilinePrompt MultilinePrompt
MultilineSystem MultilineSystem
MultilineTemplate
) )
func loadModel(cmd *cobra.Command, opts *runOptions) error { func loadModel(cmd *cobra.Command, opts *runOptions) error {
@@ -94,7 +93,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter") fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
fmt.Fprintln(os.Stderr, " /set system <string> Set system message") fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
fmt.Fprintln(os.Stderr, " /set history Enable history") fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history") fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap") fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
@@ -204,10 +202,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System}) opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset() sb.Reset()
case MultilineTemplate:
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
} }
multiline = MultilineNone multiline = MultilineNone
@@ -326,17 +320,13 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", ")) fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
opts.Options[args[2]] = fp[args[2]] opts.Options[args[2]] = fp[args[2]]
case "system", "template": case "system":
if len(args) < 3 { if len(args) < 3 {
usageSet() usageSet()
continue continue
} }
if args[1] == "system" { multiline = MultilineSystem
multiline = MultilineSystem
} else if args[1] == "template" {
multiline = MultilineTemplate
}
line := strings.Join(args[2:], " ") line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`) line, ok := strings.CutPrefix(line, `"""`)
@@ -356,23 +346,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
continue continue
} }
if args[1] == "system" { opts.System = sb.String() // for display in modelfile
opts.System = sb.String() // for display in modelfile newMessage := api.Message{Role: "system", Content: sb.String()}
newMessage := api.Message{Role: "system", Content: sb.String()} // Check if the slice is not empty and the last message is from 'system'
// Check if the slice is not empty and the last message is from 'system' if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" { // Replace the last message
// Replace the last message opts.Messages[len(opts.Messages)-1] = newMessage
opts.Messages[len(opts.Messages)-1] = newMessage } else {
} else { opts.Messages = append(opts.Messages, newMessage)
opts.Messages = append(opts.Messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
} else if args[1] == "template" {
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
} }
fmt.Println("Set system message.")
sb.Reset()
sb.Reset() sb.Reset()
continue continue
@@ -393,7 +377,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
req := &api.ShowRequest{ req := &api.ShowRequest{
Name: opts.Model, Name: opts.Model,
System: opts.System, System: opts.System,
Template: opts.Template,
Options: opts.Options, Options: opts.Options,
} }
resp, err := client.Show(cmd.Context(), req) resp, err := client.Show(cmd.Context(), req)
@@ -437,12 +420,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("No system message was specified for this model.") fmt.Println("No system message was specified for this model.")
} }
case "template": case "template":
switch { if resp.Template != "" {
case opts.Template != "":
fmt.Println(opts.Template + "\n")
case resp.Template != "":
fmt.Println(resp.Template) fmt.Println(resp.Template)
default: } else {
fmt.Println("No prompt template was specified for this model.") fmt.Println("No prompt template was specified for this model.")
} }
default: default:
@@ -536,10 +516,6 @@ func buildModelfile(opts runOptions) string {
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System) fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
} }
if opts.Template != "" {
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
}
keys := make([]string, 0) keys := make([]string, 0)
for k := range opts.Options { for k := range opts.Options {
keys = append(keys, k) keys = append(keys, k)

View File

@@ -59,7 +59,6 @@ func TestModelfileBuilder(t *testing.T) {
opts := runOptions{ opts := runOptions{
Model: "hork", Model: "hork",
System: "You are part horse and part shark, but all hork. Do horklike things", System: "You are part horse and part shark, but all hork. Do horklike things",
Template: "This is a template.",
Messages: []api.Message{ Messages: []api.Message{
{Role: "user", Content: "Hey there hork!"}, {Role: "user", Content: "Hey there hork!"},
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."}, {Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
@@ -75,7 +74,6 @@ func TestModelfileBuilder(t *testing.T) {
mf := buildModelfile(opts) mf := buildModelfile(opts)
expectedModelfile := `FROM {{.Model}} expectedModelfile := `FROM {{.Model}}
SYSTEM """{{.System}}""" SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false PARAMETER penalize_newline false
PARAMETER seed 42 PARAMETER seed 42
PARAMETER stop [hi there] PARAMETER stop [hi there]
@@ -97,7 +95,6 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
mf = buildModelfile(opts) mf = buildModelfile(opts)
expectedModelfile = `FROM {{.ParentModel}} expectedModelfile = `FROM {{.ParentModel}}
SYSTEM """{{.System}}""" SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false PARAMETER penalize_newline false
PARAMETER seed 42 PARAMETER seed 42
PARAMETER stop [hi there] PARAMETER stop [hi there]

View File

@@ -272,4 +272,4 @@ The following server settings may be used to adjust how Ollama handles concurren
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory. - `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512 - `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM. Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.

View File

@@ -27,6 +27,11 @@ chat_completion = client.chat.completions.create(
], ],
model='llama3', model='llama3',
) )
completion = client.completions.create(
model="llama3",
prompt="Say this is a test"
)
``` ```
### OpenAI JavaScript library ### OpenAI JavaScript library
@@ -45,6 +50,11 @@ const chatCompletion = await openai.chat.completions.create({
messages: [{ role: 'user', content: 'Say this is a test' }], messages: [{ role: 'user', content: 'Say this is a test' }],
model: 'llama3', model: 'llama3',
}) })
const completion = await openai.completions.create({
model: "llama3",
prompt: "Say this is a test.",
})
``` ```
### `curl` ### `curl`
@@ -66,6 +76,12 @@ curl http://localhost:11434/v1/chat/completions \
] ]
}' }'
curl http://localhost:11434/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama3",
"prompt": "Say this is a test"
}'
``` ```
## Endpoints ## Endpoints
@@ -103,8 +119,71 @@ curl http://localhost:11434/v1/chat/completions \
- [ ] `user` - [ ] `user`
- [ ] `n` - [ ] `n`
### `/v1/completions`
#### Supported features
- [x] Completions
- [x] Streaming
- [x] JSON mode
- [x] Reproducible outputs
- [ ] Logprobs
#### Supported request fields
- [x] `model`
- [x] `prompt`
- [x] `frequency_penalty`
- [x] `presence_penalty`
- [x] `seed`
- [x] `stop`
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `max_tokens`
- [x] `suffix`
- [ ] `best_of`
- [ ] `echo`
- [ ] `logit_bias`
- [ ] `user`
- [ ] `n`
#### Notes #### Notes
- `prompt` currently only accepts a string
### `/v1/completions`
#### Supported features
- [x] Completions
- [x] Streaming
- [x] JSON mode
- [x] Reproducible outputs
- [ ] Logprobs
#### Supported request fields
- [x] `model`
- [x] `prompt`
- [x] `frequency_penalty`
- [x] `presence_penalty`
- [x] `seed`
- [x] `stop`
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `max_tokens`
- [ ] `best_of`
- [ ] `echo`
- [ ] `suffix`
- [ ] `logit_bias`
- [ ] `user`
- [ ] `n`
#### Notes
- `prompt` currently only accepts a string
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached - `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
## Models ## Models

3
go.mod
View File

@@ -18,6 +18,7 @@ require (
require ( require (
github.com/agnivade/levenshtein v1.1.1 github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/google/go-cmp v0.6.0
github.com/mattn/go-runewidth v0.0.14 github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0 github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
@@ -71,7 +72,7 @@ require (
golang.org/x/net v0.25.0 // indirect golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.20.0 golang.org/x/sys v0.20.0
golang.org/x/term v0.20.0 golang.org/x/term v0.20.0
golang.org/x/text v0.15.0 // indirect golang.org/x/text v0.15.0
google.golang.org/protobuf v1.34.1 google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

View File

@@ -49,9 +49,17 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
} }
func commonAMDValidateLibDir() (string, error) { func commonAMDValidateLibDir() (string, error) {
// We try to favor system paths first, so that we can wire up the subprocess to use // Favor our bundled version
// the system version. Only use our bundled version if the system version doesn't work
// This gives users a more recovery options if versions have subtle problems at runtime // Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
// Prefer explicit HIP env var // Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH") hipPath := os.Getenv("HIP_PATH")
@@ -87,14 +95,5 @@ func commonAMDValidateLibDir() (string, error) {
} }
} }
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
return "", fmt.Errorf("no suitable rocm found, falling back to CPU") return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
} }

View File

@@ -84,9 +84,8 @@ func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
} }
slog.Debug("hipDriverGetVersion", "version", version) slog.Debug("hipDriverGetVersion", "version", version)
// TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway... driverMajor = version / 10000000
driverMajor = version / 1000 driverMinor = (version - (driverMajor * 10000000)) / 100000
driverMinor = (version - (driverMajor * 1000)) / 10
return driverMajor, driverMinor, nil return driverMajor, driverMinor, nil
} }

View File

@@ -22,8 +22,8 @@ const (
var ( var (
// Used to validate if the given ROCm lib is usable // Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here... ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob? RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob?
) )
func AMDGetGPUInfo() []RocmGPUInfo { func AMDGetGPUInfo() []RocmGPUInfo {
@@ -35,12 +35,11 @@ func AMDGetGPUInfo() []RocmGPUInfo {
} }
defer hl.Release() defer hl.Release()
// TODO - this reports incorrect version information, so omitting for now driverMajor, driverMinor, err := hl.AMDDriverVersion()
// driverMajor, driverMinor, err := hl.AMDDriverVersion() if err != nil {
// if err != nil { // For now this is benign, but we may eventually need to fail compatibility checks
// // For now this is benign, but we may eventually need to fail compatibility checks slog.Debug("error looking up amd driver version", "error", err)
// slog.Debug("error looking up amd driver version", "error", err) }
// }
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified // Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
count := hl.HipGetDeviceCount() count := hl.HipGetDeviceCount()
@@ -132,10 +131,8 @@ func AMDGetGPUInfo() []RocmGPUInfo {
MinimumMemory: rocmMinimumMemory, MinimumMemory: rocmMinimumMemory,
Name: name, Name: name,
Compute: gfx, Compute: gfx,
DriverMajor: driverMajor,
// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve DriverMinor: driverMinor,
// DriverMajor: driverMajor,
// DriverMinor: driverMinor,
}, },
index: i, index: i,
} }

View File

@@ -274,6 +274,28 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.DriverMajor = driverMajor gpuInfo.DriverMajor = driverMajor
gpuInfo.DriverMinor = driverMinor gpuInfo.DriverMinor = driverMinor
// query the management library as well so we can record any skew between the two
// which represents overhead on the GPU we must set aside on subsequent updates
if cHandles.nvml != nil {
C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
if memInfo.err != nil {
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
} else {
if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
slog.Info("detected OS VRAM overhead",
"id", gpuInfo.ID,
"library", gpuInfo.Library,
"compute", gpuInfo.Compute,
"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
"name", gpuInfo.Name,
"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
)
}
}
}
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does... // TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
cudaGPUs = append(cudaGPUs, gpuInfo) cudaGPUs = append(cudaGPUs, gpuInfo)
} }
@@ -338,14 +360,17 @@ func GetGPUInfo() GpuInfoList {
"before", "before",
"total", format.HumanBytes2(cpus[0].TotalMemory), "total", format.HumanBytes2(cpus[0].TotalMemory),
"free", format.HumanBytes2(cpus[0].FreeMemory), "free", format.HumanBytes2(cpus[0].FreeMemory),
"free_swap", format.HumanBytes2(cpus[0].FreeSwap),
), ),
slog.Group( slog.Group(
"now", "now",
"total", format.HumanBytes2(mem.TotalMemory), "total", format.HumanBytes2(mem.TotalMemory),
"free", format.HumanBytes2(mem.FreeMemory), "free", format.HumanBytes2(mem.FreeMemory),
"free_swap", format.HumanBytes2(mem.FreeSwap),
), ),
) )
cpus[0].FreeMemory = mem.FreeMemory cpus[0].FreeMemory = mem.FreeMemory
cpus[0].FreeSwap = mem.FreeSwap
} }
var memInfo C.mem_info_t var memInfo C.mem_info_t
@@ -374,9 +399,14 @@ func GetGPUInfo() GpuInfoList {
slog.Warn("error looking up nvidia GPU memory") slog.Warn("error looking up nvidia GPU memory")
continue continue
} }
if cHandles.nvml != nil && gpu.OSOverhead > 0 {
// When using the management library update based on recorded overhead
memInfo.free -= C.uint64_t(gpu.OSOverhead)
}
slog.Debug("updating cuda memory data", slog.Debug("updating cuda memory data",
"gpu", gpu.ID, "gpu", gpu.ID,
"name", gpu.Name, "name", gpu.Name,
"overhead", format.HumanBytes2(gpu.OSOverhead),
slog.Group( slog.Group(
"before", "before",
"total", format.HumanBytes2(gpu.TotalMemory), "total", format.HumanBytes2(gpu.TotalMemory),

View File

@@ -57,6 +57,7 @@ func GetCPUMem() (memInfo, error) {
return memInfo{ return memInfo{
TotalMemory: uint64(C.getPhysicalMemory()), TotalMemory: uint64(C.getPhysicalMemory()),
FreeMemory: uint64(C.getFreeMemory()), FreeMemory: uint64(C.getFreeMemory()),
// FreeSwap omitted as Darwin uses dynamic paging
}, nil }, nil
} }

View File

@@ -50,7 +50,7 @@ var OneapiMgmtName = "libze_intel_gpu.so"
func GetCPUMem() (memInfo, error) { func GetCPUMem() (memInfo, error) {
var mem memInfo var mem memInfo
var total, available, free, buffers, cached uint64 var total, available, free, buffers, cached, freeSwap uint64
f, err := os.Open("/proc/meminfo") f, err := os.Open("/proc/meminfo")
if err != nil { if err != nil {
return mem, err return mem, err
@@ -70,20 +70,21 @@ func GetCPUMem() (memInfo, error) {
_, err = fmt.Sscanf(line, "Buffers:%d", &buffers) _, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
case strings.HasPrefix(line, "Cached:"): case strings.HasPrefix(line, "Cached:"):
_, err = fmt.Sscanf(line, "Cached:%d", &cached) _, err = fmt.Sscanf(line, "Cached:%d", &cached)
case strings.HasPrefix(line, "SwapFree:"):
_, err = fmt.Sscanf(line, "SwapFree:%d", &freeSwap)
default: default:
continue continue
} }
if err != nil { if err != nil {
return mem, err return mem, err
} }
if total > 0 && available > 0 {
mem.TotalMemory = total * format.KibiByte
mem.FreeMemory = available * format.KibiByte
return mem, nil
}
} }
mem.TotalMemory = total * format.KibiByte mem.TotalMemory = total * format.KibiByte
mem.FreeMemory = (free + buffers + cached) * format.KibiByte mem.FreeSwap = freeSwap * format.KibiByte
if available > 0 {
mem.FreeMemory = available * format.KibiByte
} else {
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
}
return mem, nil return mem, nil
} }

View File

@@ -51,5 +51,5 @@ func GetCPUMem() (memInfo, error) {
if r1 == 0 { if r1 == 0 {
return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err) return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
} }
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys}, nil return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys, FreeSwap: memStatus.AvailPageFile}, nil
} }

View File

@@ -10,6 +10,7 @@ import (
type memInfo struct { type memInfo struct {
TotalMemory uint64 `json:"total_memory,omitempty"` TotalMemory uint64 `json:"total_memory,omitempty"`
FreeMemory uint64 `json:"free_memory,omitempty"` FreeMemory uint64 `json:"free_memory,omitempty"`
FreeSwap uint64 `json:"free_swap,omitempty"`
} }
// Beginning of an `ollama info` command // Beginning of an `ollama info` command
@@ -52,7 +53,8 @@ type CPUInfo struct {
type CudaGPUInfo struct { type CudaGPUInfo struct {
GpuInfo GpuInfo
index int //nolint:unused,nolintlint OSOverhead uint64 // Memory overhead between the driver library and management library
index int //nolint:unused,nolintlint
} }
type CudaGPUInfoList []CudaGPUInfo type CudaGPUInfoList []CudaGPUInfo

152
integration/embed_test.go Normal file
View File

@@ -0,0 +1,152 @@
//go:build integration
package integration
import (
"context"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
}
if len(res.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
}
if res.Embeddings[0][0] != 0.010071031 {
t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
}
}
func TestAllMiniLMBatchEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"why is the sky blue?", "why is the grass green?"},
}
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
}
if len(res.Embeddings) != 2 {
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
}
if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
}
}
func TestAllMiniLmEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
truncTrue, truncFalse := true, false
type testReq struct {
Name string
Request api.EmbedRequest
}
reqs := []testReq{
{
Name: "Target Truncation",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why",
},
},
{
Name: "Default Truncate",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 1},
},
},
{
Name: "Explicit Truncate",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1},
},
},
}
res := make(map[string]*api.EmbedResponse)
for _, req := range reqs {
response, err := embedTestHelper(ctx, t, req.Request)
if err != nil {
t.Fatalf("error: %v", err)
}
res[req.Name] = response
}
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
t.Fatal("expected default request to truncate correctly")
}
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
t.Fatal("expected default request and truncate true request to be the same")
}
// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 1},
})
if err == nil {
t.Fatal("expected error, got nil")
}
}
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}
response, err := client.Embed(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
}

View File

@@ -3188,26 +3188,33 @@ int main(int argc, char **argv) {
prompt = ""; prompt = "";
} }
json image_data; if (prompt.size() == 1) {
if (body.count("image_data") != 0) { prompt = prompt[0];
image_data = body["image_data"];
}
else
{
image_data = "";
} }
// create and queue the task // create and queue the task
const int task_id = llama.queue_tasks.get_new_id(); json responses;
llama.queue_results.add_waiting_task_id(task_id); {
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1); const int id_task = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
// get the result // get the result
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(id_task);
llama.queue_results.remove_waiting_task_id(task_id); llama.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
}
// send the result responses = result.result_json.value("results", std::vector<json>{result.result_json});
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); json embeddings = json::array();
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
// send the result
json embedding_res = json{{"embedding", embeddings}};
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
}
}); });
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?

View File

@@ -178,7 +178,7 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}" CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
echo "Building custom CUDA GPU" echo "Building custom CUDA GPU"
else else
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DGGML_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} -DCMAKE_LIBRARY_PATH=/usr/local/cuda/compat" CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
fi fi
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}" CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}" BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
@@ -254,7 +254,7 @@ if [ -z "${OLLAMA_SKIP_ROCM_GENERATE}" -a -d "${ROCM_PATH}" ]; then
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true) ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
fi fi
init_vars init_vars
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)" CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DLLAMA_CUDA_NO_PEER_COPY=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp # Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\"" echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""

View File

@@ -6,18 +6,9 @@ function amdGPUs {
if ($env:AMDGPU_TARGETS) { if ($env:AMDGPU_TARGETS) {
return $env:AMDGPU_TARGETS return $env:AMDGPU_TARGETS
} }
# TODO - load from some common data file for linux + windows build consistency # Current supported rocblas list from ROCm v6.1.2 on windows
$GPU_LIST = @( $GPU_LIST = @(
"gfx900"
"gfx906:xnack-" "gfx906:xnack-"
"gfx908:xnack-"
"gfx90a:xnack+"
"gfx90a:xnack-"
"gfx940"
"gfx941"
"gfx942"
"gfx1010"
"gfx1012"
"gfx1030" "gfx1030"
"gfx1100" "gfx1100"
"gfx1101" "gfx1101"
@@ -366,6 +357,7 @@ function build_rocm() {
"-DCMAKE_C_COMPILER=clang.exe", "-DCMAKE_C_COMPILER=clang.exe",
"-DCMAKE_CXX_COMPILER=clang++.exe", "-DCMAKE_CXX_COMPILER=clang++.exe",
"-DGGML_HIPBLAS=on", "-DGGML_HIPBLAS=on",
"-DLLAMA_CUDA_NO_PEER_COPY=on",
"-DHIP_PLATFORM=amd", "-DHIP_PLATFORM=amd",
"-DGGML_AVX=on", "-DGGML_AVX=on",
"-DGGML_AVX2=off", "-DGGML_AVX2=off",
@@ -394,7 +386,6 @@ function build_rocm() {
sign sign
install install
# Assumes v5.7, may need adjustments for v6
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\" rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\" cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"

View File

@@ -424,6 +424,32 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
4*batch*(3*embedding+vocab)+embedding*vocab*105/128, 4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16, 4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
) )
case "chatglm":
fullOffload = 4 * batch * (embedding + vocab)
partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
fullOffload = max(
fullOffload,
4*batch*(2+
2*embedding+
context+
context*heads+
embeddingHeadsK*heads+
qkvBias.Shape[0]),
)
partialOffload = max(
partialOffload,
4*batch*(1+
2*embedding+
embeddingHeadsK*heads+
context+
context*heads)+
4*embeddingHeadsK*context+
4*context*embeddingHeadsK+
4*qkvBias.Shape[0],
)
}
} }
return return

View File

@@ -537,6 +537,7 @@ var ggufKVOrder = map[string][]string{
"tokenizer.ggml.add_bos_token", "tokenizer.ggml.add_bos_token",
"tokenizer.ggml.add_eos_token", "tokenizer.ggml.add_eos_token",
"tokenizer.chat_template", "tokenizer.chat_template",
"bert.pooling_type",
}, },
} }

View File

@@ -4,8 +4,8 @@ package llm
// #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread // #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
// #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal // #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
// #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src // #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
// #cgo windows,amd64 LDFLAGS: -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src // #cgo windows,amd64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
// #cgo windows,arm64 LDFLAGS: -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src // #cgo windows,arm64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
// #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src // #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src // #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
// #include <stdlib.h> // #include <stdlib.h>
@@ -33,7 +33,7 @@ func Quantize(infile, outfile string, ftype fileType) error {
params.ftype = ftype.Value() params.ftype = ftype.Value()
if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 { if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
return fmt.Errorf("llama_model_quantize: %d", rc) return fmt.Errorf("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
} }
return nil return nil

View File

@@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error) Embed(ctx context.Context, input []string) ([][]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
@@ -88,6 +88,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
var estimate MemoryEstimate var estimate MemoryEstimate
var systemTotalMemory uint64 var systemTotalMemory uint64
var systemFreeMemory uint64 var systemFreeMemory uint64
var systemSwapFreeMemory uint64
systemMemInfo, err := gpu.GetCPUMem() systemMemInfo, err := gpu.GetCPUMem()
if err != nil { if err != nil {
@@ -95,7 +96,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} else { } else {
systemTotalMemory = systemMemInfo.TotalMemory systemTotalMemory = systemMemInfo.TotalMemory
systemFreeMemory = systemMemInfo.FreeMemory systemFreeMemory = systemMemInfo.FreeMemory
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", systemFreeMemory) systemSwapFreeMemory = systemMemInfo.FreeSwap
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
} }
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info // If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
@@ -122,6 +124,16 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} }
} }
// On linux, over-allocating CPU memory will almost always result in an error
if runtime.GOOS == "linux" {
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
available := systemFreeMemory + systemSwapFreeMemory
if systemMemoryRequired > available {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
}
}
estimate.log() estimate.log()
// Loop through potential servers // Loop through potential servers
@@ -254,10 +266,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--tensor-split", estimate.TensorSplit) params = append(params, "--tensor-split", estimate.TensorSplit)
} }
if estimate.TensorSplit != "" {
params = append(params, "--tensor-split", estimate.TensorSplit)
}
for i := range len(servers) { for i := range len(servers) {
dir := availableServers[servers[i]] dir := availableServers[servers[i]]
if dir == "" { if dir == "" {
@@ -679,7 +687,7 @@ type CompletionRequest struct {
Prompt string Prompt string
Format string Format string
Images []ImageData Images []ImageData
Options api.Options Options *api.Options
} }
type CompletionResponse struct { type CompletionResponse struct {
@@ -859,15 +867,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return nil return nil
} }
type EmbeddingRequest struct { type EmbedRequest struct {
Content string `json:"content"` Content []string `json:"content"`
} }
type EmbeddingResponse struct { type EmbedResponse struct {
Embedding []float64 `json:"embedding"` Embedding [][]float32 `json:"embedding"`
} }
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
if err := s.sem.Acquire(ctx, 1); err != nil { if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err) slog.Error("Failed to acquire semaphore", "error", err)
return nil, err return nil, err
@@ -882,7 +890,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(TokenizeRequest{Content: prompt}) data, err := json.Marshal(EmbedRequest{Content: input})
if err != nil { if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err) return nil, fmt.Errorf("error marshaling embed data: %w", err)
} }
@@ -909,7 +917,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("%s", body) return nil, fmt.Errorf("%s", body)
} }
var embedding EmbeddingResponse var embedding EmbedResponse
if err := json.Unmarshal(body, &embedding); err != nil { if err := json.Unmarshal(body, &embedding); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err) return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
} }

View File

@@ -3,11 +3,13 @@ package openai
import ( import (
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -28,7 +30,7 @@ type ErrorResponse struct {
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content any `json:"content"`
} }
type Choice struct { type Choice struct {
@@ -59,6 +61,11 @@ type ResponseFormat struct {
Type string `json:"type"` Type string `json:"type"`
} }
type EmbedRequest struct {
Input any `json:"input"`
Model string `json:"model"`
}
type ChatCompletionRequest struct { type ChatCompletionRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
@@ -132,11 +139,23 @@ type Model struct {
OwnedBy string `json:"owned_by"` OwnedBy string `json:"owned_by"`
} }
type Embedding struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type ListCompletion struct { type ListCompletion struct {
Object string `json:"object"` Object string `json:"object"`
Data []Model `json:"data"` Data []Model `json:"data"`
} }
type EmbeddingList struct {
Object string `json:"object"`
Data []Embedding `json:"data"`
Model string `json:"model"`
}
func NewError(code int, message string) ErrorResponse { func NewError(code int, message string) ErrorResponse {
var etype string var etype string
switch code { switch code {
@@ -260,6 +279,27 @@ func toListCompletion(r api.ListResponse) ListCompletion {
} }
} }
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
if r.Embeddings != nil {
var data []Embedding
for i, e := range r.Embeddings {
data = append(data, Embedding{
Object: "embedding",
Embedding: e,
Index: i,
})
}
return EmbeddingList{
Object: "list",
Data: data,
Model: model,
}
}
return EmbeddingList{}
}
func toModel(r api.ShowResponse, m string) Model { func toModel(r api.ShowResponse, m string) Model {
return Model{ return Model{
Id: m, Id: m,
@@ -269,10 +309,66 @@ func toModel(r api.ShowResponse, m string) Model {
} }
} }
func fromChatRequest(r ChatCompletionRequest) api.ChatRequest { func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
var messages []api.Message var messages []api.Message
for _, msg := range r.Messages { for _, msg := range r.Messages {
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content}) switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: msg.Role, Content: content})
case []any:
message := api.Message{Role: msg.Role}
for _, c := range content {
data, ok := c.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid message format")
}
switch data["type"] {
case "text":
text, ok := data["text"].(string)
if !ok {
return nil, fmt.Errorf("invalid message format")
}
message.Content = text
case "image_url":
var url string
if urlMap, ok := data["image_url"].(map[string]any); ok {
if url, ok = urlMap["url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
}
} else {
if url, ok = data["image_url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
}
}
types := []string{"jpeg", "jpg", "png"}
valid := false
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
}
}
if !valid {
return nil, fmt.Errorf("invalid image input")
}
img, err := base64.StdEncoding.DecodeString(url)
if err != nil {
return nil, fmt.Errorf("invalid message format")
}
message.Images = append(message.Images, img)
default:
return nil, fmt.Errorf("invalid message format")
}
}
messages = append(messages, message)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
} }
options := make(map[string]interface{}) options := make(map[string]interface{})
@@ -323,13 +419,13 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
format = "json" format = "json"
} }
return api.ChatRequest{ return &api.ChatRequest{
Model: r.Model, Model: r.Model,
Messages: messages, Messages: messages,
Format: format, Format: format,
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
} }, nil
} }
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
@@ -338,12 +434,16 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
switch stop := r.Stop.(type) { switch stop := r.Stop.(type) {
case string: case string:
options["stop"] = []string{stop} options["stop"] = []string{stop}
case []string: case []any:
options["stop"] = stop var stops []string
default: for _, s := range stop {
if r.Stop != nil { if str, ok := s.(string); ok {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop) stops = append(stops, str)
} else {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
}
} }
options["stop"] = stops
} }
if r.MaxTokens != nil { if r.MaxTokens != nil {
@@ -403,6 +503,11 @@ type RetrieveWriter struct {
model string model string
} }
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(code int, data []byte) (int, error) { func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
var serr api.StatusError var serr api.StatusError
err := json.Unmarshal(data, &serr) err := json.Unmarshal(data, &serr)
@@ -568,6 +673,33 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
return w.writeResponse(data) return w.writeResponse(data)
} }
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc { func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
w := &ListWriter{ w := &ListWriter{
@@ -631,6 +763,47 @@ func CompletionsMiddleware() gin.HandlerFunc {
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
} }
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w c.Writer = w
c.Next() c.Next()
@@ -652,7 +825,13 @@ func ChatMiddleware() gin.HandlerFunc {
} }
var b bytes.Buffer var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
chatReq, err := fromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return return
} }

View File

@@ -2,8 +2,8 @@ package openai
import ( import (
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -16,7 +16,253 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestMiddleware(t *testing.T) { const prefix = `data:image/jpeg;base64,`
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
const imageURL = prefix + image
func TestMiddlewareRequests(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
Handler func() gin.HandlerFunc
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *http.Request)
}
var capturedRequest *http.Request
captureRequestMiddleware := func() gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
capturedRequest = c.Request
c.Next()
}
}
testCases := []testCase{
{
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
}
if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
}
if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
Handler: CompletionsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var genReq api.GenerateRequest
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
t.Fatal(err)
}
if genReq.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
}
if genReq.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
}
stopTokens, ok := genReq.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
},
},
{
Name: "chat handler with image content",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{
Role: "user", Content: []map[string]any{
{"type": "text", "text": "Hello"},
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
},
},
},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
}
if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
}
if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
}
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
}
},
},
{
Name: "embed handler single input",
Method: http.MethodPost,
Path: "/api/embed",
Handler: EmbeddingsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: "Hello",
Model: "test-model",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var embedReq api.EmbedRequest
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
t.Fatal(err)
}
if embedReq.Input != "Hello" {
t.Fatalf("expected 'Hello', got %s", embedReq.Input)
}
if embedReq.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
}
},
},
{
Name: "embed handler batch input",
Method: http.MethodPost,
Path: "/api/embed",
Handler: EmbeddingsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: []string{"Hello", "World"},
Model: "test-model",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var embedReq api.EmbedRequest
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
t.Fatal(err)
}
input, ok := embedReq.Input.([]any)
if !ok {
t.Fatalf("expected input to be a list")
}
if input[0].(string) != "Hello" {
t.Fatalf("expected 'Hello', got %s", input[0])
}
if input[1].(string) != "World" {
t.Fatalf("expected 'World', got %s", input[1])
}
if embedReq.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
}
},
},
}
gin.SetMode(gin.TestMode)
router := gin.New()
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
router = gin.New()
router.Use(captureRequestMiddleware())
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, endpoint)
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest)
})
}
}
func TestMiddlewareResponses(t *testing.T) {
type testCase struct { type testCase struct {
Name string Name string
Method string Method string
@@ -30,159 +276,7 @@ func TestMiddleware(t *testing.T) {
testCases := []testCase{ testCases := []testCase{
{ {
Name: "chat handler", Name: "completions handler error forwarding",
Method: http.MethodPost,
Path: "/api/chat",
TestPath: "/api/chat",
Handler: ChatMiddleware,
Endpoint: func(c *gin.Context) {
var chatReq api.ChatRequest
if err := c.ShouldBindJSON(&chatReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
userMessage := chatReq.Messages[0].Content
var assistantMessage string
switch userMessage {
case "Hello":
assistantMessage = "Hello!"
default:
assistantMessage = "I'm not sure how to respond to that."
}
c.JSON(http.StatusOK, api.ChatResponse{
Message: api.Message{
Role: "assistant",
Content: assistantMessage,
},
})
},
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var chatResp ChatCompletion
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
t.Fatal(err)
}
if chatResp.Object != "chat.completion" {
t.Fatalf("expected chat.completion, got %s", chatResp.Object)
}
if chatResp.Choices[0].Message.Content != "Hello!" {
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.GenerateResponse{
Response: "Hello!",
})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Hello!" {
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with params",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
var generateReq api.GenerateRequest
if err := c.ShouldBindJSON(&generateReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
temperature := generateReq.Options["temperature"].(float64)
var assistantMessage string
switch temperature {
case 1.6:
assistantMessage = "Received temperature of 1.6"
default:
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
}
c.JSON(http.StatusOK, api.GenerateResponse{
Response: assistantMessage,
})
},
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Received temperature of 1.6" {
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with error",
Method: http.MethodPost, Method: http.MethodPost,
Path: "/api/generate", Path: "/api/generate",
TestPath: "/api/generate", TestPath: "/api/generate",

View File

@@ -107,9 +107,12 @@ function gatherDependencies() {
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output # TODO - this varies based on host build system and MSVC version - drive from dumpbin output
# currently works for Win11 + MSVC 2019 + Cuda V11 # currently works for Win11 + MSVC 2019 + Cuda V11
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\" cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140*.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\" cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\" cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
foreach ($part in $("runtime", "stdio", "filesystem", "math", "convert", "heap", "string", "time", "locale", "environment")) {
cp "$env:VCToolsRedistDir\..\..\..\Tools\Llvm\x64\bin\api-ms-win-crt-${part}*.dll" "${script:DEPS_DIR}\ollama_runners\"
}
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\" cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"

View File

@@ -34,9 +34,20 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
var (
errCapabilities = errors.New("does not support")
errCapabilityCompletion = errors.New("completion")
errCapabilityTools = errors.New("tools")
errCapabilityInsert = errors.New("insert")
)
type Capability string type Capability string
const CapabilityCompletion = Capability("completion") const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
)
type registryOptions struct { type registryOptions struct {
Insecure bool Insecure bool
@@ -62,7 +73,10 @@ type Model struct {
Template *template.Template Template *template.Template
} }
func (m *Model) Has(caps ...Capability) bool { // CheckCapabilities checks if the model has the specified capabilities returning an error describing
// any missing or unknown capabilities
func (m *Model) CheckCapabilities(caps ...Capability) error {
var errs []error
for _, cap := range caps { for _, cap := range caps {
switch cap { switch cap {
case CapabilityCompletion: case CapabilityCompletion:
@@ -81,15 +95,28 @@ func (m *Model) Has(caps ...Capability) bool {
} }
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok { if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
return false errs = append(errs, errCapabilityCompletion)
}
case CapabilityTools:
if !slices.Contains(m.Template.Vars(), "tools") {
errs = append(errs, errCapabilityTools)
}
case CapabilityInsert:
vars := m.Template.Vars()
if !slices.Contains(vars, "suffix") {
errs = append(errs, errCapabilityInsert)
} }
default: default:
slog.Error("unknown capability", "capability", cap) slog.Error("unknown capability", "capability", cap)
return false return fmt.Errorf("unknown capability: %s", cap)
} }
} }
return true if err := errors.Join(errs...); err != nil {
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
}
return nil
} }
func (m *Model) String() string { func (m *Model) String() string {

View File

@@ -4,6 +4,7 @@ import (
"archive/zip" "archive/zip"
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -11,6 +12,9 @@ import (
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"slices"
"strings"
"text/template/parse"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert" "github.com/ollama/ollama/convert"
@@ -289,3 +293,87 @@ func detectContentType(r io.Reader) (string, error) {
return "unknown", nil return "unknown", nil
} }
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// create a subtree from the node that ranges over .ToolCalls
tmpl := m.Template.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, false
}
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]map[string]any{
"ToolCalls": {
{
"Function": map[string]any{
"Name": "@@name@@",
"Arguments": "@@arguments@@",
},
},
},
}); err != nil {
return nil, false
}
var kv map[string]string
// execute the subtree with placeholders to identify the keys
// trim any commands that might exist in the template
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
return nil, false
}
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range kv {
switch v {
case "@@name@@":
name = k
case "@@arguments@@":
arguments = k
}
}
var objs []map[string]any
for offset := 0; offset < len(s); {
if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) {
break
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
// skip over any syntax errors
offset += int(syntax.Offset)
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
// skip over any unmarshalable types
offset += int(unmarshalType.Offset)
} else if err != nil {
return nil, false
} else {
// break when an object is decoded
break
}
}
var toolCalls []api.ToolCall
for _, kv := range objs {
var call api.ToolCall
for k, v := range kv {
switch k {
case name:
call.Function.Name = v.(string)
case arguments:
call.Function.Arguments = v.(map[string]any)
}
}
toolCalls = append(toolCalls, call)
}
return toolCalls, len(toolCalls) > 0
}

View File

@@ -3,7 +3,9 @@ package server
import ( import (
"archive/zip" "archive/zip"
"bytes" "bytes"
"encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
@@ -11,7 +13,9 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
) )
func createZipFile(t *testing.T, name string) *os.File { func createZipFile(t *testing.T, name string) *os.File {
@@ -110,3 +114,123 @@ func TestExtractFromZipFile(t *testing.T) {
}) })
} }
} }
type function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper()
bts, err := os.ReadFile(filepath.Join(base, name))
if err != nil {
t.Fatal(err)
}
return bytes.NewBuffer(bts)
}
func TestExecuteWithTools(t *testing.T) {
p := filepath.Join("testdata", "tools")
cases := []struct {
model string
output string
ok bool
}{
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"command-r-plus", "Action: ```json" + `
[
{
"tool_name": "get_current_weather",
"parameters": {
"format": "fahrenheit",
"location": "San Francisco, CA"
}
},
{
"tool_name": "get_current_weather",
"parameters": {
"format": "celsius",
"location": "Toronto, Canada"
}
}
]
` + "```", true},
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
}
var tools []api.Tool
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
t.Fatal(err)
}
var messages []api.Message
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
t.Fatal(err)
}
calls := []api.ToolCall{
{
Function: function{
Name: "get_current_weather",
Arguments: map[string]any{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
},
{
Function: function{
Name: "get_current_weather",
Arguments: map[string]any{
"format": "celsius",
"location": "Toronto, Canada",
},
},
},
}
for _, tt := range cases {
t.Run(tt.model, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
if err != nil {
t.Fatal(err)
}
t.Run("template", func(t *testing.T) {
var actual bytes.Buffer
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("parse", func(t *testing.T) {
m := &Model{Template: tmpl}
actual, ok := m.parseToolCalls(tt.output)
if ok != tt.ok {
t.Fatalf("expected %t, got %t", tt.ok, ok)
}
if tt.ok {
if diff := cmp.Diff(actual, calls); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
})
})
}
}

View File

@@ -1,217 +1,74 @@
package server package server
import ( import (
"fmt" "bytes"
"context"
"log/slog" "log/slog"
"strings"
"text/template/parse"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
) )
// isResponseNode checks if the node contains .Response type tokenizeFunc func(context.Context, string) ([]int, error)
func isResponseNode(node *parse.ActionNode) bool {
for _, cmd := range node.Pipe.Cmds { // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
for _, arg := range cmd.Args { // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 { // latest message and 2) system messages
if fieldNode.Ident[0] == "Response" { func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
return true var system []api.Message
} // always include the last message
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n - 1; i >= 0; i-- {
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
} }
} }
}
return false
}
// formatTemplateForResponse formats the template AST to: var b bytes.Buffer
// 1. remove all nodes after the first .Response (if generate=true) if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
// 2. add a .Response node to the end if it doesn't exist return "", nil, err
// TODO(jmorganca): this should recursively cut the template before the first .Response
func formatTemplateForResponse(tmpl *template.Template, generate bool) {
var found bool
for i, node := range tmpl.Tree.Root.Nodes {
if actionNode, ok := node.(*parse.ActionNode); ok {
if isResponseNode(actionNode) {
found = true
if generate {
tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
break
}
}
} }
}
if !found { s, err := tokenize(ctx, b.String())
// add the response node if it doesn't exist
responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
}
}
// Prompt renders a prompt from a template. If generate is set to true,
// the response and parts of the template following it are not rendered
func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
formatTemplateForResponse(tmpl, generate)
vars := map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
}
var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil {
return "", err
}
return sb.String(), nil
}
func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
rendered, err := Prompt(tmpl, system, prompt, response, false)
if err != nil {
return 0, err
}
tokens, err := encode(rendered)
if err != nil {
slog.Error("failed to encode prompt", "err", err)
return 0, err
}
return len(tokens), err
}
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
type prompt struct {
System string
Prompt string
Response string
images []int
tokens int
}
var p prompt
// iterate through messages to build up {system,user,response} prompts
var imgId int
var prompts []prompt
for _, msg := range messages {
switch strings.ToLower(msg.Role) {
case "system":
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.System = msg.Content
case "user":
if p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
var sb strings.Builder
for range msg.Images {
fmt.Fprintf(&sb, "[img-%d] ", imgId)
p.images = append(p.images, imgId)
imgId += 1
}
sb.WriteString(msg.Content)
p.Prompt = sb.String()
case "assistant":
if p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.Response = msg.Content
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// add final prompt
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
}
// calculate token lengths for each prompt, estimating 768 tokens per images
for i, p := range prompts {
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
if err != nil { if err != nil {
return "", err return "", nil, err
} }
prompts[i].tokens = tokens + len(prompts[i].images)*768 c := len(s)
} if m.ProjectorPaths != nil {
for _, m := range msgs[i:] {
// truncate images and prompts starting from the beginning of the list // images are represented as 768 sized embeddings
// until either one prompt remains or the total tokens fits the context window // TODO: get embedding length from project metadata
// TODO (jmorganca): this doesn't account for the context window room required for the response c += 768 * len(m.Images)
for { }
var required int
for _, p := range prompts {
required += p.tokens
} }
required += 1 // for bos token if c > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
if required <= window {
slog.Debug("prompt now fits in context window", "required", required, "window", window)
break break
} else {
n = i
} }
prompt := &prompts[0]
if len(prompt.images) > 1 {
img := prompt.images[0]
slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
prompt.images = prompt.images[1:]
prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
prompt.tokens -= 768
continue
}
if len(prompts) > 1 {
slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
system := prompt.System
prompts = prompts[1:]
if system != "" && prompts[0].System == "" {
prompts[0].System = system
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
if err != nil {
return "", err
}
prompts[0].tokens = tokens + len(prompts[0].images)*768
}
continue
}
// stop truncating if there's only one prompt left
break
} }
var sb strings.Builder // truncate any messages that do not fit into the context window
for i, p := range prompts { var b bytes.Buffer
// last prompt should leave the response unrendered (for completion) if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1) return "", nil, err
if err != nil {
return "", err
}
sb.WriteString(rendered)
} }
return sb.String(), nil for _, m := range msgs[n:] {
for _, i := range m.Images {
images = append(images, llm.ImageData{
ID: len(images),
Data: i,
})
}
}
return b.String(), images, nil
} }

View File

@@ -1,215 +1,209 @@
package server package server
import ( import (
"strings" "bytes"
"context"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
) )
func TestPrompt(t *testing.T) {
tests := []struct {
name string
template string
system string
prompt string
response string
generate bool
want string
}{
{
name: "simple prompt",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "implicit response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
},
{
name: "response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
},
{
name: "cut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
generate: true,
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
},
{
name: "nocut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.Parse(tc.template)
if err != nil {
t.Fatal(err)
}
got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
if err != nil {
t.Errorf("error = %v", err)
}
if got != tc.want {
t.Errorf("got = %v, want %v", got, tc.want)
}
})
}
}
func TestChatPrompt(t *testing.T) { func TestChatPrompt(t *testing.T) {
tests := []struct { type expect struct {
name string prompt string
template string images [][]byte
messages []api.Message }
window int
want string cases := []struct {
name string
limit int
msgs []api.Message
expect
}{ }{
{ {
name: "simple prompt", name: "messages",
template: "[INST] {{ .Prompt }} [/INST]", limit: 64,
messages: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "Hello"}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
}, },
window: 1024,
want: "[INST] Hello [/INST]",
}, },
{ {
name: "with system message", name: "truncate messages",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]", limit: 1,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Content: "Hello"}, {Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "A test. And a thumping good one at that, I'd wager. ",
}, },
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
}, },
{ {
name: "with response", name: "truncate messages with image",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}", limit: 64,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Content: "Hello"}, {Role: "assistant", Content: "I-I'm a what?"},
{Role: "assistant", Content: "I am?"}, {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
},
expect: expect{
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
},
}, },
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
}, },
{ {
name: "with implicit response", name: "truncate messages with images",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]", limit: 64,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Content: "Hello"}, {Role: "assistant", Content: "I-I'm a what?"},
{Role: "assistant", Content: "I am?"}, {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
}, },
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
}, },
{ {
name: "with conversation", name: "messages with images",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ", limit: 2048,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Content: "What are the potion ingredients?"}, {Role: "assistant", Content: "I-I'm a what?"},
{Role: "assistant", Content: "sugar"}, {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "user", Content: "Anything else?"}, },
expect: expect{
prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
}, },
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
}, },
{ {
name: "with truncation", name: "message with image tag",
template: "{{ .System }} {{ .Prompt }} {{ .Response }} ", limit: 2048,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Content: "Hello"}, {Role: "assistant", Content: "I-I'm a what?"},
{Role: "assistant", Content: "I am?"}, {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "user", Content: "Why is the sky blue?"}, },
{Role: "assistant", Content: "The sky is blue from rayleigh scattering"}, expect: expect{
prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
}, },
window: 10,
want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
}, },
{ {
name: "images", name: "messages with interleaved images",
template: "{{ .System }} {{ .Prompt }}", limit: 2048,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}}, {Role: "user", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
}, },
window: 1024,
want: "You are a Wizard. [img-0] Hello",
}, },
{ {
name: "images truncated", name: "truncate message with interleaved images",
template: "{{ .System }} {{ .Prompt }}", limit: 1024,
messages: []api.Message{ msgs: []api.Message{
{Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}}, {Role: "user", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
}, },
window: 1024,
want: "You are a Wizard. [img-0] [img-1] Hello",
}, },
{ {
name: "empty list", name: "message with system prompt",
template: "{{ .System }} {{ .Prompt }}", limit: 2048,
messages: []api.Message{}, msgs: []api.Message{
window: 1024, {Role: "system", Content: "You are the Test Who Lived."},
want: "", {Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
}, },
{ {
name: "empty prompt", name: "out of order system",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ", limit: 2048,
messages: []api.Message{ msgs: []api.Message{
{Role: "user", Content: ""}, {Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
}, },
window: 1024,
want: "",
}, },
} }
encode := func(s string) ([]int, error) { tmpl, err := template.Parse(`
words := strings.Fields(s) {{- if .System }}{{ .System }} {{ end }}
return make([]int, len(words)), nil {{- if .Prompt }}{{ .Prompt }} {{ end }}
{{- if .Response }}{{ .Response }} {{ end }}`)
if err != nil {
t.Fatal(err)
} }
for _, tc := range tests { for _, tt := range cases {
t.Run(tc.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tmpl, err := template.Parse(tc.template) model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode) if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
if err != nil { t.Errorf("mismatch (-got +want):\n%s", diff)
t.Errorf("error = %v", err)
} }
if got != tc.want { if len(images) != len(tt.images) {
t.Errorf("got: %q, want: %q", got, tc.want) t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
}
for i := range images {
if images[i].ID != i {
t.Errorf("expected ID %d, got %d", i, images[i].ID)
}
if !bytes.Equal(images[i].Data, tt.images[i]) {
t.Errorf("expected %q, got %q", tt.images[i], images[i])
}
} }
}) })
} }

View File

@@ -1,14 +1,15 @@
package server package server
import ( import (
"bytes"
"cmp" "cmp"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
@@ -54,6 +55,8 @@ func init() {
gin.SetMode(mode) gin.SetMode(mode)
} }
var errRequired = errors.New("is required")
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions() opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil { if err := opts.FromMap(model.Options); err != nil {
@@ -67,250 +70,225 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
return opts, nil return opts, nil
} }
func isSupportedImageType(image []byte) bool { // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
contentType := http.DetectContentType(image) // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"} func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
return slices.Contains(allowedTypes, contentType) if name == "" {
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
}
model, err := GetModel(name)
if err != nil {
return nil, nil, nil, err
}
if err := model.CheckCapabilities(caps...); err != nil {
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
opts, err := modelOptions(model, requestOpts)
if err != nil {
return nil, nil, nil, err
}
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err = <-errCh:
return nil, nil, nil, err
}
return runner.llama, model, &opts, nil
} }
func (s *Server) GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now() checkpointStart := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
// validate the request if req.Format != "" && req.Format != "json" {
switch { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
case len(req.Format) > 0 && req.Format != "json": } else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return return
} }
for _, img := range req.Images { caps := []Capability{CapabilityCompletion}
if !isSupportedImageType(img) { if req.Suffix != "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"}) caps = append(caps, CapabilityInsert)
return
}
} }
model, err := GetModel(req.Model) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if err != nil { if errors.Is(err, errCapabilityCompletion) {
var pErr *fs.PathError c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
if errors.As(err, &pErr) { return
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) } else if err != nil {
return handleScheduleError(c, req.Model, err)
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if !model.Has(CapabilityCompletion) { checkpointLoaded := time.Now()
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
return
}
opts, err := modelOptions(model, req.Options) if req.Prompt == "" {
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
// an empty request loads the model
// note: for a short while template was used in lieu
// of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{ c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true, Done: true,
DoneReason: "load", DoneReason: "load",
}) })
return return
} }
tmpl, err := template.Parse(req.Template) images := make([]llm.ImageData, len(req.Images))
if err != nil { for i := range req.Images {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
return
} }
checkpointLoaded := time.Now() prompt := req.Prompt
if !req.Raw {
var prompt string tmpl := m.Template
switch { if req.Template != "" {
case req.Raw: tmpl, err = template.Parse(req.Template)
prompt = req.Prompt if err != nil {
case req.Prompt != "": c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
if req.Template == "" { return
tmpl = model.Template }
} }
if req.System == "" { var b bytes.Buffer
req.System = model.System
}
slog.Debug("generate handler", "prompt", req.Prompt)
slog.Debug("generate handler", "template", req.Template)
slog.Debug("generate handler", "system", req.System)
var sb strings.Builder
for i := range req.Images {
fmt.Fprintf(&sb, "[img-%d] ", i)
}
sb.WriteString(req.Prompt)
p, err := Prompt(tmpl, req.System, sb.String(), "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.Reset()
if req.Context != nil { if req.Context != nil {
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context) s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sb.WriteString(prev) b.WriteString(s)
} }
sb.WriteString(p) var values template.Values
if req.Suffix != "" {
prompt = sb.String() values.Prompt = prompt
} values.Suffix = req.Suffix
} else {
slog.Debug("generate handler", "prompt", prompt) var msgs []api.Message
if req.System != "" {
ch := make(chan any) msgs = append(msgs, api.Message{Role: "system", Content: req.System})
var generated strings.Builder } else if m.System != "" {
go func() { msgs = append(msgs, api.Message{Role: "system", Content: m.System})
defer close(ch)
fn := func(r llm.CompletionResponse) {
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
return
} }
resp := api.GenerateResponse{ for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
}
if err := tmpl.Execute(&b, values); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt = b.String()
}
slog.Debug("generate request", "prompt", prompt, "images", images)
ch := make(chan any)
go func() {
// TODO (jmorganca): avoid building the response twice both here and below
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Done: r.Done, Response: cr.Content,
Response: r.Content, Done: cr.Done,
DoneReason: r.DoneReason, DoneReason: cr.DoneReason,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: cr.EvalCount,
EvalDuration: r.EvalDuration, EvalDuration: cr.EvalDuration,
}, },
} }
if r.Done { if _, err := sb.WriteString(cr.Content); err != nil {
resp.TotalDuration = time.Since(checkpointStart) ch <- gin.H{"error": err.Error()}
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) }
if cr.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw { if !req.Raw {
p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false) tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
if err != nil { if err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
return return
} }
res.Context = append(req.Context, tokens...)
resp.Context = append(req.Context, tokens...)
} }
} }
ch <- resp ch <- res
} }); err != nil {
var images []llm.ImageData
for i := range req.Images {
images = append(images, llm.ImageData{
ID: i,
Data: req.Images[i],
})
}
// Start prediction
req := llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if req.Stream != nil && !*req.Stream { if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response var r api.GenerateResponse
var final api.GenerateResponse
var sb strings.Builder var sb strings.Builder
for resp := range ch { for rr := range ch {
switch r := resp.(type) { switch t := rr.(type) {
case api.GenerateResponse: case api.GenerateResponse:
sb.WriteString(r.Response) sb.WriteString(t.Response)
final = r r = t
case gin.H: case gin.H:
if errorMsg, ok := r["error"].(string); ok { msg, ok := t["error"].(string)
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg}) if !ok {
return msg = "unexpected error format in response"
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default: default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return return
} }
} }
final.Response = sb.String() r.Response = sb.String()
c.JSON(http.StatusOK, final) if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
r.ToolCalls = toolCalls
r.Response = ""
}
c.JSON(http.StatusOK, r)
return return
} }
streamResponse(c, ch) streamResponse(c, ch)
} }
func (s *Server) EmbeddingsHandler(c *gin.Context) { func (s *Server) EmbedHandler(c *gin.Context) {
var req api.EmbeddingRequest var req api.EmbedRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
case errors.Is(err, io.EOF): case errors.Is(err, io.EOF):
@@ -321,34 +299,122 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
if req.Model == "" { truncate := true
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
if req.Truncate != nil && !*req.Truncate {
truncate = false
}
var input []string
switch i := req.Input.(type) {
case string:
if len(i) > 0 {
input = append(input, i)
}
case []any:
for _, v := range i {
if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
input = append(input, v.(string))
}
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return return
} }
model, err := GetModel(req.Model) if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil { if err != nil {
var pErr *fs.PathError handleScheduleError(c, req.Model, err)
if errors.As(err, &pErr) { return
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) }
kvData, err := getKVData(m.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options) ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
input[i] = s
}
embeddings, err := r.Embed(c.Request.Context(), input)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return return
} }
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive) for i, e := range embeddings {
var runner *runnerRef embeddings[i] = normalize(e)
select { }
case runner = <-rCh:
case err = <-eCh: resp := api.EmbedResponse{
handleErrorResponse(c, err) Model: req.Model,
Embeddings: embeddings,
}
c.JSON(http.StatusOK, resp)
}
func normalize(vec []float32) []float32 {
var sum float32
for _, v := range vec {
sum += v * v
}
norm := float32(0.0)
if sum > 0 {
norm = float32(1.0 / math.Sqrt(float64(sum)))
}
for i := range vec {
vec[i] *= norm
}
return vec
}
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return return
} }
@@ -358,13 +424,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt) embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return return
} }
embedding := make([]float64, len(embeddings[0]))
for i, v := range embeddings[0] {
embedding[i] = float64(v)
}
resp := api.EmbeddingResponse{ resp := api.EmbeddingResponse{
Embedding: embedding, Embedding: embedding,
} }
@@ -642,16 +715,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
m.System = req.System m.System = req.System
} }
if req.Template != "" { msgs := make([]api.Message, len(m.Messages))
m.Template, err = template.Parse(req.Template) for i, msg := range m.Messages {
if err != nil { msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
return nil, err
}
}
msgs := make([]api.Message, 0)
for _, msg := range m.Messages {
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
} }
n := model.ParseName(req.Model) n := model.ParseName(req.Model)
@@ -994,6 +1060,7 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/pull", s.PullModelHandler) r.POST("/api/pull", s.PullModelHandler)
r.POST("/api/generate", s.GenerateHandler) r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler) r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler) r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", s.CreateModelHandler) r.POST("/api/create", s.CreateModelHandler)
r.POST("/api/push", s.PushModelHandler) r.POST("/api/push", s.PushModelHandler)
@@ -1007,6 +1074,7 @@ func (s *Server) GenerateRoutes() http.Handler {
// Compatibility endpoints // Compatibility endpoints
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
@@ -1214,132 +1282,67 @@ func (s *Server) ProcessHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models}) c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
} }
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return runner.llama.Tokenize(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
if err != nil {
return "", err
}
return prompt, nil
}
func (s *Server) ChatHandler(c *gin.Context) { func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now() checkpointStart := time.Now()
var req api.ChatRequest var req api.ChatRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
// validate the request caps := []Capability{CapabilityCompletion}
switch { if req.Tools != nil {
case req.Model == "": caps = append(caps, CapabilityTools)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
} }
model, err := GetModel(req.Model) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if err != nil { if errors.Is(err, errCapabilityCompletion) {
var pErr *fs.PathError c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} } else if err != nil {
handleScheduleError(c, req.Model, err)
if !model.Has(CapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return return
} }
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
// if the first message is not a system message, then add the model's default system message if len(req.Messages) == 0 {
if len(req.Messages) > 0 && req.Messages[0].Role != "system" { c.JSON(http.StatusOK, api.ChatResponse{
req.Messages = append([]api.Message{
{
Role: "system",
Content: model.System,
},
}, req.Messages...)
}
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// an empty request loads the model
if len(req.Messages) == 0 || prompt == "" {
resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant"},
Done: true, Done: true,
DoneReason: "load", DoneReason: "load",
Message: api.Message{Role: "assistant"}, })
}
c.JSON(http.StatusOK, resp)
return return
} }
// only send images that are in the prompt if req.Messages[0].Role != "system" && m.System != "" {
var i int req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
var images []llm.ImageData
for _, m := range req.Messages {
for _, img := range m.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
images = append(images, llm.ImageData{Data: img, ID: i})
}
i += 1
}
} }
slog.Debug("chat handler", "prompt", prompt, "images", len(images)) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
slog.Debug("chat request", "images", len(images), "prompt", prompt)
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
fn := func(r llm.CompletionResponse) { Prompt: prompt,
resp := api.ChatResponse{ Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content}, Message: api.Message{Role: "assistant", Content: r.Content},
@@ -1354,62 +1357,62 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
if r.Done { if r.Done {
resp.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
ch <- resp ch <- res
} }); err != nil {
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if req.Stream != nil && !*req.Stream { if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response var resp api.ChatResponse
var final api.ChatResponse
var sb strings.Builder var sb strings.Builder
for resp := range ch { for rr := range ch {
switch r := resp.(type) { switch t := rr.(type) {
case api.ChatResponse: case api.ChatResponse:
sb.WriteString(r.Message.Content) sb.WriteString(t.Message.Content)
final = r resp = t
case gin.H: case gin.H:
if errorMsg, ok := r["error"].(string); ok { msg, ok := t["error"].(string)
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg}) if !ok {
return msg = "unexpected error format in response"
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default: default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return return
} }
} }
final.Message = api.Message{Role: "assistant", Content: sb.String()} resp.Message.Content = sb.String()
c.JSON(http.StatusOK, final) if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
c.JSON(http.StatusOK, resp)
return return
} }
streamResponse(c, ch) streamResponse(c, ch)
} }
func handleErrorResponse(c *gin.Context, err error) { func handleScheduleError(c *gin.Context, name string, err error) {
if errors.Is(err, context.Canceled) { switch {
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, context.Canceled):
c.JSON(499, gin.H{"error": "request canceled"}) c.JSON(499, gin.H{"error": "request canceled"})
return case errors.Is(err, ErrMaxQueue):
}
if errors.Is(err, ErrMaxQueue) {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()}) c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
return case errors.Is(err, os.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }

View File

@@ -85,6 +85,8 @@ func checkFileExists(t *testing.T, p string, expect []string) {
} }
func TestCreateFromBin(t *testing.T) { func TestCreateFromBin(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -111,6 +113,8 @@ func TestCreateFromBin(t *testing.T) {
} }
func TestCreateFromModel(t *testing.T) { func TestCreateFromModel(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -152,6 +156,8 @@ func TestCreateFromModel(t *testing.T) {
} }
func TestCreateRemovesLayers(t *testing.T) { func TestCreateRemovesLayers(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -199,6 +205,8 @@ func TestCreateRemovesLayers(t *testing.T) {
} }
func TestCreateUnsetsSystem(t *testing.T) { func TestCreateUnsetsSystem(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -255,6 +263,8 @@ func TestCreateUnsetsSystem(t *testing.T) {
} }
func TestCreateMergeParameters(t *testing.T) { func TestCreateMergeParameters(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -358,6 +368,8 @@ func TestCreateMergeParameters(t *testing.T) {
} }
func TestCreateReplacesMessages(t *testing.T) { func TestCreateReplacesMessages(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -434,6 +446,8 @@ func TestCreateReplacesMessages(t *testing.T) {
} }
func TestCreateTemplateSystem(t *testing.T) { func TestCreateTemplateSystem(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -480,6 +494,8 @@ func TestCreateTemplateSystem(t *testing.T) {
} }
func TestCreateLicenses(t *testing.T) { func TestCreateLicenses(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -526,6 +542,8 @@ func TestCreateLicenses(t *testing.T) {
} }
func TestCreateDetectTemplate(t *testing.T) { func TestCreateDetectTemplate(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -545,9 +563,9 @@ func TestCreateDetectTemplate(t *testing.T) {
} }
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"),
filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"), filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"),
filepath.Join(p, "blobs", "sha256-f836ee110db21567f826332e4cedd746c06d10664fd5a9ea3659e3683a944510"),
}) })
}) })

View File

@@ -8,12 +8,15 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -77,6 +80,8 @@ func TestDelete(t *testing.T) {
} }
func TestDeleteDuplicateLayers(t *testing.T) { func TestDeleteDuplicateLayers(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
var s Server var s Server

View File

@@ -0,0 +1,712 @@
package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
)
type mockRunner struct {
llm.LlamaServer
// CompletionRequest is only valid until the next call to Completion
llm.CompletionRequest
llm.CompletionResponse
}
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
m.CompletionRequest = r
fn(m.CompletionResponse)
return nil
}
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
for range strings.Fields(s) {
tokens = append(tokens, len(tokens))
}
return
}
func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return mock, nil
}
}
func TestGenerateChat(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add 10ms delay to simulate loading
time.Sleep(10 * time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities chat", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []llm.Tensor{})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "bert",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var actual api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != "test" {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done true, got false")
}
if actual.DoneReason != "load" {
t.Errorf("expected done reason load, got %s", actual.DoneReason)
}
})
checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
t.Helper()
var actual api.ChatResponse
if err := json.NewDecoder(body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != model {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done false, got true")
}
if actual.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
}
if diff := cmp.Diff(actual.Message, api.Message{
Role: "assistant",
Content: content,
}); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
if actual.PromptEvalCount == 0 {
t.Errorf("expected prompt eval count > 0, got 0")
}
if actual.PromptEvalDuration == 0 {
t.Errorf("expected prompt eval duration > 0, got 0")
}
if actual.EvalCount == 0 {
t.Errorf("expected eval count > 0, got 0")
}
if actual.EvalDuration == 0 {
t.Errorf("expected eval duration > 0, got 0")
}
if actual.LoadDuration == 0 {
t.Errorf("expected load duration > 0, got 0")
}
if actual.TotalDuration == 0 {
t.Errorf("expected total duration > 0, got 0")
}
}
mock.CompletionResponse.Content = "Hi!"
t.Run("messages", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test", "Hi!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-system",
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("messages with model system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Hi!")
})
mock.CompletionResponse.Content = "Abra kadabra!"
t.Run("messages with system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
})
t.Run("messages with interleaved system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
{Role: "assistant", Content: "I can help you with that."},
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Help me write tests."},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
})
}
func TestGenerate(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities generate", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []llm.Tensor{})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "bert",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var actual api.GenerateResponse
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != "test" {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done true, got false")
}
if actual.DoneReason != "load" {
t.Errorf("expected done reason load, got %s", actual.DoneReason)
}
})
checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
t.Helper()
var actual api.GenerateResponse
if err := json.NewDecoder(body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != model {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done false, got true")
}
if actual.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
}
if actual.Response != content {
t.Errorf("expected response %s, got %s", content, actual.Response)
}
if actual.Context == nil {
t.Errorf("expected context not nil")
}
if actual.PromptEvalCount == 0 {
t.Errorf("expected prompt eval count > 0, got 0")
}
if actual.PromptEvalDuration == 0 {
t.Errorf("expected prompt eval duration > 0, got 0")
}
if actual.EvalCount == 0 {
t.Errorf("expected eval count > 0, got 0")
}
if actual.EvalDuration == 0 {
t.Errorf("expected eval duration > 0, got 0")
}
if actual.LoadDuration == 0 {
t.Errorf("expected load duration > 0, got 0")
}
if actual.TotalDuration == 0 {
t.Errorf("expected total duration > 0, got 0")
}
}
mock.CompletionResponse.Content = "Hi!"
t.Run("prompt", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello!",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test", "Hi!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-system",
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("prompt with model system", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Hello!",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
})
mock.CompletionResponse.Content = "Abra kadabra!"
t.Run("prompt with system", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Hello!",
System: "You can perform magic tricks.",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
})
t.Run("prompt with template", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Help me write tests.",
System: "You can perform magic tricks.",
Template: `{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-suffix",
Modelfile: `FROM test
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}"""`,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("prompt with suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("prompt without suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("raw", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Help me write tests.",
Raw: true,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}

View File

@@ -7,11 +7,14 @@ import (
"slices" "slices"
"testing" "testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
) )
func TestList(t *testing.T) { func TestList(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
envconfig.LoadConfig() envconfig.LoadConfig()

View File

@@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@@ -272,6 +273,77 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, "library", retrieveResp.OwnedBy) assert.Equal(t, "library", retrieveResp.OwnedBy)
}, },
}, },
{
Name: "Embed Handler Empty Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: "",
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
var embedResp api.EmbedResponse
err = json.Unmarshal(body, &embedResp)
if err != nil {
t.Fatal(err)
}
if embedResp.Model != "t-bone" {
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
}
if embedResp.Embeddings == nil {
t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
}
if len(embedResp.Embeddings) != 0 {
t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
}
},
},
{
Name: "Embed Handler Invalid Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: 2,
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
_, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected status code 400, got %d", resp.StatusCode)
}
},
},
} }
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
@@ -420,3 +492,38 @@ func TestShow(t *testing.T) {
t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"]) t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
} }
} }
func TestNormalize(t *testing.T) {
type testCase struct {
input []float32
}
testCases := []testCase{
{input: []float32{1}},
{input: []float32{0, 1, 2, 3}},
{input: []float32{0.1, 0.2, 0.3}},
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
{input: []float32{0, 0, 0}},
}
isNormalized := func(vec []float32) (res bool) {
sum := 0.0
for _, v := range vec {
sum += float64(v * v)
}
if math.Abs(sum-1) > 1e-6 {
return sum == 0
} else {
return true
}
}
for _, tc := range testCases {
t.Run("", func(t *testing.T) {
normalized := normalize(tc.input)
if !isNormalized(normalized) {
t.Errorf("Vector %v is not normalized", tc.input)
}
})
}
}

View File

@@ -133,17 +133,8 @@ func (s *Scheduler) processPending(ctx context.Context) {
numParallel = 1 numParallel = 1
slog.Warn("multimodal models don't support parallel requests yet") slog.Warn("multimodal models don't support parallel requests yet")
} }
// Keep NumCtx and numParallel in sync
if numParallel > 1 {
pending.opts.NumCtx = pending.origNumCtx * numParallel
}
for { for {
cpus := s.getCpuFn()
var systemMem gpu.GpuInfo
if len(cpus) > 0 {
systemMem = cpus[0]
}
var runnerToExpire *runnerRef var runnerToExpire *runnerRef
s.loadedMu.Lock() s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath] runner := s.loaded[pending.model.ModelPath]
@@ -197,46 +188,15 @@ func (s *Scheduler) processPending(ctx context.Context) {
break break
} }
estimate := llm.EstimateGPULayers(gpus, ggml, pending.model.ProjectorPaths, pending.opts)
maxSize := systemMem.FreeMemory
// Add available GPU memory to the total pool
// macOS hardware has unified memory so don't double count
if runtime.GOOS != "darwin" {
for _, gpu := range gpus {
if gpu.Library == "cpu" {
continue
}
if loadedCount == 0 {
// If no other models are loaded, set the limit based on what's available
maxSize += gpu.FreeMemory
} else {
// Other models could be unloaded, favor total memory for limit
maxSize += gpu.TotalMemory
}
}
}
// Block attempting to load a model larger than system memory + GPU memory
if estimate.TotalSize > maxSize {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(estimate.TotalSize), "system", format.HumanBytes2(maxSize))
// Linux will crash if over-allocating memory - return an error to the user.
// TODO (jmorganca): add reasonable upper limits for darwin and windows as well
if runtime.GOOS == "linux" {
pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
break
}
}
// Evaluate if the model will fit in the available system memory, or if we should unload a model first // Evaluate if the model will fit in the available system memory, or if we should unload a model first
if len(gpus) == 1 && gpus[0].Library == "cpu" { if len(gpus) == 1 && gpus[0].Library == "cpu" {
// simplifying assumption of defaultParallel when in CPU mode // simplifying assumption of defaultParallel when in CPU mode
if numParallel <= 0 { if numParallel <= 0 {
numParallel = defaultParallel numParallel = defaultParallel
pending.opts.NumCtx = pending.origNumCtx * numParallel
} }
pending.opts.NumCtx = pending.origNumCtx * numParallel
if loadedCount == 0 { if loadedCount == 0 {
slog.Debug("cpu mode with first model, loading") slog.Debug("cpu mode with first model, loading")
s.loadFn(pending, ggml, gpus, numParallel) s.loadFn(pending, ggml, gpus, numParallel)

View File

@@ -642,8 +642,8 @@ type mockLlm struct {
pingResp error pingResp error
waitResp error waitResp error
completionResp error completionResp error
embeddingResp []float64 embedResp [][]float32
embeddingRespErr error embedRespErr error
tokenizeResp []int tokenizeResp []int
tokenizeRespErr error tokenizeRespErr error
detokenizeResp string detokenizeResp string
@@ -660,8 +660,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp return s.completionResp
} }
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) { func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
return s.embeddingResp, s.embeddingRespErr return s.embedResp, s.embedRespErr
} }
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
return s.tokenizeResp, s.tokenizeRespErr return s.tokenizeResp, s.tokenizeRespErr

View File

@@ -0,0 +1,67 @@
{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
{{- if .Tools }}# Safety Preamble
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
# System Preamble
## Basic Rules
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
{{ if .System }}# User Preamble
{{ .System }}
{{- end }}
## Available Tools
Here is a list of tools that you have available to you:
{{- range .Tools }}
```python
def {{ .Function.Name }}(
{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]:
"""{{ .Function.Description }}
{{- if .Function.Parameters.Properties }}
Args:
{{- range $name, $property := .Function.Parameters.Properties }}
{{ $name }} ({{ $property.Type }}): {{ $property.Description }}
{{- end }}
{{- end }}
"""
pass
```
{{- end }}
{{- else if .System }}{{ .System }}
{{- end }}<|END_OF_TURN_TOKEN|>
{{- end }}
{{- range .Messages }}
{{- if eq .Role "system" }}
{{- continue }}
{{- end }}<|START_OF_TURN_TOKEN|>
{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }}
{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}
Action: ```json
[
{{- range .ToolCalls }}
{
"tool_name": "{{ .Function.Name }}",
"parameters": {{ json .Function.Arguments }}
}
{{- end }}
]```
{{ continue }}
{{ end }}
{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|><results>
{{ .Content }}</results>
{{- end }}<|END_OF_TURN_TOKEN|>
{{- end }}
{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
```json
[
{
"tool_name": title of the tool in the specification,
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
}
]```
{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@@ -0,0 +1,39 @@
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
# System Preamble
## Basic Rules
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
# User Preamble
You are a knowledgable assistant. You can answer questions and perform tasks.
## Available Tools
Here is a list of tools that you have available to you:
```python
def get_current_weather(format: string, location: string, ) -> List[Dict]:
"""Get the current weather
Args:
format (string): The temperature unit to use. Infer this from the users location.
location (string): The city and state, e.g. San Francisco, CA
"""
pass
```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
Action: ```json
[
{
"tool_name": "get_current_weather",
"parameters": {"format":"celsius","location":"Paris, France"}
}
]```
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>
22</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
```json
[
{
"tool_name": title of the tool in the specification,
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
}
]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@@ -0,0 +1,31 @@
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
{{- if .System }}
{{ .System }}
{{- end }}
In addition to plain text responses, you can chose to call one or more of the provided functions.
Use the following rule to decide when to call a function:
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
If you decide to call functions:
* prefix function calls with functools marker (no closing marker required)
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
* make sure you pick the right functions that match the user intent
Available functions as JSON spec:
{{- if .Tools }}
{{ json .Tools }}
{{- end }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>
{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }}
{{- end }}<|end_header_id|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }} functools[
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
{{- end }}]
{{- end }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>

17
server/testdata/tools/firefunction.out vendored Normal file
View File

@@ -0,0 +1,17 @@
<|start_header_id|>system<|end_header_id|>
You are a knowledgable assistant. You can answer questions and perform tasks.
In addition to plain text responses, you can chose to call one or more of the provided functions.
Use the following rule to decide when to call a function:
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
If you decide to call functions:
* prefix function calls with functools marker (no closing marker required)
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
* make sure you pick the right functions that match the user intent
Available functions as JSON spec:
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

39
server/testdata/tools/messages.json vendored Normal file
View File

@@ -0,0 +1,39 @@
[
{
"role": "system",
"content": "You are a knowledgable assistant. You can answer questions and perform tasks."
},
{
"role": "user",
"content": "What's the weather like today in Paris?"
},
{
"role": "assistant",
"tool_calls": [
{
"id": "89a1e453-0bce-4de3-a456-c54bed09c520",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": {
"location": "Paris, France",
"format": "celsius"
}
}
}
]
},
{
"role": "tool",
"tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520",
"content": "22"
},
{
"role": "assistant",
"content": "The current temperature in Paris, France is 22 degrees Celsius."
},
{
"role": "user",
"content": "What's the weather like today in San Francisco and Toronto?"
}
]

15
server/testdata/tools/mistral.gotmpl vendored Normal file
View File

@@ -0,0 +1,15 @@
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
{{ end }}{{ .Content }}[/INST]
{{- else if eq .Role "assistant" }}
{{- if .Content }} {{ .Content }}</s>
{{- else if .ToolCalls }}[TOOL_CALLS] [
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
{{- end }}]</s>
{{- end }}
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
{{- end }}
{{- end }}

3
server/testdata/tools/mistral.out vendored Normal file
View File

@@ -0,0 +1,3 @@
[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]</s>[TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.</s>[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgable assistant. You can answer questions and perform tasks.
What's the weather like today in San Francisco and Toronto?[/INST]

30
server/testdata/tools/tools.json vendored Normal file
View File

@@ -0,0 +1,30 @@
[
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
],
"description": "The temperature unit to use. Infer this from the users location."
}
},
"required": [
"location",
"format"
]
}
}
}
]

View File

@@ -4,4 +4,5 @@
{{ .Prompt }} {{ .Prompt }}
{{ end }}### Response: {{ end }}### Response:
{{ .Response }} {{ .Response }}

View File

@@ -3,4 +3,4 @@
{{ end }}{{ if .Prompt }}<|im_start|>user {{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|> {{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant {{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|> {{ .Response }}<|im_end|>

View File

@@ -2,4 +2,5 @@
{{ end }}{{ if .Prompt }}User: {{ .Prompt }} {{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: <|begin_of_text|>{{ .Response }} {{ end }}Assistant: {{ .Response }}

View File

@@ -1,8 +1,10 @@
{{ if .System }} Source: system {{ if .System }}Source: system
{{ .System }} <step>{{ end }} Source: user {{ .System }} <step> {{ end }}Source: user
{{ .Prompt }} <step> Source: assistant {{ .Prompt }} <step> Source: assistant
{{- if not .Response }}
Destination: user Destination: user
{{- end }}
{{ .Response }}<step> {{ .Response }} <step>

View File

@@ -1,3 +1,5 @@
{{ if .System }}{{ .System }} {{ if .System }}System: {{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }} {{ end }}{{ if .Prompt }}User:
{{ end }}Assistant: {{ .Response }} {{ .Prompt }}
{{ end }}Falcon:
{{ .Response }}

View File

@@ -1,4 +1,5 @@
<start_of_turn>user <start_of_turn>user
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn> {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}<end_of_turn>
<start_of_turn>model <start_of_turn>model
{{ .Response }}<end_of_turn> {{ .Response }}<end_of_turn>

View File

@@ -1,9 +1,9 @@
{{ if .System }} {{ if .System }}System:
System:
{{ .System }} {{ .System }}
{{ end }}{{ if .Prompt }}Question: {{ end }}{{ if .Prompt }}Question:
{{ .Prompt }} {{ .Prompt }}
{{ end }}Answer: {{ end }}Answer:
{{ .Response }} {{ .Response }}

View File

@@ -1,3 +1,6 @@
[INST] <<SYS>>{{ .System }}<</SYS>> [INST] <<SYS>>
{{- if .System }}
{{ .System }}
{{ end }}<</SYS>>
{{ .Prompt }} [/INST] {{ .Response }} {{ .Prompt }} [/INST] {{ .Response }}</s><s>

View File

@@ -4,4 +4,5 @@
{{ .Prompt }} {{ .Prompt }}
{{ end }}@@ Response {{ end }}@@ Response
{{ .Response }} {{ .Response }}

View File

@@ -1,6 +1,3 @@
{{ if .System }}<|im_start|>system [INST] {{ if .System }}{{ .System }}
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user {{ end }}{{ .Prompt }}[/INST] {{ .Response }}</s>
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>

View File

@@ -1 +1 @@
{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|> {{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>{{ end }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>

View File

@@ -3,4 +3,4 @@
{{ end }}{{ if .Prompt }}<|user|> {{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}<|end|> {{ .Prompt }}<|end|>
{{ end }}<|assistant|> {{ end }}<|assistant|>
{{ .Response }}<|end|> {{ .Response }}<|end|>

View File

@@ -5,4 +5,5 @@
{{ .Prompt }} {{ .Prompt }}
{{ end }}### Assistant: {{ end }}### Assistant:
{{ .Response }} {{ .Response }}</s>

View File

@@ -3,7 +3,6 @@
{{ end }}{{ if .Prompt }}### Instruction {{ end }}{{ if .Prompt }}### Instruction
{{ .Prompt }} {{ .Prompt }}
{{ end }}### Response {{ end }}### Response
{{ .Response }}<|endoftext|> {{ .Response }}<|endoftext|>

View File

@@ -5,6 +5,7 @@ import (
"embed" "embed"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"math" "math"
"slices" "slices"
@@ -14,6 +15,7 @@ import (
"text/template/parse" "text/template/parse"
"github.com/agnivade/levenshtein" "github.com/agnivade/levenshtein"
"github.com/ollama/ollama/api"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
@@ -74,30 +76,66 @@ func Named(s string) (*named, error) {
return nil, errors.New("no matching template found") return nil, errors.New("no matching template found")
} }
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
type Template struct { type Template struct {
*template.Template *template.Template
raw string raw string
} }
// response is a template node that can be added to templates that don't already have one
var response = parse.ActionNode{
NodeType: parse.NodeAction,
Pipe: &parse.PipeNode{
NodeType: parse.NodePipe,
Cmds: []*parse.CommandNode{
{
NodeType: parse.NodeCommand,
Args: []parse.Node{
&parse.FieldNode{
NodeType: parse.NodeField,
Ident: []string{"Response"},
},
},
},
},
},
}
var funcs = template.FuncMap{
"json": func(v any) string {
b, _ := json.Marshal(v)
return string(b)
},
}
func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
tmpl, err := tmpl.Parse(s)
if err != nil {
return nil, err
}
t := Template{Template: tmpl, raw: s}
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
// touch up the template and append {{ .Response }}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
}
return &t, nil
}
func (t *Template) String() string { func (t *Template) String() string {
return t.raw return t.raw
} }
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
func Parse(s string) (*Template, error) {
t, err := template.New("").Option("missingkey=zero").Parse(s)
if err != nil {
return nil, err
}
return &Template{Template: t, raw: s}, nil
}
func (t *Template) Vars() []string { func (t *Template) Vars() []string {
var vars []string var vars []string
for _, n := range t.Tree.Root.Nodes { for _, tt := range t.Templates() {
vars = append(vars, parseNode(n)...) for _, n := range tt.Root.Nodes {
vars = append(vars, Identifiers(n)...)
}
} }
set := make(map[string]struct{}) set := make(map[string]struct{})
@@ -110,49 +148,282 @@ func (t *Template) Vars() []string {
return vars return vars
} }
func parseNode(n parse.Node) []string { type Values struct {
Messages []api.Message
Tools []api.Tool
Prompt string
Suffix string
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool
}
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
var walk func(parse.Node) parse.Node
walk = func(n parse.Node) parse.Node {
if fn(n) {
return n
}
switch t := n.(type) {
case *parse.ListNode:
for _, c := range t.Nodes {
if n := walk(c); n != nil {
return n
}
}
case *parse.BranchNode:
for _, n := range []*parse.ListNode{t.List, t.ElseList} {
if n != nil {
if n := walk(n); n != nil {
return n
}
}
}
case *parse.IfNode:
return walk(&t.BranchNode)
case *parse.WithNode:
return walk(&t.BranchNode)
case *parse.RangeNode:
return walk(&t.BranchNode)
}
return nil
}
if n := walk(t.Tree.Root); n != nil {
return (&template.Template{
Tree: &parse.Tree{
Root: &parse.ListNode{
Nodes: []parse.Node{n},
},
},
}).Funcs(funcs)
}
return nil
}
func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages)
if v.Prompt != "" && v.Suffix != "" {
return t.Template.Execute(w, map[string]any{
"Prompt": v.Prompt,
"Suffix": v.Suffix,
"Response": "",
})
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": messages,
"Tools": v.Tools,
})
}
system = ""
var b bytes.Buffer
var prompt, response string
for _, m := range messages {
execute := func() error {
if err := t.Template.Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
}); err != nil {
return err
}
system = ""
prompt = ""
response = ""
return nil
}
switch m.Role {
case "system":
if prompt != "" || response != "" {
if err := execute(); err != nil {
return err
}
}
system = m.Content
case "user":
if response != "" {
if err := execute(); err != nil {
return err
}
}
prompt = m.Content
case "assistant":
response = m.Content
}
}
var cut bool
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
cut = true
}
return cut
})
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
}); err != nil {
return err
}
_, err := io.Copy(w, &b)
return err
}
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also collects and returns all system messages.
// collate mutates message content adding image tags ([img-%d]) as needed
func collate(msgs []api.Message) (string, []*api.Message) {
var n int
var system []string
var collated []*api.Message
for i := range msgs {
msg := msgs[i]
for range msg.Images {
imageTag := fmt.Sprintf("[img-%d]", n)
if !strings.Contains(msg.Content, "[img]") {
msg.Content = strings.TrimSpace("[img] " + msg.Content)
}
msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
n++
}
if msg.Role == "system" {
system = append(system, msg.Content)
}
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
collated[len(collated)-1].Content += "\n\n" + msg.Content
} else {
collated = append(collated, &msg)
}
}
return strings.Join(system, "\n\n"), collated
}
// Identifiers walks the node tree returning any identifiers it finds along the way
func Identifiers(n parse.Node) []string {
switch n := n.(type) { switch n := n.(type) {
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, Identifiers(n)...)
}
return names
case *parse.TemplateNode:
return Identifiers(n.Pipe)
case *parse.ActionNode: case *parse.ActionNode:
return parseNode(n.Pipe) return Identifiers(n.Pipe)
case *parse.BranchNode:
names := Identifiers(n.Pipe)
for _, n := range []*parse.ListNode{n.List, n.ElseList} {
if n != nil {
names = append(names, Identifiers(n)...)
}
}
return names
case *parse.IfNode: case *parse.IfNode:
names := parseNode(n.Pipe) return Identifiers(&n.BranchNode)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.RangeNode: case *parse.RangeNode:
names := parseNode(n.Pipe) return Identifiers(&n.BranchNode)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.WithNode: case *parse.WithNode:
names := parseNode(n.Pipe) return Identifiers(&n.BranchNode)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.PipeNode: case *parse.PipeNode:
var names []string var names []string
for _, c := range n.Cmds { for _, c := range n.Cmds {
for _, a := range c.Args { for _, a := range c.Args {
names = append(names, parseNode(a)...) names = append(names, Identifiers(a)...)
} }
} }
return names
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, parseNode(n)...)
}
return names return names
case *parse.FieldNode: case *parse.FieldNode:
return n.Ident return n.Ident
case *parse.VariableNode:
return n.Ident
} }
return nil return nil
} }
// deleteNode walks the node list and deletes nodes that match the predicate
// this is currently to remove the {{ .Response }} node from templates
func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
var walk func(n parse.Node) parse.Node
walk = func(n parse.Node) parse.Node {
if fn(n) {
return nil
}
switch t := n.(type) {
case *parse.ListNode:
var nodes []parse.Node
for _, c := range t.Nodes {
if n := walk(c); n != nil {
nodes = append(nodes, n)
}
}
t.Nodes = nodes
return t
case *parse.IfNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.WithNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.RangeNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.BranchNode:
t.List = walk(t.List).(*parse.ListNode)
if t.ElseList != nil {
t.ElseList = walk(t.ElseList).(*parse.ListNode)
}
case *parse.ActionNode:
n := walk(t.Pipe)
if n == nil {
return nil
}
t.Pipe = n.(*parse.PipeNode)
case *parse.PipeNode:
var commands []*parse.CommandNode
for _, c := range t.Cmds {
var args []parse.Node
for _, a := range c.Args {
if n := walk(a); n != nil {
args = append(args, n)
}
}
if len(args) == 0 {
return nil
}
c.Args = args
commands = append(commands, c)
}
if len(commands) == 0 {
return nil
}
t.Cmds = commands
}
return n
}
return walk(n)
}

View File

@@ -8,9 +8,11 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
"strings"
"testing" "testing"
"text/template"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
@@ -46,7 +48,7 @@ func TestNamed(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
tmpl, err := template.New(s).Parse(b.String()) tmpl, err := Parse(b.String())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -59,18 +61,125 @@ func TestNamed(t *testing.T) {
} }
} }
func TestTemplate(t *testing.T) {
cases := make(map[string][]api.Message)
for _, mm := range [][]api.Message{
{
{Role: "user", Content: "Hello, how are you?"},
},
{
{Role: "user", Content: "Hello, how are you?"},
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
{Role: "user", Content: "I'd like to show off how chat templating works!"},
},
{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
{Role: "user", Content: "I'd like to show off how chat templating works!"},
},
} {
var roles []string
for _, m := range mm {
roles = append(roles, m.Role)
}
cases[strings.Join(roles, "-")] = mm
}
matches, err := filepath.Glob("*.gotmpl")
if err != nil {
t.Fatal(err)
}
for _, match := range matches {
t.Run(match, func(t *testing.T) {
bts, err := os.ReadFile(match)
if err != nil {
t.Fatal(err)
}
tmpl, err := Parse(string(bts))
if err != nil {
t.Fatal(err)
}
for n, tt := range cases {
var actual bytes.Buffer
t.Run(n, func(t *testing.T) {
if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
t.Fatal(err)
}
expect, err := os.ReadFile(filepath.Join("testdata", match, n))
if err != nil {
t.Fatal(err)
}
bts := actual.Bytes()
if slices.Contains([]string{"chatqa.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && bts[len(bts)-1] == ' ' {
t.Log("removing trailing space from output")
bts = bts[:len(bts)-1]
}
if diff := cmp.Diff(bts, expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("legacy", func(t *testing.T) {
t.Skip("legacy outputs are currently default outputs")
var legacy bytes.Buffer
if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
t.Fatal(err)
}
legacyBytes := legacy.Bytes()
if slices.Contains([]string{"chatqa.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && legacyBytes[len(legacyBytes)-1] == ' ' {
t.Log("removing trailing space from legacy output")
legacyBytes = legacyBytes[:len(legacyBytes)-1]
} else if slices.Contains([]string{"codellama-70b-instruct.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl"}, match) {
t.Skip("legacy outputs cannot be compared to messages outputs")
}
if diff := cmp.Diff(legacyBytes, actual.Bytes()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
})
}
}
func TestParse(t *testing.T) { func TestParse(t *testing.T) {
cases := []struct { cases := []struct {
template string template string
vars []string vars []string
}{ }{
{"{{ .Prompt }}", []string{"prompt"}}, {"{{ .Prompt }}", []string{"prompt", "response"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}}, {"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}}, {"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}}, {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}}, {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}}, {`{{- range .Messages }}
{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}}, {{- if eq .Role "system" }}SYSTEM:
{{- else if eq .Role "user" }}USER:
{{- else if eq .Role "assistant" }}ASSISTANT:
{{- end }} {{ .Content }}
{{- end }}`, []string{"content", "messages", "role"}},
{`{{- if .Messages }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else -}}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
} }
for _, tt := range cases { for _, tt := range cases {
@@ -80,9 +189,207 @@ func TestParse(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
vars := tmpl.Vars() if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
if !slices.Equal(tt.vars, vars) { t.Errorf("mismatch (-got +want):\n%s", diff)
t.Errorf("expected %v, got %v", tt.vars, vars) }
})
}
}
func TestExecuteWithMessages(t *testing.T) {
type template struct {
name string
template string
}
cases := []struct {
name string
templates []template
values Values
expected string
}{
{
"mistral",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `[INST] {{ if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
},
{
"mistral system",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `[INST] {{ if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant!"},
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`[INST] You are a helpful assistant!
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
},
{
"chatml",
[]template{
// this does not have a "no response" test because it's impossible to render the same output
{"response", `{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
`},
{"messages", `
{{- range $index, $_ := .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
`},
},
Values{
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant!"},
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
Hello friend!<|im_end|>
<|im_start|>assistant
Hello human!<|im_end|>
<|im_start|>user
What is your name?<|im_end|>
<|im_start|>assistant
`,
},
{
"moondream",
[]template{
// this does not have a "no response" test because it's impossible to render the same output
{"response", `{{ if .Prompt }}Question: {{ .Prompt }}
{{ end }}Answer: {{ .Response }}
`},
{"messages", `
{{- range .Messages }}
{{- if eq .Role "user" }}Question: {{ .Content }}
{{ else if eq .Role "assistant" }}Answer: {{ .Content }}
{{ end }}
{{- end }}Answer: `},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
{Role: "assistant", Content: "It's a hot dog."},
{Role: "user", Content: "What's in _this_ image?"},
{Role: "user", Images: []api.ImageData{[]byte("")}},
{Role: "user", Content: "Is it a hot dog?"},
},
},
`Question: [img-0] What's in this image?
Answer: It's a hot dog.
Question: What's in _this_ image?
[img-1]
Is it a hot dog?
Answer: `,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
for _, ttt := range tt.templates {
t.Run(ttt.name, func(t *testing.T) {
tmpl, err := Parse(ttt.template)
if err != nil {
t.Fatal(err)
}
var b bytes.Buffer
if err := tmpl.Execute(&b, tt.values); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(b.String(), tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
})
}
}
func TestExecuteWithSuffix(t *testing.T) {
tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}`)
if err != nil {
t.Fatal(err)
}
cases := []struct {
name string
values Values
expect string
}{
{
"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
},
{
"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
var b bytes.Buffer
if err := tmpl.Execute(&b, tt.values); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
} }
}) })
} }

View File

@@ -0,0 +1 @@
<start_system>You are a helpful assistant.<end_message><start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>

1
template/testdata/alfred.gotmpl/user vendored Normal file
View File

@@ -0,0 +1 @@
<start_user>Hello, how are you?<end_message><start_assistant>

View File

@@ -0,0 +1 @@
<start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>

View File

@@ -0,0 +1,12 @@
You are a helpful assistant.
### Instruction:
Hello, how are you?
### Response:
I'm doing great. How can I help you today?
### Instruction:
I'd like to show off how chat templating works!
### Response:

4
template/testdata/alpaca.gotmpl/user vendored Normal file
View File

@@ -0,0 +1,4 @@
### Instruction:
Hello, how are you?
### Response:

View File

@@ -0,0 +1,10 @@
### Instruction:
Hello, how are you?
### Response:
I'm doing great. How can I help you today?
### Instruction:
I'd like to show off how chat templating works!
### Response:

View File

@@ -0,0 +1,9 @@
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
I'm doing great. How can I help you today?<|im_end|>
<|im_start|>user
I'd like to show off how chat templating works!<|im_end|>
<|im_start|>assistant

3
template/testdata/chatml.gotmpl/user vendored Normal file
View File

@@ -0,0 +1,3 @@
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant

View File

@@ -0,0 +1,7 @@
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
I'm doing great. How can I help you today?<|im_end|>
<|im_start|>user
I'd like to show off how chat templating works!<|im_end|>
<|im_start|>assistant

View File

@@ -0,0 +1,9 @@
System: You are a helpful assistant.
User: Hello, how are you?
Assistant: I'm doing great. How can I help you today?
User: I'd like to show off how chat templating works!
Assistant:

3
template/testdata/chatqa.gotmpl/user vendored Normal file
View File

@@ -0,0 +1,3 @@
User: Hello, how are you?
Assistant:

View File

@@ -0,0 +1,7 @@
User: Hello, how are you?
Assistant: I'm doing great. How can I help you today?
User: I'd like to show off how chat templating works!
Assistant:

View File

@@ -0,0 +1,12 @@
Source: system
You are a helpful assistant. <step> Source: user
Hello, how are you? <step> Source: assistant
I'm doing great. How can I help you today? <step> Source: user
I'd like to show off how chat templating works! <step> Source: assistant
Destination: user

View File

@@ -0,0 +1,6 @@
Source: user
Hello, how are you? <step> Source: assistant
Destination: user

View File

@@ -0,0 +1,10 @@
Source: user
Hello, how are you? <step> Source: assistant
I'm doing great. How can I help you today? <step> Source: user
I'd like to show off how chat templating works! <step> Source: assistant
Destination: user

View File

@@ -0,0 +1,8 @@
System: You are a helpful assistant.
User:
Hello, how are you?
Falcon:
I'm doing great. How can I help you today?
User:
I'd like to show off how chat templating works!
Falcon:

View File

@@ -0,0 +1,3 @@
User:
Hello, how are you?
Falcon:

View File

@@ -0,0 +1,7 @@
User:
Hello, how are you?
Falcon:
I'm doing great. How can I help you today?
User:
I'd like to show off how chat templating works!
Falcon:

View File

@@ -0,0 +1,8 @@
<start_of_turn>user
You are a helpful assistant.
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model

View File

@@ -0,0 +1,3 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model

View File

@@ -0,0 +1,7 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model

View File

@@ -0,0 +1,13 @@
System:
You are a helpful assistant.
Question:
Hello, how are you?
Answer:
I'm doing great. How can I help you today?
Question:
I'd like to show off how chat templating works!
Answer:

View File

@@ -0,0 +1,4 @@
Question:
Hello, how are you?
Answer:

View File

@@ -0,0 +1,10 @@
Question:
Hello, how are you?
Answer:
I'm doing great. How can I help you today?
Question:
I'd like to show off how chat templating works!
Answer:

View File

@@ -0,0 +1,7 @@
[INST] <<SYS>>
You are a helpful assistant.
<</SYS>>
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] <<SYS>><</SYS>>
I'd like to show off how chat templating works! [/INST]

View File

@@ -0,0 +1,3 @@
[INST] <<SYS>><</SYS>>
Hello, how are you? [/INST]

View File

@@ -0,0 +1,5 @@
[INST] <<SYS>><</SYS>>
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] <<SYS>><</SYS>>
I'd like to show off how chat templating works! [/INST]

View File

@@ -0,0 +1,10 @@
<|start_header_id|>system<|end_header_id|>
You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>
Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
I'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>
I'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -0,0 +1,4 @@
<|start_header_id|>user<|end_header_id|>
Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -0,0 +1,8 @@
<|start_header_id|>user<|end_header_id|>
Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
I'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>
I'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Some files were not shown because too many files have changed in this diff Show More