Compare commits

...

57 Commits

Author SHA1 Message Date
Roy Han
2647a0e443 num parallel embed 2024-07-26 15:18:35 -07:00
Michael Yang
ec4c35fe99 Merge pull request #5512 from ollama/mxyng/detect-stop
autodetect stop parameters from template
2024-07-26 13:48:23 -07:00
Jeffrey Morgan
f5e3939220 Update api.md (#5968) 2024-07-25 23:10:18 -04:00
Jeffrey Morgan
ae27d9dcfd Update openai.md 2024-07-25 20:27:33 -04:00
Michael Yang
37096790a7 Merge pull request #5552 from ollama/mxyng/messages-docs
docs
2024-07-25 16:26:19 -07:00
Michael Yang
997c903884 Update docs/template.md
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-07-25 16:23:40 -07:00
Blake Mizerany
c8af3c2d96 server: reuse original download URL for images (#5962)
This changes the registry client to reuse the original download URL
it gets on the first redirect response for all subsequent requests,
preventing thundering herd issues when hot new LLMs are released.
2024-07-25 15:58:30 -07:00
Jeffrey Morgan
455e61170d Update openai.md 2024-07-25 18:34:47 -04:00
royjhan
4de1370a9d openai tools doc (#5617) 2024-07-25 18:34:06 -04:00
Jeffrey Morgan
bbf8f102ee Revert "llm(llama): pass rope factors (#5924)" (#5963)
This reverts commit bb46bbcf5e.
2024-07-25 18:24:55 -04:00
Michael Yang
bb46bbcf5e llm(llama): pass rope factors (#5924) 2024-07-24 16:05:59 -04:00
royjhan
ac33aa7d37 Fix Embed Test Flakes (#5893)
* float cmp

* increase tolerance
2024-07-24 11:15:46 -07:00
Ajay Chintala
a6cd8f6169 Update README.md to add LLMStack integration (#5799) 2024-07-23 14:40:23 -04:00
Daniel Hiltgen
c78089263a Merge pull request #5864 from dhiltgen/bump_go
Bump Go patch version
2024-07-22 16:34:18 -07:00
Daniel Hiltgen
3e5ea035d5 Merge pull request #5757 from lreed-mdsol/lreed/bump-go-version-fix-vulnerabilities
bump go version to 1.22.5 to fix security vulnerabilities in docker
2024-07-22 16:32:43 -07:00
Daniel Hiltgen
5d604eec5b Bump Go patch version 2024-07-22 16:16:28 -07:00
Josh
db0968f30c fix dupe err message (#5857) 2024-07-22 15:48:15 -07:00
Michael Yang
9b60a038e5 update api.md 2024-07-22 13:49:51 -07:00
Michael Yang
83a0cb8d88 docs 2024-07-22 13:38:09 -07:00
royjhan
c0648233f2 api embed docs (#5282) 2024-07-22 13:37:08 -07:00
Jeffrey Morgan
d835368eb8 convert: capture head_dim for mistral (#5818) 2024-07-22 16:16:22 -04:00
Daniel Hiltgen
5784c05397 Merge pull request #5854 from dhiltgen/win_exit_status
Refine error reporting for subprocess crash
2024-07-22 10:40:22 -07:00
Daniel Hiltgen
f14aa5435d Merge pull request #5855 from dhiltgen/remove_max_vram
Remove no longer supported max vram var
2024-07-22 10:35:29 -07:00
Jeffrey Morgan
f8fedbda20 Update llama.cpp submodule commit to d94c6e0c (#5805) 2024-07-22 12:42:00 -04:00
Jeffrey Morgan
b3e5491e41 server: collect nested tool call objects when parsing (#5824) 2024-07-22 12:38:03 -04:00
Daniel Hiltgen
cc269ba094 Remove no longer supported max vram var
The OLLAMA_MAX_VRAM env var was a temporary workaround for OOM
scenarios.  With Concurrency this was no longer wired up, and the simplistic
value doesn't map to multi-GPU setups.  Users can still set `num_gpu`
to limit memory usage to avoid OOM if we get our predictions wrong.
2024-07-22 09:08:11 -07:00
Daniel Hiltgen
a3c20e3f18 Refine error reporting for subprocess crash
On windows, the exit status winds up being the search term many
users search for and end up piling in on issues that are unrelated.
This refines the reporting so that if we have a more detailed message
we'll suppress the exit status portion of the message.
2024-07-22 08:52:16 -07:00
Jeffrey Morgan
80ee9b5e47 Remove out of space test temporarily (#5825) 2024-07-21 00:22:11 -04:00
Jeffrey Morgan
5534f2cc6a llm: consider head_dim in llama arch (#5817) 2024-07-20 21:48:12 -04:00
Daniel Hiltgen
d321297d8a Merge pull request #5815 from dhiltgen/win_rocm_gfx_features
Adjust windows ROCm discovery
2024-07-20 16:02:55 -07:00
Daniel Hiltgen
06e5d74e34 Merge pull request #5506 from dhiltgen/sched_tests
Refine scheduler unit tests for reliability
2024-07-20 15:48:39 -07:00
Daniel Hiltgen
5d707e6fd5 Merge pull request #5583 from dhiltgen/integration_improvements
Fix context exhaustion integration test for small gpus
2024-07-20 15:48:21 -07:00
Daniel Hiltgen
283948c83b Adjust windows ROCm discovery
The v5 hip library returns unsupported GPUs which wont enumerate at
inference time in the runner so this makes sure we align discovery.  The
gfx906 cards are no longer supported so we shouldn't compile with that
GPU type as it wont enumerate at runtime.
2024-07-20 15:17:50 -07:00
Jeffrey Morgan
1475eab95f add patch for tekken (#5807) 2024-07-20 13:41:21 -04:00
Jeffrey Morgan
20090f3172 preserve last assistant message (#5802) 2024-07-19 20:19:26 -07:00
Jeffrey Morgan
69a2d4ccff Fix generate test flakyness (#5804) 2024-07-19 19:11:25 -07:00
Josh
e8b954c646 server: validate template (#5734)
add template validation to modelfile
2024-07-19 15:24:29 -07:00
royjhan
c57317cbf0 OpenAI: Function Based Testing (#5752)
* distinguish error forwarding

* more coverage

* rm comment
2024-07-19 11:37:12 -07:00
royjhan
51b2fd299c adjust openai chat msg processing (#5729) 2024-07-19 11:19:20 -07:00
Michael Yang
d0634b1596 Merge pull request #5780 from ollama/mxyng/tools
fix parsing tool calls: break on unexpected eofs
2024-07-18 12:14:10 -07:00
Michael Yang
43606d6d6a fix parsing tool calls 2024-07-18 12:08:11 -07:00
Jeffrey Morgan
70b1010fa5 server: check for empty tools array too (#5779) 2024-07-18 11:44:57 -07:00
Jeffrey Morgan
84e5721f3a always provide content even if empty (#5778) 2024-07-18 11:28:19 -07:00
Jeffrey Morgan
319fb1ce03 server: only parse tool calls if tools are provided (#5771)
* server: only parse tool calls if tools are provided

* still set `resp.Message.Content`
2024-07-18 08:50:23 -07:00
Michael Yang
b255445557 marshal json automatically for some template values (#5758) 2024-07-17 15:35:11 -07:00
lreed
f02f83660c bump go version to 1.22.5 to fix security vulnerabilities 2024-07-17 21:44:19 +00:00
Michael Yang
b23424bb3c Merge pull request #5753 from ollama/mxyng/parse-tool-call
parse tool call as individual objects
2024-07-17 11:47:53 -07:00
Michael Yang
5fd6988126 parse tool call as individual objects 2024-07-17 11:19:04 -07:00
Michael Yang
5b82960df8 stub response (#5750) 2024-07-17 10:39:22 -07:00
Michael Yang
cc9a252d8c Merge pull request #5732 from ollama/mxyng/cleanup
remove ToolCall from GenerateResponse
2024-07-17 10:26:54 -07:00
Pákozdi György
d281a6e603 add sidellama link (#5702) 2024-07-17 10:24:44 -07:00
royjhan
154f6f45d4 OpenAI: Support Tools (#5614)
* reopen pr

* tools

* remove tc from stream for now

* ID and Function

* openai expects arguments to be a string (#5739)

* mutually exclusive content and tool calls

* clean up

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-07-16 20:52:59 -07:00
royjhan
0d41623b52 OpenAI: Add Suffix to v1/completions (#5611)
* add suffix

* remove todo

* remove TODO

* add to test

* rm outdated prompt tokens info md

* fix test

* fix test
2024-07-16 20:50:14 -07:00
Michael Yang
c279f96371 remove ToolCall from GenerateResponse 2024-07-16 15:22:49 -07:00
Michael Yang
ebc529cbb3 autodetect stop parameters from template 2024-07-12 16:01:23 -07:00
Daniel Hiltgen
73e2c8f68f Fix context exhaustion integration test for small gpus
On the smaller GPUs, the initial model load of llama2 took over 30s (the
default timeout for the DoGenerate helper)
2024-07-09 16:24:14 -07:00
Daniel Hiltgen
f4408219e9 Refine scheduler unit tests for reliability
This breaks up some of the test scenarios to create a
more reliable set of tests, as well as adding a little more
coverage.
2024-07-09 16:00:08 -07:00
65 changed files with 1938 additions and 493 deletions

View File

@@ -31,7 +31,7 @@ jobs:
security set-keychain-settings -lut 3600 build.keychain security set-keychain-settings -lut 3600 build.keychain
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- name: Build Darwin - name: Build Darwin
env: env:
@@ -87,7 +87,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- run: go get ./... - run: go get ./...
- run: | - run: |
@@ -141,7 +141,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- name: 'Install ROCm' - name: 'Install ROCm'
run: | run: |
@@ -218,7 +218,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- name: 'Install CUDA' - name: 'Install CUDA'
run: | run: |
@@ -306,7 +306,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- run: go get - run: go get
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4

View File

@@ -63,7 +63,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- run: go get ./... - run: go get ./...
- run: | - run: |
@@ -163,7 +163,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- name: 'Install ROCm' - name: 'Install ROCm'
run: | run: |
@@ -200,7 +200,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- name: 'Install CUDA' - name: 'Install CUDA'
run: | run: |
@@ -255,7 +255,7 @@ jobs:
submodules: recursive submodules: recursive
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: false cache: false
- run: | - run: |
case ${{ matrix.arch }} in case ${{ matrix.arch }} in
@@ -297,7 +297,7 @@ jobs:
submodules: recursive submodules: recursive
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- run: | - run: |
case ${{ matrix.arch }} in case ${{ matrix.arch }} in

View File

@@ -1,4 +1,4 @@
ARG GOLANG_VERSION=1.22.1 ARG GOLANG_VERSION=1.22.5
ARG CMAKE_VERSION=3.22.1 ARG CMAKE_VERSION=3.22.1
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md # this CUDA_VERSION corresponds with the one specified in docs/gpu.md
ARG CUDA_VERSION=11.3.1 ARG CUDA_VERSION=11.3.1

View File

@@ -64,7 +64,8 @@ Here are some example models that can be downloaded:
| LLaVA | 7B | 4.5GB | `ollama run llava` | | LLaVA | 7B | 4.5GB | `ollama run llava` |
| Solar | 10.7B | 6.1GB | `ollama run solar` | | Solar | 10.7B | 6.1GB | `ollama run solar` |
> Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models. > [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
## Customize a model ## Customize a model
@@ -295,6 +296,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama) - [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS) - [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
- [AI Studio](https://github.com/MindWorkAI/AI-Studio) - [AI Studio](https://github.com/MindWorkAI/AI-Studio)
- [Sidellama](https://github.com/gyopak/sidellama) (browser-based LLM client)
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
### Terminal ### Terminal

View File

@@ -101,46 +101,29 @@ type ChatRequest struct {
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Tools is an optional list of tools the model has access to. // Tools is an optional list of tools the model has access to.
Tools []Tool `json:"tools,omitempty"` Tools `json:"tools,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
type Tools []Tool
func (t Tools) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
// Message is a single message in a chat sequence. The message contains the // Message is a single message in a chat sequence. The message contains the
// role ("system", "user", or "assistant"), the content and an optional list // role ("system", "user", or "assistant"), the content and an optional list
// of images. // of images.
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content,omitempty"` Content string `json:"content"`
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} }
type ToolCall struct {
Function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
} `json:"function"`
}
type Tool struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
} `json:"function"`
}
func (m *Message) UnmarshalJSON(b []byte) error { func (m *Message) UnmarshalJSON(b []byte) error {
type Alias Message type Alias Message
var a Alias var a Alias
@@ -153,6 +136,46 @@ func (m *Message) UnmarshalJSON(b []byte) error {
return nil return nil
} }
type ToolCall struct {
Function ToolCallFunction `json:"function"`
}
type ToolCallFunction struct {
Name string `json:"name"`
Arguments ToolCallFunctionArguments `json:"arguments"`
}
type ToolCallFunctionArguments map[string]any
func (t *ToolCallFunctionArguments) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
type Tool struct {
Type string `json:"type"`
Function ToolFunction `json:"function"`
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
}
func (t *ToolFunction) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are // ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse]. // similar to [GenerateResponse].
type ChatResponse struct { type ChatResponse struct {
@@ -405,9 +428,6 @@ type GenerateResponse struct {
// Response is the textual response itself. // Response is the textual response itself.
Response string `json:"response"` Response string `json:"response"`
// ToolCalls is the list of tools the model wants to call
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
// Done specifies if the response is complete. // Done specifies if the response is complete.
Done bool `json:"done"` Done bool `json:"done"`

View File

@@ -1344,7 +1344,6 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_TMPDIR"], envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"], envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_LLM_LIBRARY"], envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_MAX_VRAM"],
}) })
default: default:
appendEnvDocs(cmd, envs) appendEnvDocs(cmd, envs)

View File

@@ -71,6 +71,11 @@ func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
"tokenizer.ggml.unknown_token_id": uint32(0), "tokenizer.ggml.unknown_token_id": uint32(0),
} }
if m.Params.HeadDimension > 0 {
kv["llama.attention.key_length"] = uint32(m.Params.HeadDimension)
kv["llama.attention.value_length"] = uint32(m.Params.HeadDimension)
}
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
} }

View File

@@ -40,6 +40,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
- `model`: (required) the [model name](#model-names) - `model`: (required) the [model name](#model-names)
- `prompt`: the prompt to generate a response for - `prompt`: the prompt to generate a response for
- `suffix`: the text after the model response
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`) - `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
Advanced parameters (optional): Advanced parameters (optional):
@@ -57,7 +58,8 @@ Advanced parameters (optional):
Enable JSON mode by setting the `format` parameter to `json`. This will structure the response as a valid JSON object. See the JSON mode [example](#request-json-mode) below. Enable JSON mode by setting the `format` parameter to `json`. This will structure the response as a valid JSON object. See the JSON mode [example](#request-json-mode) below.
> Note: it's important to instruct the model to use JSON in the `prompt`. Otherwise, the model may generate large amounts whitespace. > [!IMPORTANT]
> It's important to instruct the model to use JSON in the `prompt`. Otherwise, the model may generate large amounts whitespace.
### Examples ### Examples
@@ -148,8 +150,44 @@ If `stream` is set to `false`, the response will be a single JSON object:
} }
``` ```
#### Request (with suffix)
##### Request
```shell
curl http://localhost:11434/api/generate -d '{
"model": "codellama:code",
"prompt": "def compute_gcd(a, b):",
"suffix": " return result",
"options": {
"temperature": 0
},
"stream": false
}'
```
##### Response
```json
{
"model": "codellama:code",
"created_at": "2024-07-22T20:47:51.147561Z",
"response": "\n if a == 0:\n return b\n else:\n return compute_gcd(b % a, a)\n\ndef compute_lcm(a, b):\n result = (a * b) / compute_gcd(a, b)\n",
"done": true,
"done_reason": "stop",
"context": [...],
"total_duration": 1162761250,
"load_duration": 6683708,
"prompt_eval_count": 17,
"prompt_eval_duration": 201222000,
"eval_count": 63,
"eval_duration": 953997000
}
```
#### Request (JSON mode) #### Request (JSON mode)
> [!IMPORTANT]
> When `format` is set to `json`, the output will always be a well-formed JSON object. It's important to also instruct the model to respond in JSON. > When `format` is set to `json`, the output will always be a well-formed JSON object. It's important to also instruct the model to respond in JSON.
##### Request ##### Request
@@ -380,12 +418,14 @@ Generate the next message in a chat with a provided model. This is a streaming e
- `model`: (required) the [model name](#model-names) - `model`: (required) the [model name](#model-names)
- `messages`: the messages of the chat, this can be used to keep a chat memory - `messages`: the messages of the chat, this can be used to keep a chat memory
- `tools`: tools for the model to use if supported. Requires `stream` to be set to `false`
The `message` object has the following fields: The `message` object has the following fields:
- `role`: the role of the message, either `system`, `user` or `assistant` - `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
- `content`: the content of the message - `content`: the content of the message
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`) - `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
- `tool_calls` (optional): a list of tools the model wants to use
Advanced parameters (optional): Advanced parameters (optional):
@@ -622,6 +662,79 @@ curl http://localhost:11434/api/chat -d '{
} }
``` ```
#### Chat request (with tools)
##### Request
```
curl http://localhost:11434/api/chat -d '{
"model": "mistral",
"messages": [
{
"role": "user",
"content": "What is the weather today in Paris?"
}
],
"stream": false,
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location to get the weather for, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location", "format"]
}
}
}
]
}'
```
##### Response
```json
{
"model": "mistral:7b-instruct-v0.3-q4_K_M",
"created_at": "2024-07-22T20:33:28.123648Z",
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_current_weather",
"arguments": {
"format": "celsius",
"location": "Paris, FR"
}
}
}
]
},
"done_reason": "stop",
"done": true,
"total_duration": 885095291,
"load_duration": 3753500,
"prompt_eval_count": 122,
"prompt_eval_duration": 328493000,
"eval_count": 33,
"eval_duration": 552222000
}
```
## Create a Model ## Create a Model
```shell ```shell
@@ -1026,7 +1139,7 @@ If `stream` is set to `false`, then the response is a single JSON object:
## Generate Embeddings ## Generate Embeddings
```shell ```shell
POST /api/embeddings POST /api/embed
``` ```
Generate embeddings from a model Generate embeddings from a model
@@ -1034,10 +1147,11 @@ Generate embeddings from a model
### Parameters ### Parameters
- `model`: name of model to generate embeddings from - `model`: name of model to generate embeddings from
- `prompt`: text to generate embeddings for - `input`: text or list of text to generate embeddings for
Advanced parameters: Advanced parameters:
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
@@ -1046,9 +1160,9 @@ Advanced parameters:
#### Request #### Request
```shell ```shell
curl http://localhost:11434/api/embeddings -d '{ curl http://localhost:11434/api/embed -d '{
"model": "all-minilm", "model": "all-minilm",
"prompt": "Here is an article about llamas..." "input": "Why is the sky blue?"
}' }'
``` ```
@@ -1056,10 +1170,35 @@ curl http://localhost:11434/api/embeddings -d '{
```json ```json
{ {
"embedding": [ "model": "all-minilm",
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313, "embeddings": [[
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281 0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
] 0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348
]]
}
```
#### Request (Multiple input)
```shell
curl http://localhost:11434/api/embed -d '{
"model": "all-minilm",
"input": ["Why is the sky blue?", "Why is the grass green?"]
}'
```
#### Response
```json
{
"model": "all-minilm",
"embeddings": [[
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348
],[
-0.0098027075, 0.06042469, 0.025257962, -0.006364387, 0.07272725,
0.017194884, 0.09032035, -0.051705178, 0.09951512, 0.09072481
]]
} }
``` ```
@@ -1106,3 +1245,45 @@ A single JSON object will be returned.
] ]
} }
``` ```
## Generate Embedding
> Note: this endpoint has been superseded by `/api/embed`
```shell
POST /api/embeddings
```
Generate embeddings from a model
### Parameters
- `model`: name of model to generate embeddings from
- `prompt`: text to generate embeddings for
Advanced parameters:
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
### Examples
#### Request
```shell
curl http://localhost:11434/api/embeddings -d '{
"model": "all-minilm",
"prompt": "Here is an article about llamas..."
}'
```
#### Response
```json
{
"embedding": [
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
]
}
```

View File

@@ -46,13 +46,24 @@ sudo modprobe nvidia_uvm`
## AMD Radeon ## AMD Radeon
Ollama supports the following AMD GPUs: Ollama supports the following AMD GPUs:
### Linux Support
| Family | Cards and accelerators | | Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | | -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` | | AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` | | AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` | | AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
### Overrides ### Windows Support
With ROCm v6.1, the following GPUs are supported on Windows.
| Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
### Overrides on Linux
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
some cases you can force the system to try to use a similar LLVM target that is some cases you can force the system to try to use a similar LLVM target that is
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4) close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
@@ -63,7 +74,7 @@ would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the
server. If you have an unsupported AMD GPU you can experiment using the list of server. If you have an unsupported AMD GPU you can experiment using the list of
supported types below. supported types below.
At this time, the known supported GPU types are the following LLVM Targets. At this time, the known supported GPU types on linux are the following LLVM Targets.
This table shows some example GPUs that map to these LLVM targets: This table shows some example GPUs that map to these LLVM targets:
| **LLVM Target** | **An Example GPU** | | **LLVM Target** | **An Example GPU** |
|-----------------|---------------------| |-----------------|---------------------|

View File

@@ -1,6 +1,7 @@
# Ollama Model File # Ollama Model File
> Note: `Modelfile` syntax is in development > [!NOTE]
> `Modelfile` syntax is in development
A model file is the blueprint to create and share models with Ollama. A model file is the blueprint to create and share models with Ollama.

View File

@@ -78,8 +78,8 @@ curl http://localhost:11434/v1/chat/completions \
- [x] Streaming - [x] Streaming
- [x] JSON mode - [x] JSON mode
- [x] Reproducible outputs - [x] Reproducible outputs
- [x] Tools (streaming support coming soon)
- [ ] Vision - [ ] Vision
- [ ] Function calling
- [ ] Logprobs - [ ] Logprobs
#### Supported request fields #### Supported request fields
@@ -97,16 +97,12 @@ curl http://localhost:11434/v1/chat/completions \
- [x] `temperature` - [x] `temperature`
- [x] `top_p` - [x] `top_p`
- [x] `max_tokens` - [x] `max_tokens`
- [ ] `logit_bias` - [x] `tools`
- [ ] `tools`
- [ ] `tool_choice` - [ ] `tool_choice`
- [ ] `logit_bias`
- [ ] `user` - [ ] `user`
- [ ] `n` - [ ] `n`
#### Notes
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
## Models ## Models
Before using a model, pull it locally `ollama pull`: Before using a model, pull it locally `ollama pull`:

173
docs/template.md Normal file
View File

@@ -0,0 +1,173 @@
# Template
Ollama provides a powerful templating engine backed by Go's built-in templating engine to construct prompts for your large language model. This feature is a valuable tool to get the most out of your models.
## Basic Template Structure
A basic Go template consists of three main parts:
* **Layout**: The overall structure of the template.
* **Variables**: Placeholders for dynamic data that will be replaced with actual values when the template is rendered.
* **Functions**: Custom functions or logic that can be used to manipulate the template's content.
Here's an example of a simple chat template:
```gotmpl
{{- range .Messages }}
{{ .Role }}: {{ .Content }}
{{- end }}
```
In this example, we have:
* A basic messages structure (layout)
* Three variables: `Messages`, `Role`, and `Content` (variables)
* A custom function (action) that iterates over an array of items (`range .Messages`) and displays each item
## Adding templates to your model
By default, models imported into Ollama have a default template of `{{ .Prompt }}`, i.e. user inputs are sent verbatim to the LLM. This is appropriate for text or code completion models but lacks essential markers for chat or instruction models.
Omitting a template in these models puts the responsibility of correctly templating input onto the user. Adding a template allows users to easily get the best results from the model.
To add templates in your model, you'll need to add a `TEMPLATE` command to the Modelfile. Here's an example using Meta's Llama 3.
```dockerfile
FROM llama3
TEMPLATE """{{- if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
{{ .Content }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>
"""
```
## Variables
`System` (string): system prompt
`Prompt` (string): user prompt
`Response` (string): assistant response
`Suffix` (string): text inserted after the assistant's response
`Messages` (list): list of messages
`Messages[].Role` (string): role which can be one of `system`, `user`, `assistant`, or `tool`
`Messages[].Content` (string): message content
`Messages[].ToolCalls` (list): list of tools the model wants to call
`Messages[].ToolCalls[].Function` (object): function to call
`Messages[].ToolCalls[].Function.Name` (string): function name
`Messages[].ToolCalls[].Function.Arguments` (map): mapping of argument name to argument value
`Tools` (list): list of tools the model can access
`Tools[].Type` (string): schema type. `type` is always `function`
`Tools[].Function` (object): function definition
`Tools[].Function.Name` (string): function name
`Tools[].Function.Description` (string): function description
`Tools[].Function.Parameters` (object): function parameters
`Tools[].Function.Parameters.Type` (string): schema type. `type` is always `object`
`Tools[].Function.Parameters.Required` (list): list of required properties
`Tools[].Function.Parameters.Properties` (map): mapping of property name to property definition
`Tools[].Function.Parameters.Properties[].Type` (string): property type
`Tools[].Function.Parameters.Properties[].Description` (string): property description
`Tools[].Function.Parameters.Properties[].Enum` (list): list of valid values
## Tips and Best Practices
Keep the following tips and best practices in mind when working with Go templates:
* **Be mindful of dot**: Control flow structures like `range` and `with` changes the value `.`
* **Out-of-scope variables**: Use `$.` to reference variables not currently in scope, starting from the root
* **Whitespace control**: Use `-` to trim leading (`{{-`) and trailing (`-}}`) whitespace
## Examples
### Example Messages
#### ChatML
ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2.
```gotmpl
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else }}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
```
### Example Tools
Tools support can be added to a model by adding a `{{ .Tools }}` node to the template. This feature is useful for models trained to call external tools and can a powerful tool for retrieving real-time data or performing complex tasks.
#### Mistral
Mistral v0.3 and Mixtral 8x22B supports tool calling.
```gotmpl
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}
{{- if and (le (len (slice $.Messages $index)) 2) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
{{ end }}{{ .Content }}[/INST]
{{- else if eq .Role "assistant" }}
{{- if .Content }} {{ .Content }}</s>
{{- else if .ToolCalls }}[TOOL_CALLS] [
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
{{- end }}]</s>
{{- end }}
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
{{- end }}
{{- end }}
```
### Example Fill-in-Middle
Fill-in-middle support can be added to a model by adding a `{{ .Suffix }}` node to the template. This feature is useful for models that are trained to generate text in the middle of user input, such as code completion models.
#### CodeLlama
CodeLlama [7B](https://ollama.com/library/codellama:7b-code) and [13B](https://ollama.com/library/codellama:13b-code) code completion models support fill-in-middle.
```gotmpl
<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
```
> [!NOTE]
> CodeLlama 34B and 70B code completion and all instruct and Python fine-tuned models do not support fill-in-middle.
#### Codestral
Codestral [22B](https://ollama.com/library/codestral:22b) supports fill-in-middle.
```gotmpl
[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
```

View File

@@ -43,8 +43,6 @@ var (
MaxRunners int MaxRunners int
// Set via OLLAMA_MAX_QUEUE in the environment // Set via OLLAMA_MAX_QUEUE in the environment
MaxQueuedRequests int MaxQueuedRequests int
// Set via OLLAMA_MAX_VRAM in the environment
MaxVRAM uint64
// Set via OLLAMA_MODELS in the environment // Set via OLLAMA_MODELS in the environment
ModelsDir string ModelsDir string
// Set via OLLAMA_NOHISTORY in the environment // Set via OLLAMA_NOHISTORY in the environment
@@ -89,7 +87,6 @@ func AsMap() map[string]EnvVar {
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
@@ -194,16 +191,6 @@ func LoadConfig() {
TmpDir = clean("OLLAMA_TMPDIR") TmpDir = clean("OLLAMA_TMPDIR")
userLimit := clean("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseUint(userLimit, 10, 64)
if err != nil {
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err)
} else {
MaxVRAM = avail
}
}
LLMLibrary = clean("OLLAMA_LLM_LIBRARY") LLMLibrary = clean("OLLAMA_LLM_LIBRARY")
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" { if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {

View File

@@ -33,9 +33,10 @@ type HipLib struct {
} }
func NewHipLib() (*HipLib, error) { func NewHipLib() (*HipLib, error) {
h, err := windows.LoadLibrary("amdhip64.dll") // At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs
h, err := windows.LoadLibrary("amdhip64_6.dll")
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err) return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err)
} }
hl := &HipLib{} hl := &HipLib{}
hl.dll = h hl.dll = h

View File

@@ -92,7 +92,8 @@ func AMDGetGPUInfo() []RocmGPUInfo {
continue continue
} }
if gfxOverride == "" { if gfxOverride == "" {
if !slices.Contains[[]string, string](supported, gfx) { // Strip off Target Features when comparing
if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) {
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported) slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting? // TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage") slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")

View File

@@ -69,7 +69,7 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
reqLimit := len(req) reqLimit := len(req)
iterLimit := 5 iterLimit := 5
vram := os.Getenv("OLLAMA_MAX_VRAM") vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
if vram != "" { if vram != "" {
max, err := strconv.ParseUint(vram, 10, 64) max, err := strconv.ParseUint(vram, 10, 64)
require.NoError(t, err) require.NoError(t, err)
@@ -106,7 +106,7 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit // Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
func TestMultiModelStress(t *testing.T) { func TestMultiModelStress(t *testing.T) {
vram := os.Getenv("OLLAMA_MAX_VRAM") vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
if vram == "" { if vram == "" {
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test") t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
} }

View File

@@ -12,7 +12,7 @@ import (
func TestContextExhaustion(t *testing.T) { func TestContextExhaustion(t *testing.T) {
// Longer needed for small footprint GPUs // Longer needed for small footprint GPUs
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel() defer cancel()
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
@@ -25,5 +25,10 @@ func TestContextExhaustion(t *testing.T) {
"num_ctx": 128, "num_ctx": 128,
}, },
} }
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"}) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
} }

View File

@@ -4,12 +4,45 @@ package integration
import ( import (
"context" "context"
"math"
"testing" "testing"
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
func floatsEqual32(a, b float32) bool {
return math.Abs(float64(a-b)) <= 1e-4
}
func floatsEqual64(a, b float64) bool {
return math.Abs(a-b) <= 1e-4
}
func TestAllMiniLMEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompt: "why is the sky blue?",
}
res, err := embeddingTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
}
if len(res.Embedding) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embedding))
}
if !floatsEqual64(res.Embedding[0], 0.06642947345972061) {
t.Fatalf("expected 0.06642947345972061, got %.16f", res.Embedding[0])
}
}
func TestAllMiniLMEmbed(t *testing.T) { func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
@@ -33,8 +66,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
} }
if res.Embeddings[0][0] != 0.010071031 { if !floatsEqual32(res.Embeddings[0][0], 0.010071031) {
t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0]) t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0])
} }
} }
@@ -61,12 +94,12 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
} }
if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 { if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) {
t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0]) t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0])
} }
} }
func TestAllMiniLmEmbedTruncate(t *testing.T) { func TestAllMiniLMEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
@@ -135,6 +168,22 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
} }
} }
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}
response, err := client.Embeddings(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
}
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()

View File

@@ -7,8 +7,8 @@ function amdGPUs {
return $env:AMDGPU_TARGETS return $env:AMDGPU_TARGETS
} }
# Current supported rocblas list from ROCm v6.1.2 on windows # Current supported rocblas list from ROCm v6.1.2 on windows
# https://rocm.docs.amd.com/projects/install-on-windows/en/latest/reference/system-requirements.html#windows-supported-gpus
$GPU_LIST = @( $GPU_LIST = @(
"gfx906:xnack-"
"gfx1030" "gfx1030"
"gfx1100" "gfx1100"
"gfx1101" "gfx1101"

View File

@@ -1,8 +1,8 @@
diff --git a/src/llama.cpp b/src/llama.cpp diff --git a/src/llama.cpp b/src/llama.cpp
index 2b9ace28..172640e2 100644 index 8fe51971..7113ba64 100644
--- a/src/llama.cpp --- a/src/llama.cpp
+++ b/src/llama.cpp +++ b/src/llama.cpp
@@ -5357,16 +5357,7 @@ static void llm_load_vocab( @@ -5433,16 +5433,7 @@ static void llm_load_vocab(
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
vocab.tokenizer_add_space_prefix = false; vocab.tokenizer_add_space_prefix = false;
vocab.tokenizer_clean_spaces = true; vocab.tokenizer_clean_spaces = true;
@@ -20,9 +20,9 @@ index 2b9ace28..172640e2 100644
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if ( } else if (
tokenizer_pre == "llama3" || tokenizer_pre == "llama3" ||
@@ -5439,7 +5430,8 @@ static void llm_load_vocab( @@ -5526,7 +5517,8 @@ static void llm_load_vocab(
tokenizer_pre == "jais") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM;
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS; vocab.tokenizer_clean_spaces = false;
} else { } else {
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); - throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__); + LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);

View File

@@ -1,13 +0,0 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index 40d2ec2c..f34eb79a 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

360
llm/patches/09-lora.diff Normal file
View File

@@ -0,0 +1,360 @@
diff --git a/common/common.cpp b/common/common.cpp
index dbb724fb..c26fe6ee 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -2087,14 +2087,29 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]);
+
+ // try to load as gguf
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
if (adapter == nullptr) {
- fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
- llama_free(lctx);
- llama_free_model(model);
- return std::make_tuple(nullptr, nullptr);
+ fprintf(stderr, "%s: error: failed to apply lora adapter, trying ggla\n", __func__);
+
+ // if that fails, try loading as ggla for compatibility
+ int err = llama_model_apply_lora_from_file(model,
+ lora_adapter.c_str(),
+ lora_scale,
+ ((i > 0) || params.lora_base.empty())
+ ? NULL
+ : params.lora_base.c_str(),
+ params.n_threads);
+ if (err != 0) {
+ fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+ llama_free(lctx);
+ llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr);
+ }
+ } else {
+ llama_lora_adapter_set(lctx, adapter, lora_scale);
}
- llama_lora_adapter_set(lctx, adapter, lora_scale);
}
if (params.ignore_eos) {
diff --git a/include/llama.h b/include/llama.h
index 93fd77ca..b0fb37a6 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -1160,6 +1160,20 @@ extern "C" {
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
+ // Apply a LoRA adapter to a loaded model
+ // path_base_model is the path to a higher quality model to use as a base for
+ // the layers modified by the adapter. Can be NULL to use the current loaded model.
+ // The model needs to be reloaded before applying a new adapter, otherwise the adapter
+ // will be applied on top of the previous one
+ // Returns 0 on success
+ LLAMA_API int32_t llama_model_apply_lora_from_file(
+ const struct llama_model * model,
+ const char * path_lora,
+ float scale,
+ const char * path_base_model,
+ int32_t n_threads);
+
+
#ifdef __cplusplus
}
#endif
diff --git a/src/llama.cpp b/src/llama.cpp
index 80a0dd0f..9d7b0e17 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -21880,3 +21880,290 @@ static void llama_log_callback_default(ggml_log_level level, const char * text,
fputs(text, stderr);
fflush(stderr);
}
+
+static int llama_apply_lora_from_file_internal(
+ const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
+) {
+ LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
+
+ const int64_t t_start_lora_us = ggml_time_us();
+
+ llama_file fin(path_lora, "rb");
+
+ // verify magic and version
+ {
+ uint32_t magic = fin.read_u32();
+ if (magic != LLAMA_FILE_MAGIC_GGLA) {
+ LLAMA_LOG_ERROR("%s: bad file magic\n", __func__);
+ return 1;
+ }
+
+ uint32_t format_version = fin.read_u32();
+ if (format_version != 1) {
+ LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
+ return 1;
+ }
+ }
+
+ int32_t lora_r = fin.read_u32();
+ int32_t lora_alpha = fin.read_u32();
+ float scaling = scale * (float)lora_alpha / (float)lora_r;
+
+ LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
+
+ // load base model
+ std::unique_ptr<llama_model_loader> ml;
+ if (path_base_model) {
+ LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
+ ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
+ ml->init_mappings(/*prefetch*/ false); // no prefetching
+ }
+
+ struct tensor_meta {
+ std::string name;
+ ggml_type type;
+ int32_t ne[2];
+ size_t offset;
+ };
+ std::map<std::string, tensor_meta> tensor_meta_map;
+
+ // load all tensor meta
+ while (true) {
+ if (fin.tell() == fin.size) {
+ // eof
+ break;
+ }
+
+ int32_t n_dims;
+ int32_t name_len;
+ int32_t ftype;
+
+ fin.read_raw(&n_dims, sizeof(n_dims));
+ fin.read_raw(&name_len, sizeof(name_len));
+ fin.read_raw(&ftype, sizeof(ftype));
+
+ if (n_dims != 1 && n_dims != 2) {
+ LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
+ return 1;
+ }
+
+ int32_t ne[2] = { 1, 1 };
+ for (int i = 0; i < n_dims; ++i) {
+ fin.read_raw(&ne[i], sizeof(ne[i]));
+ }
+
+ std::string name;
+ {
+ GGML_ASSERT(name_len < GGML_MAX_NAME);
+ char buf[GGML_MAX_NAME];
+ fin.read_raw(buf, name_len);
+ name = std::string(buf, name_len);
+ }
+
+ // check for lora suffix
+ std::string lora_suffix;
+ if (name.length() > 6) {
+ lora_suffix = name.substr(name.length() - 6);
+ }
+ if (lora_suffix != ".loraA" && lora_suffix != ".loraB") {
+ LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
+ return 1;
+ }
+
+ // tensor type
+ ggml_type wtype;
+ switch (ftype) {
+ case 0: wtype = GGML_TYPE_F32; break;
+ case 1: wtype = GGML_TYPE_F16; break;
+ default:
+ {
+ LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n",
+ __func__, ftype);
+ return 1;
+ }
+ }
+
+ // data offset
+ size_t offset = fin.tell();
+ offset = (offset + 31) & -32;
+
+ // skip tensor data
+ fin.seek(offset + ggml_row_size(wtype, ne[0]) * ne[1], SEEK_SET);
+
+ tensor_meta_map.emplace(name, tensor_meta{ name, wtype, { ne[0], ne[1] }, offset });
+ }
+
+ bool warned = false;
+ int n_tensors = 0;
+
+ // apply
+ ggml_backend_t backend_cpu = ggml_backend_cpu_init();
+ if (backend_cpu == nullptr) {
+ LLAMA_LOG_ERROR("%s: error: failed to initialize cpu backend\n", __func__);
+ return 1;
+ }
+ ggml_backend_cpu_set_n_threads(backend_cpu, n_threads);
+
+ std::vector<no_init<uint8_t>> read_buf;
+ for (const auto & it : model.tensors_by_name) {
+ const std::string & base_name = it.first;
+ ggml_tensor * model_t = it.second;
+
+ if (tensor_meta_map.find(base_name + ".loraA") == tensor_meta_map.end() ||
+ tensor_meta_map.find(base_name + ".loraB") == tensor_meta_map.end()) {
+ continue;
+ }
+
+ tensor_meta & metaA = tensor_meta_map.at(base_name + ".loraA");
+ tensor_meta & metaB = tensor_meta_map.at(base_name + ".loraB");
+
+ ggml_init_params lora_init_params = {
+ /* .mem_size */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
+ /* .mem_buffer */ nullptr,
+ /* .no_alloc */ true,
+ };
+ ggml_context * lora_ctx = ggml_init(lora_init_params);
+ if (lora_ctx == nullptr) {
+ LLAMA_LOG_ERROR("%s: error: failed to initialize lora context\n", __func__);
+ ggml_backend_free(backend_cpu);
+ return 1;
+ }
+
+ // create tensors
+ ggml_tensor * loraA = ggml_new_tensor_2d(lora_ctx, metaA.type, metaA.ne[0], metaA.ne[1]);
+ ggml_tensor * loraB = ggml_new_tensor_2d(lora_ctx, metaB.type, metaB.ne[0], metaB.ne[1]);
+ ggml_set_name(loraA, metaA.name.c_str());
+ ggml_set_name(loraB, metaB.name.c_str());
+
+ ggml_tensor * base_t;
+ if (ml) {
+ if (!ml->get_tensor_meta(base_name.c_str())) {
+ LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
+ return 1;
+ }
+ base_t = ggml_dup_tensor(lora_ctx, ml->get_tensor_meta(base_name.c_str()));
+ } else {
+ base_t = ggml_dup_tensor(lora_ctx, model_t);
+ }
+ ggml_set_name(base_t, base_name.c_str());
+
+ // allocate in backend buffer
+ ggml_backend_buffer_t lora_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
+ if (lora_buf == nullptr) {
+ LLAMA_LOG_ERROR("%s: error: failed to allocate lora tensors\n", __func__);
+ return 1;
+ }
+
+ // load tensor data
+ auto load_tensor = [&read_buf, &fin](const tensor_meta & tensor_meta, ggml_tensor * tensor) {
+ read_buf.resize(ggml_nbytes(tensor));
+ fin.seek(tensor_meta.offset, SEEK_SET);
+ fin.read_raw(read_buf.data(), ggml_nbytes(tensor));
+ ggml_backend_tensor_set(tensor, read_buf.data(), 0, read_buf.size());
+ };
+ load_tensor(metaA, loraA);
+ load_tensor(metaB, loraB);
+
+ // load base model tensor data
+ if (ml) {
+ ml->load_data_for(base_t);
+ } else {
+ ggml_backend_tensor_copy(model_t, base_t);
+ }
+
+ if (ggml_is_quantized(base_t->type) && !warned) {
+ LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, "
+ "use a f16 or f32 base model with --lora-base\n", __func__);
+ warned = true;
+ }
+
+ if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
+ LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
+ " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
+ ggml_free(lora_ctx);
+ ggml_backend_buffer_free(lora_buf);
+ ggml_backend_free(backend_cpu);
+ return 1;
+ }
+
+ auto build_lora_graph = [&]() {
+ // w = w + BA*s
+ ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
+ ggml_set_name(BA, "BA");
+
+ if (scaling != 1.0f) {
+ BA = ggml_scale(lora_ctx, BA, scaling);
+ ggml_set_name(BA, "BA_scaled");
+ }
+
+ ggml_tensor * r;
+ r = ggml_add_inplace(lora_ctx, base_t, BA);
+ ggml_set_name(r, "r_add");
+
+ if (base_t->type != model_t->type) {
+ // convert the result to the model type
+ r = ggml_cast(lora_ctx, r, model_t->type);
+ ggml_set_name(r, "r_cast");
+ }
+
+ return r;
+ };
+
+ ggml_cgraph * gf = ggml_new_graph(lora_ctx);
+ ggml_tensor * r = build_lora_graph();
+ ggml_build_forward_expand(gf, r);
+
+ ggml_backend_buffer_t graph_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
+ if (graph_buf == nullptr) {
+ LLAMA_LOG_ERROR("%s: error: failed to allocate graph tensors\n", __func__);
+ ggml_free(lora_ctx);
+ ggml_backend_buffer_free(lora_buf);
+ ggml_backend_free(backend_cpu);
+ return 1;
+ }
+
+ ggml_backend_graph_compute(backend_cpu, gf);
+
+ ggml_backend_tensor_set(model_t, r->data, 0, ggml_nbytes(r));
+
+#if 0
+ // TODO: use scheduler with fallback to CPU for less copies between CPU and GPU
+ //ggml_backend_sched_t sched = ggml_backend_sched_new(backends.data(), backends.size(), GGML_DEFAULT_GRAPH_SIZE);
+
+ // sched compute
+ ggml_build_forward_expand(gf, build_graph());
+ ggml_backend_sched_init_measure(sched, gf);
+
+ // create the graph again, since the previous one was destroyed by the measure
+ ggml_graph_clear(gf);
+ ggml_build_forward_expand(gf, build_graph());
+ ggml_backend_sched_graph_compute(sched, gf);
+ ggml_backend_sched_free(sched);
+#endif
+
+ ggml_backend_buffer_free(lora_buf);
+ ggml_backend_buffer_free(graph_buf);
+ ggml_free(lora_ctx);
+
+ n_tensors++;
+ if (n_tensors % 4 == 0) {
+ LLAMA_LOG_INFO(".");
+ }
+ }
+
+ ggml_backend_free(backend_cpu);
+
+ const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
+ LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
+
+ return 0;
+}
+
+int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int32_t n_threads) {
+ try {
+ return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
+ return 1;
+ }
+}
\ No newline at end of file

View File

@@ -385,8 +385,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
filteredEnv := []string{} filteredEnv := []string{}
for _, ev := range s.cmd.Env { for _, ev := range s.cmd.Env {
if strings.HasPrefix(ev, "CUDA_") || if strings.HasPrefix(ev, "CUDA_") ||
strings.HasPrefix(ev, "ROCR_") ||
strings.HasPrefix(ev, "ROCM_") || strings.HasPrefix(ev, "ROCM_") ||
strings.HasPrefix(ev, "HIP_") || strings.HasPrefix(ev, "HIP_") ||
strings.HasPrefix(ev, "GPU_") ||
strings.HasPrefix(ev, "HSA_") || strings.HasPrefix(ev, "HSA_") ||
strings.HasPrefix(ev, "GGML_") || strings.HasPrefix(ev, "GGML_") ||
strings.HasPrefix(ev, "PATH=") || strings.HasPrefix(ev, "PATH=") ||
@@ -415,7 +417,17 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
// reap subprocess when it exits // reap subprocess when it exits
go func() { go func() {
s.done <- s.cmd.Wait() err := s.cmd.Wait()
// Favor a more detailed message over the process exit status
if err != nil && s.status != nil && s.status.LastErrMsg != "" {
slog.Debug("llama runner terminated", "error", err)
if strings.Contains(s.status.LastErrMsg, "unknown model") {
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
}
s.done <- fmt.Errorf(s.status.LastErrMsg)
} else {
s.done <- err
}
}() }()
return s, nil return s, nil
@@ -578,14 +590,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
slog.Warn("client connection closed before server finished loading, aborting load") slog.Warn("client connection closed before server finished loading, aborting load")
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err()) return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
case err := <-s.done: case err := <-s.done:
msg := "" return fmt.Errorf("llama runner process has terminated: %w", err)
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
}
if strings.Contains(msg, "unknown model") {
return fmt.Errorf("this model is not supported by your version of Ollama. You may need to upgrade")
}
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
default: default:
} }
if time.Now().After(stallTimer) { if time.Now().After(stallTimer) {

View File

@@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"math/rand" "math/rand"
"net/http" "net/http"
"strings" "strings"
@@ -29,8 +30,9 @@ type ErrorResponse struct {
} }
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content any `json:"content"` Content any `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} }
type Choice struct { type Choice struct {
@@ -78,6 +80,7 @@ type ChatCompletionRequest struct {
PresencePenalty *float64 `json:"presence_penalty_penalty"` PresencePenalty *float64 `json:"presence_penalty_penalty"`
TopP *float64 `json:"top_p"` TopP *float64 `json:"top_p"`
ResponseFormat *ResponseFormat `json:"response_format"` ResponseFormat *ResponseFormat `json:"response_format"`
Tools []api.Tool `json:"tools"`
} }
type ChatCompletion struct { type ChatCompletion struct {
@@ -111,6 +114,7 @@ type CompletionRequest struct {
Stream bool `json:"stream"` Stream bool `json:"stream"`
Temperature *float32 `json:"temperature"` Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"` TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
} }
type Completion struct { type Completion struct {
@@ -132,6 +136,15 @@ type CompletionChunk struct {
SystemFingerprint string `json:"system_fingerprint"` SystemFingerprint string `json:"system_fingerprint"`
} }
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
}
type Model struct { type Model struct {
Id string `json:"id"` Id string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
@@ -170,7 +183,31 @@ func NewError(code int, message string) ErrorResponse {
return ErrorResponse{Error{Type: etype, Message: message}} return ErrorResponse{Error{Type: etype, Message: message}}
} }
func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return "call_" + strings.ToLower(string(b))
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
for i, tc := range r.Message.ToolCalls {
toolCalls[i].ID = toolCallId()
toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name
args, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("could not marshall function arguments to json", "error", err)
continue
}
toolCalls[i].Function.Arguments = string(args)
}
return ChatCompletion{ return ChatCompletion{
Id: id, Id: id,
Object: "chat.completion", Object: "chat.completion",
@@ -179,7 +216,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
SystemFingerprint: "fp_ollama", SystemFingerprint: "fp_ollama",
Choices: []Choice{{ Choices: []Choice{{
Index: 0, Index: 0,
Message: Message{Role: r.Message.Role, Content: r.Message.Content}, Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
FinishReason: func(reason string) *string { FinishReason: func(reason string) *string {
if len(reason) > 0 { if len(reason) > 0 {
return &reason return &reason
@@ -188,7 +225,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}(r.DoneReason), }(r.DoneReason),
}}, }},
Usage: Usage{ Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
PromptTokens: r.PromptEvalCount, PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount, CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount, TotalTokens: r.PromptEvalCount + r.EvalCount,
@@ -234,7 +270,6 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
}(r.DoneReason), }(r.DoneReason),
}}, }},
Usage: Usage{ Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
PromptTokens: r.PromptEvalCount, PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount, CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount, TotalTokens: r.PromptEvalCount + r.EvalCount,
@@ -316,7 +351,6 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
case string: case string:
messages = append(messages, api.Message{Role: msg.Role, Content: content}) messages = append(messages, api.Message{Role: msg.Role, Content: content})
case []any: case []any:
message := api.Message{Role: msg.Role}
for _, c := range content { for _, c := range content {
data, ok := c.(map[string]any) data, ok := c.(map[string]any)
if !ok { if !ok {
@@ -328,7 +362,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
if !ok { if !ok {
return nil, fmt.Errorf("invalid message format") return nil, fmt.Errorf("invalid message format")
} }
message.Content = text messages = append(messages, api.Message{Role: msg.Role, Content: text})
case "image_url": case "image_url":
var url string var url string
if urlMap, ok := data["image_url"].(map[string]any); ok { if urlMap, ok := data["image_url"].(map[string]any); ok {
@@ -360,14 +394,26 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid message format") return nil, fmt.Errorf("invalid message format")
} }
message.Images = append(message.Images, img)
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
default: default:
return nil, fmt.Errorf("invalid message format") return nil, fmt.Errorf("invalid message format")
} }
} }
messages = append(messages, message)
default: default:
return nil, fmt.Errorf("invalid message content type: %T", content) if msg.ToolCalls == nil {
return nil, fmt.Errorf("invalid message content type: %T", content)
}
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
for i, tc := range msg.ToolCalls {
toolCalls[i].Function.Name = tc.Function.Name
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
if err != nil {
return nil, fmt.Errorf("invalid tool call arguments")
}
}
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
} }
} }
@@ -425,6 +471,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
Format: format, Format: format,
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
Tools: r.Tools,
}, nil }, nil
} }
@@ -475,6 +522,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
Prompt: r.Prompt, Prompt: r.Prompt,
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
Suffix: r.Suffix,
}, nil }, nil
} }
@@ -829,6 +877,7 @@ func ChatMiddleware() gin.HandlerFunc {
chatReq, err := fromChatRequest(req) chatReq, err := fromChatRequest(req)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
} }
if err := json.NewEncoder(&b).Encode(chatReq); err != nil { if err := json.NewEncoder(&b).Encode(chatReq); err != nil {

View File

@@ -20,108 +20,59 @@ const prefix = `data:image/jpeg;base64,`
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
const imageURL = prefix + image const imageURL = prefix + image
func TestMiddlewareRequests(t *testing.T) { func prepareRequest(req *http.Request, body any) {
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
err := json.Unmarshal(bodyBytes, capturedRequest)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
}
c.Next()
}
}
func TestChatMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
Name string Name string
Method string
Path string
Handler func() gin.HandlerFunc
Setup func(t *testing.T, req *http.Request) Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *http.Request) Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
} }
var capturedRequest *http.Request var capturedRequest *api.ChatRequest
captureRequestMiddleware := func() gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
capturedRequest = c.Request
c.Next()
}
}
testCases := []testCase{ testCases := []testCase{
{ {
Name: "chat handler", Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{ body := ChatCompletionRequest{
Model: "test-model", Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}}, Messages: []Message{{Role: "user", Content: "Hello"}},
} }
prepareRequest(req, body)
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}, },
Expected: func(t *testing.T, req *http.Request) { Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
var chatReq api.ChatRequest if resp.Code != http.StatusOK {
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil { t.Fatalf("expected 200, got %d", resp.Code)
t.Fatal(err)
} }
if chatReq.Messages[0].Role != "user" { if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role) t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
} }
if chatReq.Messages[0].Content != "Hello" { if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content) t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
} }
}, },
}, },
{ {
Name: "completions handler", Name: "chat handler with image content",
Method: http.MethodPost,
Path: "/api/generate",
Handler: CompletionsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var genReq api.GenerateRequest
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
t.Fatal(err)
}
if genReq.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
}
if genReq.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
}
stopTokens, ok := genReq.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
},
},
{
Name: "chat handler with image content",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{ body := ChatCompletionRequest{
Model: "test-model", Model: "test-model",
@@ -134,87 +85,254 @@ func TestMiddlewareRequests(t *testing.T) {
}, },
}, },
} }
prepareRequest(req, body)
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}, },
Expected: func(t *testing.T, req *http.Request) { Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
var chatReq api.ChatRequest if resp.Code != http.StatusOK {
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil { t.Fatalf("expected 200, got %d", resp.Code)
t.Fatal(err)
} }
if chatReq.Messages[0].Role != "user" { if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role) t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
} }
if chatReq.Messages[0].Content != "Hello" { if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content) t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
} }
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):]) img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if !bytes.Equal(chatReq.Messages[0].Images[0], img) { if req.Messages[1].Role != "user" {
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0]) t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
}
if !bytes.Equal(req.Messages[1].Images[0], img) {
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
} }
}, },
}, },
{ {
Name: "embed handler single input", Name: "chat handler with tools",
Method: http.MethodPost, Setup: func(t *testing.T, req *http.Request) {
Path: "/api/embed", body := ChatCompletionRequest{
Handler: EmbeddingsMiddleware, Model: "test-model",
Messages: []Message{
{Role: "user", Content: "What's the weather like in Paris Today?"},
{Role: "assistant", ToolCalls: []ToolCall{{
ID: "id",
Type: "function",
Function: struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}{
Name: "get_current_weather",
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
},
}}},
},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != 200 {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
}
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
}
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
}
},
},
{
Name: "chat handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: 2}},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid message content type") {
t.Fatalf("error was not forwarded")
}
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/chat", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
tc.Setup(t, req)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
})
}
}
func TestCompletionsMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
}
var capturedRequest *api.GenerateRequest
testCases := []testCase{
{
Name: "completions handler",
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if req.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Prompt)
}
if req.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
}
stopTokens, ok := req.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
if req.Suffix != "suffix" {
t.Fatalf("expected 'suffix', got %s", req.Suffix)
}
},
},
{
Name: "completions handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: nil,
Stop: []int{1, 2},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
t.Fatalf("error was not forwarded")
}
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
tc.Setup(t, req)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
})
}
}
func TestEmbeddingsMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
}
var capturedRequest *api.EmbedRequest
testCases := []testCase{
{
Name: "embed handler single input",
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{ body := EmbedRequest{
Input: "Hello", Input: "Hello",
Model: "test-model", Model: "test-model",
} }
prepareRequest(req, body)
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}, },
Expected: func(t *testing.T, req *http.Request) { Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
var embedReq api.EmbedRequest if req.Input != "Hello" {
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil { t.Fatalf("expected 'Hello', got %s", req.Input)
t.Fatal(err)
} }
if embedReq.Input != "Hello" { if req.Model != "test-model" {
t.Fatalf("expected 'Hello', got %s", embedReq.Input) t.Fatalf("expected 'test-model', got %s", req.Model)
}
if embedReq.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
} }
}, },
}, },
{ {
Name: "embed handler batch input", Name: "embed handler batch input",
Method: http.MethodPost,
Path: "/api/embed",
Handler: EmbeddingsMiddleware,
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{ body := EmbedRequest{
Input: []string{"Hello", "World"}, Input: []string{"Hello", "World"},
Model: "test-model", Model: "test-model",
} }
prepareRequest(req, body)
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}, },
Expected: func(t *testing.T, req *http.Request) { Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
var embedReq api.EmbedRequest input, ok := req.Input.([]any)
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
t.Fatal(err)
}
input, ok := embedReq.Input.([]any)
if !ok { if !ok {
t.Fatalf("expected input to be a list") t.Fatalf("expected input to be a list")
@@ -228,36 +346,52 @@ func TestMiddlewareRequests(t *testing.T) {
t.Fatalf("expected 'World', got %s", input[1]) t.Fatalf("expected 'World', got %s", input[1])
} }
if embedReq.Model != "test-model" { if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model) t.Fatalf("expected 'test-model', got %s", req.Model)
}
},
},
{
Name: "embed handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid input") {
t.Fatalf("error was not forwarded")
} }
}, },
}, },
} }
gin.SetMode(gin.TestMode)
router := gin.New()
endpoint := func(c *gin.Context) { endpoint := func(c *gin.Context) {
c.Status(http.StatusOK) c.Status(http.StatusOK)
} }
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/embed", endpoint)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
router = gin.New() req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
router.Use(captureRequestMiddleware())
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, endpoint)
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
if tc.Setup != nil { tc.Setup(t, req)
tc.Setup(t, req)
}
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest) tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
}) })
} }
} }
@@ -275,36 +409,6 @@ func TestMiddlewareResponses(t *testing.T) {
} }
testCases := []testCase{ testCases := []testCase{
{
Name: "completions handler error forwarding",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
t.Fatalf("error was not forwarded")
}
},
},
{ {
Name: "list handler", Name: "list handler",
Method: http.MethodGet, Method: http.MethodGet,
@@ -321,8 +425,6 @@ func TestMiddlewareResponses(t *testing.T) {
}) })
}, },
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var listResp ListCompletion var listResp ListCompletion
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
t.Fatal(err) t.Fatal(err)
@@ -386,6 +488,8 @@ func TestMiddlewareResponses(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
tc.Expected(t, resp) tc.Expected(t, resp)
}) })
} }

View File

@@ -8,6 +8,7 @@ import (
"io" "io"
"log/slog" "log/slog"
"math" "math"
"math/rand/v2"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@@ -141,6 +142,32 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis
b.err = b.run(ctx, requestURL, opts) b.err = b.run(ctx, requestURL, opts)
} }
func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error {
var n int
return func(ctx context.Context) error {
if ctx.Err() != nil {
return ctx.Err()
}
n++
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
t := time.NewTimer(d)
defer t.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-t.C:
return nil
}
}
}
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error { func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
defer blobDownloadManager.Delete(b.Digest) defer blobDownloadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx) ctx, b.CancelFunc = context.WithCancel(ctx)
@@ -153,6 +180,52 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
_ = file.Truncate(b.Total) _ = file.Truncate(b.Total)
directURL, err := func() (*url.URL, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
backoff := newBackoff(10 * time.Second)
for {
// shallow clone opts to be used in the closure
// without affecting the outer opts.
newOpts := new(registryOptions)
*newOpts = *opts
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) > 10 {
return errors.New("maxium redirects exceeded (10) for directURL")
}
// if the hostname is the same, allow the redirect
if req.URL.Hostname() == requestURL.Hostname() {
return nil
}
// stop at the first redirect that is not
// the same hostname as the original
// request.
return http.ErrUseLastResponse
}
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, nil, nil, newOpts)
if err != nil {
slog.Warn("failed to get direct URL; backing off and retrying", "err", err)
if err := backoff(ctx); err != nil {
return nil, err
}
continue
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusTemporaryRedirect {
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
return resp.Location()
}
}()
if err != nil {
return err
}
g, inner := errgroup.WithContext(ctx) g, inner := errgroup.WithContext(ctx)
g.SetLimit(numDownloadParts) g.SetLimit(numDownloadParts)
for i := range b.Parts { for i := range b.Parts {
@@ -165,7 +238,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
var err error var err error
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
w := io.NewOffsetWriter(file, part.StartsAt()) w := io.NewOffsetWriter(file, part.StartsAt())
err = b.downloadChunk(inner, requestURL, w, part, opts) err = b.downloadChunk(inner, directURL, w, part, opts)
switch { switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space // return immediately if the context is canceled or the device is out of space

View File

@@ -54,6 +54,8 @@ type registryOptions struct {
Username string Username string
Password string Password string
Token string Token string
CheckRedirect func(req *http.Request, via []*http.Request) error
} }
type Model struct { type Model struct {
@@ -492,6 +494,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
layers = append(layers, baseLayer.Layer) layers = append(layers, baseLayer.Layer)
} }
case "license", "template", "system": case "license", "template", "system":
if c.Name == "template" {
if _, err := template.Parse(c.Args); err != nil {
return fmt.Errorf("%w: %s", errBadTemplate, err)
}
}
if c.Name != "license" { if c.Name != "license" {
// replace // replace
layers = slices.DeleteFunc(layers, func(layer *Layer) bool { layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
@@ -1125,7 +1133,9 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
req.ContentLength = contentLength req.ContentLength = contentLength
} }
resp, err := http.DefaultClient.Do(req) resp, err := (&http.Client{
CheckRedirect: regOpts.CheckRedirect,
}).Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -263,13 +263,27 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
if t, err := template.Named(s); err != nil { if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err) slog.Debug("template detection", "error", err)
} else { } else {
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template") layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil { if err != nil {
return nil, err return nil, err
} }
tmpl.status = fmt.Sprintf("using autodetected template %s", t.Name) layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
layers = append(layers, &layerGGML{tmpl, nil}) layers = append(layers, &layerGGML{layer, nil})
if t.Parameters != nil {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(t.Parameters); err != nil {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}
layers = append(layers, &layerGGML{layer, nil})
}
} }
} }
} }
@@ -311,12 +325,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
} }
var b bytes.Buffer var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]map[string]any{ if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": { "ToolCalls": {
{ {
"Function": map[string]any{ Function: api.ToolCallFunction{
"Name": "@@name@@", Name: "@@name@@",
"Arguments": "@@arguments@@", Arguments: api.ToolCallFunctionArguments{
"@@argument@@": 1,
},
}, },
}, },
}, },
@@ -324,7 +340,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
return nil, false return nil, false
} }
var kv map[string]string var kv map[string]any
// execute the subtree with placeholders to identify the keys // execute the subtree with placeholders to identify the keys
// trim any commands that might exist in the template // trim any commands that might exist in the template
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil { if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
@@ -334,17 +350,23 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// find the keys that correspond to the name and arguments fields // find the keys that correspond to the name and arguments fields
var name, arguments string var name, arguments string
for k, v := range kv { for k, v := range kv {
switch v { switch v.(type) {
case "@@name@@": case string:
name = k name = k
case "@@arguments@@": case map[string]any:
arguments = k arguments = k
} }
} }
if name == "" || arguments == "" {
return nil, false
}
var objs []map[string]any var objs []map[string]any
for offset := 0; offset < len(s); { for offset := 0; offset < len(s); {
if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) { var obj map[string]any
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break break
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) { } else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
// skip over any syntax errors // skip over any syntax errors
@@ -353,26 +375,44 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// skip over any unmarshalable types // skip over any unmarshalable types
offset += int(unmarshalType.Offset) offset += int(unmarshalType.Offset)
} else if err != nil { } else if err != nil {
slog.Error("parseToolCalls", "error", err)
return nil, false return nil, false
} else { } else {
// break when an object is decoded offset += int(decoder.InputOffset())
break
// collect all nested objects
var collect func(any) []map[string]any
collect = func(obj any) (all []map[string]any) {
switch o := obj.(type) {
case map[string]any:
all = append(all, o)
for _, v := range o {
all = append(all, collect(v)...)
}
case []any:
for _, v := range o {
all = append(all, collect(v)...)
}
}
return all
}
objs = append(objs, collect(obj)...)
} }
} }
var toolCalls []api.ToolCall var toolCalls []api.ToolCall
for _, kv := range objs { for _, kv := range objs {
var call api.ToolCall n, nok := kv[name].(string)
for k, v := range kv { a, aok := kv[arguments].(map[string]any)
switch k { if nok && aok {
case name: toolCalls = append(toolCalls, api.ToolCall{
call.Function.Name = v.(string) Function: api.ToolCallFunction{
case arguments: Name: n,
call.Function.Arguments = v.(map[string]any) Arguments: a,
} },
})
} }
toolCalls = append(toolCalls, call)
} }
return toolCalls, len(toolCalls) > 0 return toolCalls, len(toolCalls) > 0

View File

@@ -115,11 +115,6 @@ func TestExtractFromZipFile(t *testing.T) {
} }
} }
type function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
func readFile(t *testing.T, base, name string) *bytes.Buffer { func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper() t.Helper()
@@ -167,6 +162,11 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"llama3-groq-tool-use", `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
</tool_call>`, true},
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
} }
var tools []api.Tool var tools []api.Tool
@@ -181,18 +181,18 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
calls := []api.ToolCall{ calls := []api.ToolCall{
{ {
Function: function{ Function: api.ToolCallFunction{
Name: "get_current_weather", Name: "get_current_weather",
Arguments: map[string]any{ Arguments: api.ToolCallFunctionArguments{
"format": "fahrenheit", "format": "fahrenheit",
"location": "San Francisco, CA", "location": "San Francisco, CA",
}, },
}, },
}, },
{ {
Function: function{ Function: api.ToolCallFunction{
Name: "get_current_weather", Name: "get_current_weather",
Arguments: map[string]any{ Arguments: api.ToolCallFunctionArguments{
"format": "celsius", "format": "celsius",
"location": "Toronto, Canada", "location": "Toronto, Canada",
}, },

View File

@@ -56,6 +56,7 @@ func init() {
} }
var errRequired = errors.New("is required") var errRequired = errors.New("is required")
var errBadTemplate = errors.New("template error")
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions() opts := api.DefaultOptions()
@@ -275,11 +276,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
r.Response = sb.String() r.Response = sb.String()
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
r.ToolCalls = toolCalls
r.Response = ""
}
c.JSON(http.StatusOK, r) c.JSON(http.StatusOK, r)
return return
} }
@@ -613,7 +609,9 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
defer cancel() defer cancel()
quantization := cmp.Or(r.Quantize, r.Quantization) quantization := cmp.Or(r.Quantize, r.Quantization)
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil { if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); errors.Is(err, errBadTemplate) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
} else if err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@@ -1201,11 +1199,15 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
return return
} }
case gin.H: case gin.H:
status, ok := r["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
if errorMsg, ok := r["error"].(string); ok { if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg}) c.JSON(status, gin.H{"error": errorMsg})
return return
} else { } else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"}) c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
return return
} }
default: default:
@@ -1295,7 +1297,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
if req.Tools != nil { if len(req.Tools) > 0 {
caps = append(caps, CapabilityTools) caps = append(caps, CapabilityTools)
} }
@@ -1390,9 +1392,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
resp.Message.Content = sb.String() resp.Message.Content = sb.String()
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls if len(req.Tools) > 0 {
resp.Message.Content = "" if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
} }
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)

View File

@@ -491,6 +491,42 @@ func TestCreateTemplateSystem(t *testing.T) {
if string(system) != "Say bye!" { if string(system) != "Say bye!" {
t.Errorf("expected \"Say bye!\", actual %s", system) t.Errorf("expected \"Say bye!\", actual %s", system)
} }
t.Run("incomplete template", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("template with unclosed if", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("template with undefined function", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
} }
func TestCreateLicenses(t *testing.T) { func TestCreateLicenses(t *testing.T) {
@@ -563,9 +599,10 @@ func TestCreateDetectTemplate(t *testing.T) {
} }
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-0d79f567714c62c048378f2107fb332dabee0135d080c302d884317da9433cc5"),
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"), filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"), filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"),
filepath.Join(p, "blobs", "sha256-f836ee110db21567f826332e4cedd746c06d10664fd5a9ea3659e3683a944510"), filepath.Join(p, "blobs", "sha256-ea34c57ba5b78b740aafe2aeb74dc6507fc3ad14170b64c26a04fb9e36c88d75"),
}) })
}) })

View File

@@ -73,8 +73,8 @@ func TestGenerateChat(t *testing.T) {
getCpuFn: gpu.GetCPUInfo, getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond, reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add 10ms delay to simulate loading // add small delay to simulate loading
time.Sleep(10 * time.Millisecond) time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
llama: &mock, llama: &mock,
} }
@@ -371,6 +371,8 @@ func TestGenerate(t *testing.T) {
getCpuFn: gpu.GetCPUInfo, getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond, reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
llama: &mock, llama: &mock,
} }

View File

@@ -132,6 +132,8 @@ func (s *Scheduler) processPending(ctx context.Context) {
if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 { if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 {
numParallel = 1 numParallel = 1
slog.Warn("multimodal models don't support parallel requests yet") slog.Warn("multimodal models don't support parallel requests yet")
} else if strings.Contains(pending.model.Config.ModelFamily, "bert") {
numParallel = runtime.NumCPU()
} }
for { for {

View File

@@ -94,7 +94,7 @@ func TestLoad(t *testing.T) {
require.Len(t, s.expiredCh, 1) require.Len(t, s.expiredCh, 1)
} }
type bundle struct { type reqBundle struct {
ctx context.Context //nolint:containedctx ctx context.Context //nolint:containedctx
ctxDone func() ctxDone func()
srv *mockLlm srv *mockLlm
@@ -102,13 +102,13 @@ type bundle struct {
ggml *llm.GGML ggml *llm.GGML
} }
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { func (scenario *reqBundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return scenario.srv, nil return scenario.srv, nil
} }
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle { func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle {
scenario := &bundle{} b := &reqBundle{}
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx) b.ctx, b.ctxDone = context.WithCancel(ctx)
t.Helper() t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName) f, err := os.CreateTemp(t.TempDir(), modelName)
@@ -135,124 +135,154 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
fname := f.Name() fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname} model := &Model{Name: modelName, ModelPath: fname}
scenario.ggml, err = llm.LoadModel(model.ModelPath, 0) b.ggml, err = llm.LoadModel(model.ModelPath, 0)
require.NoError(t, err) require.NoError(t, err)
scenario.req = &LlmRequest{ if duration == nil {
ctx: scenario.ctx, duration = &api.Duration{Duration: 5 * time.Millisecond}
}
b.req = &LlmRequest{
ctx: b.ctx,
model: model, model: model,
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond}, sessionDuration: duration,
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1), errCh: make(chan error, 1),
} }
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}} b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
return scenario return b
} }
func TestRequests(t *testing.T) { func getGpuFn() gpu.GpuInfoList {
ctx, done := context.WithTimeout(context.Background(), 10*time.Second) g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
func getCpuFn() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "cpu"}
g.TotalMemory = 32 * format.GigaByte
g.FreeMemory = 26 * format.GigaByte
return []gpu.GpuInfo{g}
}
func TestRequestsSameModelSameRequest(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done() defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
tmpModel := *scenario1a.req.model
scenario2a.req.model = &tmpModel
scenario2a.ggml = scenario1a.ggml
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList { s.getGpuFn = getGpuFn
g := gpu.GpuInfo{Library: "metal"} s.getCpuFn = getCpuFn
g.TotalMemory = 24 * format.GigaByte a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
g.FreeMemory = 12 * format.GigaByte b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
return []gpu.GpuInfo{g} b.req.model = a.req.model
} b.ggml = a.ggml
s.getCpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "cpu"} s.newServerFn = a.newServer
g.TotalMemory = 32 * format.GigaByte slog.Info("a")
g.FreeMemory = 26 * format.GigaByte s.pendingReqCh <- a.req
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req
require.Len(t, s.pendingReqCh, 1) require.Len(t, s.pendingReqCh, 1)
s.Run(ctx) s.Run(ctx)
select { select {
case resp := <-scenario1a.req.successCh: case resp := <-a.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1a.req.errCh) require.Empty(t, a.req.errCh)
case err := <-scenario1a.req.errCh: case err := <-a.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
// Same runner as first request due to not needing a reload // Same runner as first request due to not needing a reload
s.newServerFn = scenario1b.newServer s.newServerFn = b.newServer
slog.Info("scenario1b") slog.Info("b")
s.pendingReqCh <- scenario1b.req s.pendingReqCh <- b.req
select { select {
case resp := <-scenario1b.req.successCh: case resp := <-b.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1b.req.errCh) require.Empty(t, b.req.errCh)
case err := <-scenario1b.req.errCh: case err := <-b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
}
func TestRequestsSimpleReloadSameModel(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
tmpModel := *a.req.model
b.req.model = &tmpModel
b.ggml = a.ggml
s.newServerFn = a.newServer
slog.Info("a")
s.pendingReqCh <- a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-a.req.successCh:
require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, a.req.errCh)
case err := <-a.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
// Trigger a reload // Trigger a reload
s.newServerFn = scenario2a.newServer s.newServerFn = b.newServer
scenario2a.req.model.AdapterPaths = []string{"new"} b.req.model.AdapterPaths = []string{"new"}
slog.Info("scenario2a") slog.Info("b")
s.pendingReqCh <- scenario2a.req s.pendingReqCh <- b.req
// finish first two requests, so model can reload // finish first two requests, so model can reload
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
scenario1a.ctxDone() a.ctxDone()
scenario1b.ctxDone()
select { select {
case resp := <-scenario2a.req.successCh: case resp := <-b.req.successCh:
require.Equal(t, resp.llama, scenario2a.srv) require.Equal(t, resp.llama, b.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario2a.req.errCh) require.Empty(t, b.req.errCh)
case err := <-scenario2a.req.errCh: case err := <-b.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
}
func TestRequestsMultipleLoadedModels(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
// Multiple loaded models
a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil)
c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil)
c.req.opts.NumGPU = 0 // CPU load, will be allowed
d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded
envconfig.MaxRunners = 1 envconfig.MaxRunners = 1
s.newServerFn = scenario3a.newServer s.newServerFn = a.newServer
slog.Info("scenario3a") slog.Info("a")
s.pendingReqCh <- scenario3a.req s.pendingReqCh <- a.req
// finish prior request, so new model can load s.Run(ctx)
time.Sleep(1 * time.Millisecond)
scenario2a.ctxDone()
select { select {
case resp := <-scenario3a.req.successCh: case resp := <-a.req.successCh:
require.Equal(t, resp.llama, scenario3a.srv) require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3a.req.errCh) require.Empty(t, a.req.errCh)
case err := <-scenario3a.req.errCh: case err := <-a.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
@@ -262,15 +292,15 @@ func TestRequests(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
envconfig.MaxRunners = 0 envconfig.MaxRunners = 0
s.newServerFn = scenario3b.newServer s.newServerFn = b.newServer
slog.Info("scenario3b") slog.Info("b")
s.pendingReqCh <- scenario3b.req s.pendingReqCh <- b.req
select { select {
case resp := <-scenario3b.req.successCh: case resp := <-b.req.successCh:
require.Equal(t, resp.llama, scenario3b.srv) require.Equal(t, resp.llama, b.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3b.req.errCh) require.Empty(t, b.req.errCh)
case err := <-scenario3b.req.errCh: case err := <-b.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
@@ -280,15 +310,15 @@ func TestRequests(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
// This is a CPU load with NumGPU = 0 so it should load // This is a CPU load with NumGPU = 0 so it should load
s.newServerFn = scenario3c.newServer s.newServerFn = c.newServer
slog.Info("scenario3c") slog.Info("c")
s.pendingReqCh <- scenario3c.req s.pendingReqCh <- c.req
select { select {
case resp := <-scenario3c.req.successCh: case resp := <-c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv) require.Equal(t, resp.llama, c.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3c.req.errCh) require.Empty(t, c.req.errCh)
case err := <-scenario3c.req.errCh: case err := <-c.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
@@ -298,25 +328,25 @@ func TestRequests(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
// Try to load a model that wont fit // Try to load a model that wont fit
s.newServerFn = scenario3d.newServer s.newServerFn = d.newServer
slog.Info("scenario3d") slog.Info("d")
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 3) require.Len(t, s.loaded, 3)
s.loadedMu.Unlock() s.loadedMu.Unlock()
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond) time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- scenario3d.req s.pendingReqCh <- d.req
// finish prior request, so new model can load // finish prior request, so new model can load
time.Sleep(6 * time.Millisecond) time.Sleep(6 * time.Millisecond)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 2) require.Len(t, s.loaded, 2)
s.loadedMu.Unlock() s.loadedMu.Unlock()
scenario3b.ctxDone() b.ctxDone()
select { select {
case resp := <-scenario3d.req.successCh: case resp := <-d.req.successCh:
require.Equal(t, resp.llama, scenario3d.srv) require.Equal(t, resp.llama, d.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3d.req.errCh) require.Empty(t, d.req.errCh)
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@@ -329,26 +359,19 @@ func TestGetRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done() defer done()
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
scenario1a.req.sessionDuration = &api.Duration{Duration: 0} b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10) c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
envconfig.MaxQueuedRequests = 1 envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList { s.getGpuFn = getGpuFn
g := gpu.GpuInfo{Library: "metal"} s.getCpuFn = getCpuFn
g.TotalMemory = 24 * format.GigaByte s.newServerFn = a.newServer
g.FreeMemory = 12 * format.GigaByte slog.Info("a")
return []gpu.GpuInfo{g} successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1) require.Len(t, s.pendingReqCh, 1)
slog.Info("scenario1b") slog.Info("b")
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration) successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1) require.Len(t, s.pendingReqCh, 1)
require.Empty(t, successCh1b) require.Empty(t, successCh1b)
require.Len(t, errCh1b, 1) require.Len(t, errCh1b, 1)
@@ -357,22 +380,24 @@ func TestGetRunner(t *testing.T) {
s.Run(ctx) s.Run(ctx)
select { select {
case resp := <-successCh1a: case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, errCh1a) require.Empty(t, errCh1a)
case err := <-errCh1a:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
scenario1a.ctxDone() a.ctxDone() // Set "a" model to idle so it can unload
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 1) require.Len(t, s.loaded, 1)
s.loadedMu.Unlock() s.loadedMu.Unlock()
scenario1c.req.model.ModelPath = "bad path" c.req.model.ModelPath = "bad path"
slog.Info("scenario1c") slog.Info("c")
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration) successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
// Starts in pending channel, then should be quickly processsed to return an error // Starts in pending channel, then should be quickly processsed to return an error
time.Sleep(5 * time.Millisecond) time.Sleep(20 * time.Millisecond) // Long enough for the "a" model to expire and unload
require.Empty(t, successCh1c) require.Empty(t, successCh1c)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Empty(t, s.loaded) require.Empty(t, s.loaded)
@@ -380,7 +405,7 @@ func TestGetRunner(t *testing.T) {
require.Len(t, errCh1c, 1) require.Len(t, errCh1c, 1)
err = <-errCh1c err = <-errCh1c
require.Contains(t, err.Error(), "bad path") require.Contains(t, err.Error(), "bad path")
scenario1b.ctxDone() b.ctxDone()
} }
// TODO - add one scenario that triggers the bogus finished event with positive ref count // TODO - add one scenario that triggers the bogus finished event with positive ref count
@@ -389,7 +414,7 @@ func TestPrematureExpired(t *testing.T) {
defer done() defer done()
// Same model, same request // Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil)
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList { s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"} g := gpu.GpuInfo{Library: "metal"}
@@ -411,6 +436,8 @@ func TestPrematureExpired(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
slog.Info("sending premature expired event now") slog.Info("sending premature expired event now")
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
case err := <-errCh1a:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@@ -446,6 +473,8 @@ func TestUseLoadedRunner(t *testing.T) {
select { select {
case success := <-req.successCh: case success := <-req.successCh:
require.Equal(t, r1, success) require.Equal(t, r1, success)
case err := <-req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@@ -625,8 +654,7 @@ func TestAlreadyCanceled(t *testing.T) {
defer done() defer done()
dctx, done2 := context.WithCancel(ctx) dctx, done2 := context.WithCancel(ctx)
done2() done2()
scenario1a := newScenario(t, dctx, "ollama-model-1", 10) scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0})
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
s := InitScheduler(ctx) s := InitScheduler(ctx)
slog.Info("scenario1a") slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req s.pendingReqCh <- scenario1a.req

View File

@@ -46,7 +46,7 @@ Action: ```json
{{- range .ToolCalls }} {{- range .ToolCalls }}
{ {
"tool_name": "{{ .Function.Name }}", "tool_name": "{{ .Function.Name }}",
"parameters": {{ json .Function.Arguments }} "parameters": {{ .Function.Arguments }}
} }
{{- end }} {{- end }}
]``` ]```

View File

@@ -17,7 +17,7 @@ If you decide to call functions:
Available functions as JSON spec: Available functions as JSON spec:
{{- if .Tools }} {{- if .Tools }}
{{ json .Tools }} {{ .Tools }}
{{- end }}<|eot_id|> {{- end }}<|eot_id|>
{{- end }} {{- end }}
{{- range .Messages }}<|start_header_id|> {{- range .Messages }}<|start_header_id|>
@@ -25,7 +25,7 @@ Available functions as JSON spec:
{{- end }}<|end_header_id|> {{- end }}<|end_header_id|>
{{- if .Content }}{{ .Content }} {{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }} functools[ {{- else if .ToolCalls }} functools[
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }} {{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
{{- end }}] {{- end }}]
{{- end }}<|eot_id|> {{- end }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|> {{- end }}<|start_header_id|>assistant<|end_header_id|>

View File

@@ -0,0 +1,43 @@
{{- if .Messages }}
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
{{ .System }}
{{- if .Tools }} You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"name": <function-name>,"arguments": <args-dict>}
</tool_call>
Here are the available tools:
<tools>
{{- range .Tools }} {{ .Function }}
{{- end }} </tools>
{{- end }}
{{- end }}<|eot_id|>
{{- range .Messages }}
{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|>
{{ if eq .Role "user" }}{{ .Content }}
{{- else if eq .Role "assistant" }}
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}
</tool_call>
{{- end }}
{{- else if eq .Role "tool" }}<tool_response>
{{ .Content }}
</tool_response>
{{- end }}<|eot_id|>
{{- end }}
{{- end }}<|start_header_id|>assistant<|end_header_id|>
{{ else }}
{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ end }}{{ .Response }}
{{- if .Response }}<|eot_id|>
{{- end }}

View File

@@ -0,0 +1,24 @@
<|start_header_id|>system<|end_header_id|>
You are a knowledgable assistant. You can answer questions and perform tasks. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"name": <function-name>,"arguments": <args-dict>}
</tool_call>
Here are the available tools:
<tools> {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} </tools><|eot_id|><|start_header_id|>user<|end_header_id|>
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
<tool_call>
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
</tool_call><|eot_id|><|start_header_id|>tool<|end_header_id|>
<tool_response>
22
</tool_response><|eot_id|><|start_header_id|>assistant<|end_header_id|>
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,13 +1,13 @@
{{- range $index, $_ := .Messages }} {{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }} {{- if eq .Role "user" }}
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS] {{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }} {{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
{{ end }}{{ .Content }}[/INST] {{ end }}{{ .Content }}[/INST]
{{- else if eq .Role "assistant" }} {{- else if eq .Role "assistant" }}
{{- if .Content }} {{ .Content }}</s> {{- if .Content }} {{ .Content }}</s>
{{- else if .ToolCalls }}[TOOL_CALLS] [ {{- else if .ToolCalls }}[TOOL_CALLS] [
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}} {{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}]</s> {{- end }}]</s>
{{- end }} {{- end }}
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS] {{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]

45
server/testdata/tools/xlam.gotmpl vendored Normal file
View File

@@ -0,0 +1,45 @@
{{- if .System }}{{ .System }}
{{ end }}
{{- range $i, $_ := .Messages }}
{{- if eq .Role "user" }}### Instruction:
{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }}
[BEGIN OF TASK INSTRUCTION]
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the functions can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out.
[END OF TASK INSTRUCTION]
[BEGIN OF AVAILABLE TOOLS]
{{ $.Tools }}
[END OF AVAILABLE TOOLS]
[BEGIN OF FORMAT INSTRUCTION]
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
[BEGIN OF QUERY]
{{ .Content }}
[END OF QUERY]
{{ else }}
{{ .Content }}
{{ end }}
{{- else if .ToolCalls }}### Response:
{"tool_calls": [{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}]}
<|EOT|>
{{ else if eq .Role "assistant" }}### Response:
{{ .Content }}
<|EOT|>
{{ end }}
{{- end }}### Response:

40
server/testdata/tools/xlam.out vendored Normal file
View File

@@ -0,0 +1,40 @@
You are a knowledgable assistant. You can answer questions and perform tasks.
### Instruction:
What's the weather like today in Paris?
### Response:
{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]}
<|EOT|>
### Response:
The current temperature in Paris, France is 22 degrees Celsius.
<|EOT|>
### Instruction:
[BEGIN OF TASK INSTRUCTION]
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the functions can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out.
[END OF TASK INSTRUCTION]
[BEGIN OF AVAILABLE TOOLS]
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]
[END OF AVAILABLE TOOLS]
[BEGIN OF FORMAT INSTRUCTION]
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
[BEGIN OF QUERY]
What's the weather like today in San Francisco and Toronto?
[END OF QUERY]
### Response:

8
template/alfred.json Normal file
View File

@@ -0,0 +1,8 @@
{
"stop": [
"<start_system>",
"<end_message>",
"<start_user>",
"<start_assistant>"
]
}

6
template/alpaca.json Normal file
View File

@@ -0,0 +1,6 @@
{
"stop": [
"### Instruction:",
"### Response"
]
}

6
template/chatml.json Normal file
View File

@@ -0,0 +1,6 @@
{
"stop": [
"<|im_start|>",
"<|im_end|>"
]
}

8
template/chatqa.json Normal file
View File

@@ -0,0 +1,8 @@
{
"stop": [
"System:",
"User:",
"Assistant:",
"<|begin_of_text|>"
]
}

View File

@@ -0,0 +1,7 @@
{
"stop": [
"Source:",
"Destination:",
"<step>"
]
}

View File

@@ -0,0 +1,6 @@
{
"stop": [
"User:",
"Assistant:"
]
}

View File

@@ -0,0 +1,6 @@
{
"stop": [
"<start_of_turn>",
"<end_of_turn>"
]
}

View File

@@ -0,0 +1,7 @@
{
"stop": [
"System:",
"Question:",
"Answer:"
]
}

View File

@@ -0,0 +1,8 @@
{
"stop": [
"[INST]",
"[/INST]",
"<<SYS>>",
"<</SYS>>"
]
}

View File

@@ -0,0 +1,7 @@
{
"stop": [
"<|start_header_id|>",
"<|end_header_id|>",
"<|eot_id|>"
]
}

6
template/magicoder.json Normal file
View File

@@ -0,0 +1,6 @@
{
"stop": [
"@@ Instruction",
"@@ Response"
]
}

View File

@@ -0,0 +1,6 @@
{
"stop": [
"<|im_start|>",
"<|im_end|>"
]
}

5
template/openchat.json Normal file
View File

@@ -0,0 +1,5 @@
{
"stop": [
"<|end_of_turn|>"
]
}

8
template/phi-3.json Normal file
View File

@@ -0,0 +1,8 @@
{
"stop": [
"<|end|>",
"<|system|>",
"<|user|>",
"<|assistant|>"
]
}

View File

@@ -0,0 +1,7 @@
{
"stop": [
"### System:",
"### User:",
"### Assistant"
]
}

View File

@@ -0,0 +1,7 @@
{
"stop": [
"### Instruction",
"### Response",
"<|endoftext|>"
]
}

View File

@@ -23,6 +23,7 @@ import (
var indexBytes []byte var indexBytes []byte
//go:embed *.gotmpl //go:embed *.gotmpl
//go:embed *.json
var templatesFS embed.FS var templatesFS embed.FS
var templatesOnce = sync.OnceValues(func() ([]*named, error) { var templatesOnce = sync.OnceValues(func() ([]*named, error) {
@@ -39,6 +40,15 @@ var templatesOnce = sync.OnceValues(func() ([]*named, error) {
// normalize line endings // normalize line endings
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n")) t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
params, err := templatesFS.ReadFile(t.Name + ".json")
if err != nil {
continue
}
if err := json.Unmarshal(params, &t.Parameters); err != nil {
return nil, err
}
} }
return templates, nil return templates, nil
@@ -48,6 +58,10 @@ type named struct {
Name string `json:"name"` Name string `json:"name"`
Template string `json:"template"` Template string `json:"template"`
Bytes []byte Bytes []byte
Parameters *struct {
Stop []string `json:"stop"`
}
} }
func (t named) Reader() io.Reader { func (t named) Reader() io.Reader {
@@ -150,9 +164,9 @@ func (t *Template) Vars() []string {
type Values struct { type Values struct {
Messages []api.Message Messages []api.Message
Tools []api.Tool api.Tools
Prompt string Prompt string
Suffix string Suffix string
// forceLegacy is a flag used to test compatibility with legacy templates // forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool forceLegacy bool
@@ -217,6 +231,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
"System": system, "System": system,
"Messages": messages, "Messages": messages,
"Tools": v.Tools, "Tools": v.Tools,
"Response": "",
}) })
} }
@@ -263,6 +278,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool { nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") { if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
cut = true cut = true
return false
} }
return cut return cut
@@ -270,8 +286,9 @@ func (t *Template) Execute(w io.Writer, v Values) error {
tree := parse.Tree{Root: nodes.(*parse.ListNode)} tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{ if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": system, "System": system,
"Prompt": prompt, "Prompt": prompt,
"Response": response,
}); err != nil { }); err != nil {
return err return err
} }

View File

@@ -260,6 +260,26 @@ func TestExecuteWithMessages(t *testing.T) {
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `, Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
}, },
{
"mistral assistant",
[]template{
{"no response", `[INST] {{ .Prompt }}[/INST] `},
{"response", `[INST] {{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `
{{- range $i, $m := .Messages }}
{{- if eq .Role "user" }}[INST] {{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}{{ end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
{Role: "assistant", Content: "My name is Ollama and I"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I`,
},
{ {
"chatml", "chatml",
[]template{ []template{

6
template/vicuna.json Normal file
View File

@@ -0,0 +1,6 @@
{
"stop": [
"USER:",
"ASSISTANT:"
]
}

8
template/zephyr.json Normal file
View File

@@ -0,0 +1,8 @@
{
"stop": [
"<|system|>",
"</s>",
"<|user|>",
"<|assistant|>"
]
}