Compare commits
46 Commits
mattw/faq-
...
mattw/quan
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fed3843be2 | ||
|
|
01d4047ed3 | ||
|
|
b5939008a1 | ||
|
|
e9ce91e9a6 | ||
|
|
4ad6c9b11f | ||
|
|
c0285158a9 | ||
|
|
77a66df72c | ||
|
|
5b4837f881 | ||
|
|
29340c2e62 | ||
|
|
d5ec730354 | ||
|
|
8bed487aba | ||
|
|
c1a10a6e9b | ||
|
|
ddbfa6fe31 | ||
|
|
2fcd41ef81 | ||
|
|
16f4603b67 | ||
|
|
1184686649 | ||
|
|
2588cb2daa | ||
|
|
c7ea8f237e | ||
|
|
0b3118e0af | ||
|
|
05face44ef | ||
|
|
a2ad952440 | ||
|
|
5fea4410be | ||
|
|
b846eb64d0 | ||
|
|
3c5dd9ed1d | ||
|
|
b17ccd0542 | ||
|
|
d0409f772f | ||
|
|
ec261422af | ||
|
|
0498f7ce56 | ||
|
|
738a8d12eb | ||
|
|
d966b730ac | ||
|
|
9a70aecccb | ||
|
|
22cd5eaab6 | ||
|
|
304a8799ca | ||
|
|
2a2fa3c329 | ||
|
|
55978c1dc9 | ||
|
|
d4ebdadbe7 | ||
|
|
c5f21f73a4 | ||
|
|
371bc73531 | ||
|
|
c651d8b824 | ||
|
|
cf50ef5b51 | ||
|
|
697bea6939 | ||
|
|
10da41d677 | ||
|
|
db356c8519 | ||
|
|
b80081022f | ||
|
|
790457398a | ||
|
|
511069a2a5 |
@@ -49,6 +49,7 @@ Here are some example open-source models that can be downloaded:
|
||||
| ------------------ | ---------- | ----- | ------------------------------ |
|
||||
| Llama 2 | 7B | 3.8GB | `ollama run llama2` |
|
||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||
| Dolphin Phi | 2.7B | 1.6GB | `ollama run dolphin-phi` |
|
||||
| Phi-2 | 2.7B | 1.7GB | `ollama run phi` |
|
||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
||||
| Starling | 7B | 4.1GB | `ollama run starling-lm` |
|
||||
@@ -62,7 +63,7 @@ Here are some example open-source models that can be downloaded:
|
||||
|
||||
> Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
|
||||
## Customize your own model
|
||||
## Customize a model
|
||||
|
||||
### Import from GGUF
|
||||
|
||||
@@ -261,6 +262,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
|
||||
- [Amica](https://github.com/semperai/amica)
|
||||
- [chatd](https://github.com/BruceMacD/chatd)
|
||||
- [Ollama-SwiftUI](https://github.com/kghandour/Ollama-SwiftUI)
|
||||
|
||||
|
||||
### Terminal
|
||||
|
||||
@@ -300,7 +303,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
### Mobile
|
||||
|
||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||
- [Maid](https://github.com/danemadsen/Maid)
|
||||
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
||||
|
||||
### Extensions & Plugins
|
||||
|
||||
|
||||
28
cmd/cmd.go
28
cmd/cmd.go
@@ -662,10 +662,11 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||
|
||||
usage := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
@@ -687,6 +688,21 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
usageShortcuts := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available keyboard shortcuts:")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + a Move to the beginning of the line (Home)")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + e Move to the end of the line (End)")
|
||||
fmt.Fprintln(os.Stderr, " Alt + b Move back (left) one word")
|
||||
fmt.Fprintln(os.Stderr, " Alt + f Move forward (right) one word")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + k Delete the sentence after the cursor")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + u Delete the sentence before the cursor")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl + d Exit ollama (/bye)")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
usageShow := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /show license Show model license")
|
||||
@@ -737,7 +753,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||
return nil
|
||||
case errors.Is(err, readline.ErrInterrupt):
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl-D or /bye to exit.")
|
||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||
}
|
||||
|
||||
scanner.Prompt.UseAlt = false
|
||||
@@ -940,6 +956,8 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||
usageSet()
|
||||
case "show", "/show":
|
||||
usageShow()
|
||||
case "shortcut", "shortcuts":
|
||||
usageShortcuts()
|
||||
}
|
||||
} else {
|
||||
usage()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Documentation
|
||||
|
||||
To get started, see the project's **[quicktart](../README.md#quickstart)**.
|
||||
To get started, see the project's **[quickstart](../README.md#quickstart)**.
|
||||
|
||||
Ollama is a tool for running AI models on your hardware. Many users will choose to use the Command Line Interface (CLI) to work with Ollama. Learn more about all the commands in the CLI in the **[Main Readme](../README.md)**.
|
||||
|
||||
@@ -22,4 +22,4 @@ Finally for all the questions that don't fit anywhere else, there is the **[FAQ]
|
||||
|
||||
[Tutorials](./tutorials.md) apply the documentation to tasks.
|
||||
|
||||
For working code examples of using Ollama, see [Examples](../examples).
|
||||
For working code examples of using Ollama, see [Examples](../examples).
|
||||
|
||||
73
docs/api.md
73
docs/api.md
@@ -27,7 +27,6 @@ All durations are returned in nanoseconds.
|
||||
|
||||
Certain endpoints stream responses as JSON objects and can optional return non-streamed responses.
|
||||
|
||||
|
||||
## Generate a completion
|
||||
|
||||
```shell
|
||||
@@ -47,7 +46,7 @@ Advanced parameters (optional):
|
||||
- `format`: the format to return a response in. Currently the only accepted value is `json`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `system`: system message to (overrides what is defined in the `Modelfile`)
|
||||
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
|
||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API.
|
||||
@@ -104,12 +103,12 @@ To calculate how fast the response is generated in tokens per second (token/s),
|
||||
"response": "",
|
||||
"done": true,
|
||||
"context": [1, 2, 3],
|
||||
"total_duration":10706818083,
|
||||
"load_duration":6338219291,
|
||||
"prompt_eval_count":26,
|
||||
"prompt_eval_duration":130079000,
|
||||
"eval_count":259,
|
||||
"eval_duration":4232710000
|
||||
"total_duration": 10706818083,
|
||||
"load_duration": 6338219291,
|
||||
"prompt_eval_count": 26,
|
||||
"prompt_eval_duration": 130079000,
|
||||
"eval_count": 259,
|
||||
"eval_duration": 4232710000
|
||||
}
|
||||
```
|
||||
|
||||
@@ -170,7 +169,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
"created_at": "2023-11-09T21:07:55.186497Z",
|
||||
"response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n",
|
||||
"done": true,
|
||||
"context": [1, 2, 3],
|
||||
"context": [1, 2, 3],
|
||||
"total_duration": 4648158584,
|
||||
"load_duration": 4071084,
|
||||
"prompt_eval_count": 36,
|
||||
@@ -235,6 +234,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
#### Request (Raw Mode)
|
||||
|
||||
In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable templating. Also note that raw mode will not return a context.
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
@@ -319,7 +319,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"response": "The sky is blue because it is the color of the sky.",
|
||||
"done": true,
|
||||
"context": [1, 2, 3],
|
||||
"context": [1, 2, 3],
|
||||
"total_duration": 4935886791,
|
||||
"load_duration": 534986708,
|
||||
"prompt_eval_count": 26,
|
||||
@@ -347,10 +347,10 @@ A single JSON object is returned:
|
||||
|
||||
```json
|
||||
{
|
||||
"model":"llama2",
|
||||
"created_at":"2023-12-18T19:52:07.071755Z",
|
||||
"response":"",
|
||||
"done":true
|
||||
"model": "llama2",
|
||||
"created_at": "2023-12-18T19:52:07.071755Z",
|
||||
"response": "",
|
||||
"done": true
|
||||
}
|
||||
```
|
||||
|
||||
@@ -377,7 +377,7 @@ Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Currently the only accepted value is `json`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
|
||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
|
||||
### Examples
|
||||
@@ -410,7 +410,7 @@ A stream of JSON objects is returned:
|
||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||
"message": {
|
||||
"role": "assisant",
|
||||
"content": "The",
|
||||
"content": "The",
|
||||
"images": null
|
||||
},
|
||||
"done": false
|
||||
@@ -424,12 +424,12 @@ Final response:
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"done": true,
|
||||
"total_duration":4883583458,
|
||||
"load_duration":1334875,
|
||||
"prompt_eval_count":26,
|
||||
"prompt_eval_duration":342546000,
|
||||
"eval_count":282,
|
||||
"eval_duration":4535599000
|
||||
"total_duration": 4883583458,
|
||||
"load_duration": 1334875,
|
||||
"prompt_eval_count": 26,
|
||||
"prompt_eval_duration": 342546000,
|
||||
"eval_count": 282,
|
||||
"eval_duration": 4535599000
|
||||
}
|
||||
```
|
||||
|
||||
@@ -445,7 +445,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
"role": "user",
|
||||
"content": "why is the sky blue?"
|
||||
}
|
||||
],
|
||||
],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
@@ -519,12 +519,12 @@ Final response:
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"done": true,
|
||||
"total_duration":8113331500,
|
||||
"load_duration":6396458,
|
||||
"prompt_eval_count":61,
|
||||
"prompt_eval_duration":398801000,
|
||||
"eval_count":468,
|
||||
"eval_duration":7701267000
|
||||
"total_duration": 8113331500,
|
||||
"load_duration": 6396458,
|
||||
"prompt_eval_count": 61,
|
||||
"prompt_eval_duration": 398801000,
|
||||
"eval_count": 468,
|
||||
"eval_duration": 7701267000
|
||||
}
|
||||
```
|
||||
|
||||
@@ -559,12 +559,12 @@ curl http://localhost:11434/api/chat -d '{
|
||||
"images": null
|
||||
},
|
||||
"done": true,
|
||||
"total_duration":1668506709,
|
||||
"load_duration":1986209,
|
||||
"prompt_eval_count":26,
|
||||
"prompt_eval_duration":359682000,
|
||||
"eval_count":83,
|
||||
"eval_duration":1303285000
|
||||
"total_duration": 1668506709,
|
||||
"load_duration": 1986209,
|
||||
"prompt_eval_count": 26,
|
||||
"prompt_eval_duration": 359682000,
|
||||
"eval_count": 83,
|
||||
"eval_duration": 1303285000
|
||||
}
|
||||
```
|
||||
|
||||
@@ -574,7 +574,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
POST /api/create
|
||||
```
|
||||
|
||||
Create a model from a [`Modelfile`](./modelfile.md). It is recommended to set `modelfile` to the content of the Modelfile rather than just set `path`. This is a requirement for remote create. Remote model creation must also create any file blobs, fields such as `FROM` and `ADAPTER`, explicitly with the server using [Create a Blob](#create-a-blob) and the value to the path indicated in the response.
|
||||
Create a model from a [`Modelfile`](./modelfile.md). It is recommended to set `modelfile` to the content of the Modelfile rather than just set `path`. This is a requirement for remote create. Remote model creation must also create any file blobs, fields such as `FROM` and `ADAPTER`, explicitly with the server using [Create a Blob](#create-a-blob) and the value to the path indicated in the response.
|
||||
|
||||
### Parameters
|
||||
|
||||
@@ -624,7 +624,6 @@ HEAD /api/blobs/:digest
|
||||
|
||||
Ensures that the file blob used for a FROM or ADAPTER field exists on the server. This is checking your Ollama server and not Ollama.ai.
|
||||
|
||||
|
||||
#### Query Parameters
|
||||
|
||||
- `digest`: the SHA256 digest of the blob
|
||||
|
||||
25
docs/faq.md
25
docs/faq.md
@@ -62,8 +62,6 @@ Refer to the section above for how to use environment variables on your platform
|
||||
- macOS: `~/.ollama/models`.
|
||||
- Linux: `/usr/share/ollama/.ollama/models`
|
||||
|
||||
See [the CLI Documentation](./cli.md) for more on this.
|
||||
|
||||
## How do I set them to a different location?
|
||||
|
||||
If a different directory needs to be used, set the environment variable `OLLAMA_MODELS` to the chosen directory. Refer to the section above for how to use environment variables on your platform.
|
||||
@@ -114,3 +112,26 @@ This can impact both installing Ollama, as well as downloading models.
|
||||
Open `Control Panel > Networking and Internet > View network status and tasks` and click on `Change adapter settings` on the left panel. Find the `vEthernel (WSL)` adapter, right click and select `Properties`.
|
||||
Click on `Configure` and open the `Advanced` tab. Search through each of the properties until you find `Large Send Offload Version 2 (IPv4)` and `Large Send Offload Version 2 (IPv6)`. *Disable* both of these
|
||||
properties.
|
||||
|
||||
## What does the q in the model tag mean? What is quantization?
|
||||
|
||||
Whenever you pull a model without a tag, Ollama will actually pull the q4_0 quantization of the model. You can verify this on the tags page. On https://ollama.ai/library/llama2/tags you can see that the hash for the latest tag matches the hash for the 7b model. 
|
||||
|
||||
Looking at the that page for any model, you can see several quantization options available. Quantization is a method of compression that allows the model to fit in less space and thus use less RAM and VRAM on your machine.
|
||||
|
||||
At a high level, a model is made of an enormous collection of nodes that determine how to generate text. These nodes are connected at different levels with weights. The training process adjusts these weights to be able to output the right text every time.
|
||||
|
||||
Most of the source models that we use start with weights that are 32bit floating-point numbers. Those weights, and another concept called biases, add up to be the parameters. So a source model with 7 billion parameters has 7 billion 32bit floating-point numbers, plus a description of all the nodes and more. That adds up to needing at least 28 Gigabytes of memory to load, if you choose to load one of those source models.
|
||||
|
||||
Quantization turns those 32bit floating point weights into much smaller integers. The number next to the q indicates the bit size of the weights. So a q4 model converted those 32bit floats into 4bit integers. A 4bit quantization takes up the space for 7billion 4bit integers, plus a little overhead. That comes out to almost 4 Gigabytes. Obviously, there is some loss of information in this process of going from 30GB to 4GB, but it turns out in most cases it isn't really noticeable. In fact, even the 2bit quantization which fits in less than 3GB can be very useful.
|
||||
|
||||
There are three major sets of quantizations you will see in the Ollama Library of models: **fp16**, models with just a q and a number, like **q4_0**, and then models with a **K** in the tag. The **fp16** model is one that has been converted and quantized from the source 32bit to 16bit. This will be about half the size of the 32bit source model and is the largest quantization we deliver in the library. The **q4_0**, **q4_1**, **q5_0**, etc. models use two different quantization methods that were the original methods.
|
||||
|
||||
The models with a **K** are often referred to as K Quants. This is a method that allows for models of a similar quality but smaller than the original method used. Essentially, it finds clusters of weights and quantizes those together, allowing for higher precision while using the same bit sizes as the regular quantization options. But this requires a set of maps for the model to figure out the original values which have a computational cost. You may see some impact on the speed of models with K quants compared to the regular quantizations.
|
||||
|
||||
## What is context, can I increase it, and why doesn't every model support a huge context?
|
||||
|
||||
Context refers to the size of the input you can send to a model and get sensible output back. Many models have a context size of 2048 tokens. It's sometimes possible to give it more using the **num_ctx** parameter, but the answers start to degrade. This is because half of the context is "freed" up to allow for more memory. Newer models have been able to increase that context size using different methods. This increase in context size results in a corresponding increase in memory required, sometimes by orders of magnitude.
|
||||
|
||||
> !WARNING]
|
||||
> Currently, over-allocating context size may result in model quality or stability issues.
|
||||
|
||||
@@ -148,6 +148,7 @@ The quantization options are as follow (from highest highest to lowest levels of
|
||||
- `q5_K_M`
|
||||
- `q6_K`
|
||||
- `q8_0`
|
||||
- `f16`
|
||||
|
||||
## Manually converting & quantizing models
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
responseData, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
||||
37
gpu/gpu.go
37
gpu/gpu.go
@@ -13,6 +13,7 @@ import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
@@ -65,15 +66,14 @@ func GetGPUInfo() GpuInfo {
|
||||
}
|
||||
|
||||
var memInfo C.mem_info_t
|
||||
resp := GpuInfo{"", "", 0, 0}
|
||||
resp := GpuInfo{}
|
||||
if gpuHandles.cuda != nil {
|
||||
C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
|
||||
if memInfo.err != nil {
|
||||
log.Printf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
} else {
|
||||
resp.Driver = "CUDA"
|
||||
resp.Library = "cuda_server"
|
||||
resp.Library = "cuda"
|
||||
}
|
||||
} else if gpuHandles.rocm != nil {
|
||||
C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
|
||||
@@ -81,15 +81,17 @@ func GetGPUInfo() GpuInfo {
|
||||
log.Printf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
} else {
|
||||
resp.Driver = "ROCM"
|
||||
resp.Library = "rocm_server"
|
||||
resp.Library = "rocm"
|
||||
}
|
||||
}
|
||||
if resp.Driver == "" {
|
||||
if resp.Library == "" {
|
||||
C.cpu_check_ram(&memInfo)
|
||||
resp.Driver = "CPU"
|
||||
// In the future we may offer multiple CPU variants to tune CPU features
|
||||
resp.Library = "default"
|
||||
if runtime.GOOS == "windows" {
|
||||
resp.Library = "cpu"
|
||||
} else {
|
||||
resp.Library = "default"
|
||||
}
|
||||
}
|
||||
if memInfo.err != nil {
|
||||
log.Printf("error looking up CPU memory: %s", C.GoString(memInfo.err))
|
||||
@@ -101,9 +103,22 @@ func GetGPUInfo() GpuInfo {
|
||||
return resp
|
||||
}
|
||||
|
||||
func getCPUMem() (memInfo, error) {
|
||||
var ret memInfo
|
||||
var info C.mem_info_t
|
||||
C.cpu_check_ram(&info)
|
||||
if info.err != nil {
|
||||
defer C.free(unsafe.Pointer(info.err))
|
||||
return ret, fmt.Errorf(C.GoString(info.err))
|
||||
}
|
||||
ret.FreeMemory = uint64(info.free)
|
||||
ret.TotalMemory = uint64(info.total)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func CheckVRAM() (int64, error) {
|
||||
gpuInfo := GetGPUInfo()
|
||||
if gpuInfo.FreeMemory > 0 && gpuInfo.Driver != "CPU" {
|
||||
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
|
||||
return int64(gpuInfo.FreeMemory), nil
|
||||
}
|
||||
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
|
||||
@@ -114,7 +129,7 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
||||
return opts.NumGPU
|
||||
}
|
||||
info := GetGPUInfo()
|
||||
if info.Driver == "CPU" {
|
||||
if info.Library == "cpu" || info.Library == "default" {
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -128,7 +143,7 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
||||
// 75% of the absolute max number of layers we can fit in available VRAM, off-loading too many layers to the GPU can cause OOM errors
|
||||
layers := int(info.FreeMemory/bytesPerLayer) * 3 / 4
|
||||
|
||||
log.Printf("%d MB VRAM available, loading up to %d %s GPU layers out of %d", info.FreeMemory/(1024*1024), layers, info.Driver, numLayer)
|
||||
log.Printf("%d MB VRAM available, loading up to %d %s GPU layers out of %d", info.FreeMemory/(1024*1024), layers, info.Library, numLayer)
|
||||
|
||||
return layers
|
||||
}
|
||||
|
||||
@@ -18,21 +18,30 @@ func CheckVRAM() (int64, error) {
|
||||
|
||||
func GetGPUInfo() GpuInfo {
|
||||
// TODO - Metal vs. x86 macs...
|
||||
|
||||
mem, _ := getCPUMem()
|
||||
return GpuInfo{
|
||||
Driver: "METAL",
|
||||
Library: "default",
|
||||
TotalMemory: 0,
|
||||
FreeMemory: 0,
|
||||
Library: "default",
|
||||
memInfo: mem,
|
||||
}
|
||||
}
|
||||
|
||||
func getCPUMem() (memInfo, error) {
|
||||
return memInfo{
|
||||
TotalMemory: 0,
|
||||
FreeMemory: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
||||
if opts.NumGPU != -1 {
|
||||
return opts.NumGPU
|
||||
}
|
||||
|
||||
// metal only supported on arm64
|
||||
if runtime.GOARCH == "arm64" {
|
||||
return 1
|
||||
}
|
||||
|
||||
// metal only supported on arm64
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
@@ -9,20 +9,21 @@
|
||||
#include <dlfcn.h>
|
||||
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags)
|
||||
#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
|
||||
#define LOAD_ERR() dlerror()
|
||||
#define LOAD_ERR() strdup(dlerror())
|
||||
#define UNLOAD_LIBRARY(handle) dlclose(handle)
|
||||
#else
|
||||
#include <windows.h>
|
||||
#define LOAD_LIBRARY(lib, flags) LoadLibrary(lib)
|
||||
#define LOAD_SYMBOL(handle, sym) GetProcAddress(handle, sym)
|
||||
#define UNLOAD_LIBRARY(handle) FreeLibrary(handle)
|
||||
|
||||
// TODO - refactor this with proper error message handling on windows
|
||||
inline static char *LOAD_ERR() {
|
||||
static char errbuf[8];
|
||||
snprintf(errbuf, 8, "0x%lx", GetLastError());
|
||||
return errbuf;
|
||||
}
|
||||
#define LOAD_ERR() ({\
|
||||
LPSTR messageBuffer = NULL; \
|
||||
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, \
|
||||
NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL); \
|
||||
char *resp = strdup(messageBuffer); \
|
||||
LocalFree(messageBuffer); \
|
||||
resp; \
|
||||
})
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
void cpu_check_ram(mem_info_t *resp) {
|
||||
resp->err = NULL;
|
||||
MEMORYSTATUSEX info;
|
||||
info.dwLength = sizeof(info);
|
||||
if (GlobalMemoryStatusEx(&info) != 0) {
|
||||
resp->total = info.ullTotalPhys;
|
||||
resp->free = info.ullAvailPhys;
|
||||
} else {
|
||||
resp->err = strdup(LOAD_ERR());
|
||||
resp->err = LOAD_ERR();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -43,9 +43,11 @@ void cuda_init(cuda_init_resp_t *resp) {
|
||||
if (!resp->ch.handle) {
|
||||
// TODO improve error message, as the LOAD_ERR will have typically have the
|
||||
// final path that was checked which might be confusing.
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(buf, buflen,
|
||||
"Unable to load %s library to query for Nvidia GPUs: %s",
|
||||
cuda_lib_paths[0], LOAD_ERR());
|
||||
cuda_lib_paths[0], msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
@@ -55,8 +57,10 @@ void cuda_init(cuda_init_resp_t *resp) {
|
||||
if (!l[i].p) {
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
||||
LOAD_ERR());
|
||||
msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -40,9 +40,11 @@ void rocm_init(rocm_init_resp_t *resp) {
|
||||
resp->rh.handle = LOAD_LIBRARY(rocm_lib_paths[i], RTLD_LAZY);
|
||||
}
|
||||
if (!resp->rh.handle) {
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(buf, buflen,
|
||||
"Unable to load %s library to query for Radeon GPUs: %s\n",
|
||||
rocm_lib_paths[0], LOAD_ERR());
|
||||
rocm_lib_paths[0], msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
@@ -51,8 +53,10 @@ void rocm_init(rocm_init_resp_t *resp) {
|
||||
*l[i].p = LOAD_SYMBOL(resp->rh.handle, l[i].s);
|
||||
if (!l[i].p) {
|
||||
UNLOAD_LIBRARY(resp->rh.handle);
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
||||
LOAD_ERR());
|
||||
msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
func TestBasicGetGPUInfo(t *testing.T) {
|
||||
info := GetGPUInfo()
|
||||
assert.Contains(t, "CUDA ROCM CPU METAL", info.Driver)
|
||||
assert.Contains(t, "cuda rocm cpu default", info.Library)
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
@@ -23,4 +23,19 @@ func TestBasicGetGPUInfo(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCPUMemInfo(t *testing.T) {
|
||||
info, err := getCPUMem()
|
||||
assert.NoError(t, err)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
t.Skip("CPU memory not populated on darwin")
|
||||
case "linux", "windows":
|
||||
assert.Greater(t, info.TotalMemory, uint64(0))
|
||||
assert.Greater(t, info.FreeMemory, uint64(0))
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TODO - add some logic to figure out card type through other means and actually verify we got back what we expected
|
||||
|
||||
11
gpu/types.go
11
gpu/types.go
@@ -1,11 +1,14 @@
|
||||
package gpu
|
||||
|
||||
type memInfo struct {
|
||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
}
|
||||
|
||||
// Beginning of an `ollama info` command
|
||||
type GpuInfo struct {
|
||||
Driver string `json:"driver,omitempty"`
|
||||
Library string `json:"library,omitempty"`
|
||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
memInfo
|
||||
Library string `json:"library,omitempty"`
|
||||
|
||||
// TODO add other useful attributes about the card here for discovery information
|
||||
}
|
||||
|
||||
@@ -7,24 +7,29 @@
|
||||
#include <dlfcn.h>
|
||||
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags | RTLD_DEEPBIND)
|
||||
#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
|
||||
#define LOAD_ERR() dlerror()
|
||||
#define LOAD_ERR() strdup(dlerror())
|
||||
#define UNLOAD_LIBRARY(handle) dlclose(handle)
|
||||
#elif _WIN32
|
||||
#include <windows.h>
|
||||
#define LOAD_LIBRARY(lib, flags) LoadLibrary(lib)
|
||||
#define LOAD_SYMBOL(handle, sym) GetProcAddress(handle, sym)
|
||||
#define UNLOAD_LIBRARY(handle) FreeLibrary(handle)
|
||||
// TODO - refactor this with proper error message handling on windows
|
||||
inline static char *LOAD_ERR() {
|
||||
static char errbuf[8];
|
||||
snprintf(errbuf, 8, "0x%lx", GetLastError());
|
||||
return errbuf;
|
||||
inline char *LOAD_ERR() {
|
||||
LPSTR messageBuffer = NULL;
|
||||
size_t size = FormatMessageA(
|
||||
FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
|
||||
FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||
NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
|
||||
(LPSTR)&messageBuffer, 0, NULL);
|
||||
char *resp = strdup(messageBuffer);
|
||||
LocalFree(messageBuffer);
|
||||
return resp;
|
||||
}
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags)
|
||||
#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
|
||||
#define LOAD_ERR() dlerror()
|
||||
#define LOAD_ERR() strdup(dlerror())
|
||||
#define UNLOAD_LIBRARY(handle) dlclose(handle)
|
||||
#endif
|
||||
|
||||
@@ -57,8 +62,10 @@ void dynamic_shim_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
s->handle = LOAD_LIBRARY(libPath, RTLD_NOW);
|
||||
if (!s->handle) {
|
||||
err->id = -1;
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(err->msg, err->msg_len,
|
||||
"Unable to load dynamic server library: %s", LOAD_ERR());
|
||||
"Unable to load dynamic server library: %s", msg);
|
||||
free(msg);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -67,8 +74,10 @@ void dynamic_shim_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
if (!l[i].p) {
|
||||
UNLOAD_LIBRARY(s->handle);
|
||||
err->id = -1;
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(err->msg, err->msg_len, "symbol lookup for %s failed: %s",
|
||||
l[i].s, LOAD_ERR());
|
||||
l[i].s, msg);
|
||||
free(msg);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "server.h"
|
||||
#include "ext_server.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package llm
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I${SRCDIR}/llama.cpp/gguf -I${SRCDIR}/llama.cpp/gguf/common -I${SRCDIR}/llama.cpp/gguf/examples/server
|
||||
#cgo CFLAGS: -I${SRCDIR}/llama.cpp -I${SRCDIR}/llama.cpp/gguf -I${SRCDIR}/llama.cpp/gguf/common -I${SRCDIR}/llama.cpp/gguf/examples/server
|
||||
#cgo CFLAGS: -DNDEBUG -DLLAMA_SERVER_LIBRARY=1 -D_XOPEN_SOURCE=600 -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
||||
#cgo CFLAGS: -Wmissing-noreturn -Wall -Wextra -Wcast-qual -Wno-unused-function -Wno-array-bounds
|
||||
#cgo CPPFLAGS: -Ofast -Wall -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations -Wno-unused-but-set-variable
|
||||
@@ -10,23 +10,22 @@ package llm
|
||||
#cgo darwin CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG
|
||||
#cgo darwin LDFLAGS: -lc++ -framework Accelerate
|
||||
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
|
||||
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/metal/common/libcommon.a
|
||||
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/metal/examples/server/libext_server.a
|
||||
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/metal/libllama.a
|
||||
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/metal/libggml_static.a
|
||||
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/darwin/metal/lib/libcommon.a
|
||||
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/darwin/metal/lib/libext_server.a
|
||||
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/darwin/metal/lib/libllama.a
|
||||
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/darwin/metal/lib/libggml_static.a
|
||||
#cgo linux CFLAGS: -D_GNU_SOURCE
|
||||
#cgo linux windows CFLAGS: -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS
|
||||
#cgo linux LDFLAGS: -L/usr/local/cuda/targets/x86_64-linux/lib -L/usr/local/cuda/lib64 -L/usr/local/cuda/targets/x86_64-linux/lib/stubs
|
||||
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/cpu/examples/server/libext_server.a
|
||||
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/cpu/common/libcommon.a
|
||||
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/cpu/libllama.a
|
||||
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/cpu/libggml_static.a
|
||||
#cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
|
||||
#cgo windows LDFLAGS: -L${SRCDIR}/llama.cpp/gguf/build/wincpu/dist/lib
|
||||
#cgo windows LDFLAGS: -lcpu_server -lpthread
|
||||
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/linux/cpu/lib/libext_server.a
|
||||
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/linux/cpu/lib/libcommon.a
|
||||
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/linux/cpu/lib/libllama.a
|
||||
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/linux/cpu/lib/libggml_static.a
|
||||
#cgo linux LDFLAGS: -lrt -ldl -lstdc++ -lm
|
||||
#cgo linux windows LDFLAGS: -lpthread
|
||||
|
||||
#include <stdlib.h>
|
||||
#include "server.h"
|
||||
#include "ext_server.h"
|
||||
|
||||
*/
|
||||
import "C"
|
||||
@@ -46,6 +45,24 @@ import (
|
||||
"github.com/jmorganca/ollama/gpu"
|
||||
)
|
||||
|
||||
type extServer interface {
|
||||
LLM
|
||||
llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t)
|
||||
llama_server_start()
|
||||
llama_server_stop()
|
||||
llama_server_completion(json_req *C.char, resp *C.ext_server_resp_t)
|
||||
llama_server_completion_next_result(task_id C.int, resp *C.ext_server_task_result_t)
|
||||
llama_server_completion_cancel(task_id C.int, err *C.ext_server_resp_t)
|
||||
llama_server_release_task_result(result *C.ext_server_task_result_t)
|
||||
llama_server_tokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
|
||||
llama_server_detokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
|
||||
llama_server_embedding(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
|
||||
llama_server_release_json_resp(json_resp **C.char)
|
||||
}
|
||||
|
||||
// Note: current implementation does not support concurrent instantiations
|
||||
var mutex sync.Mutex
|
||||
|
||||
func newExtServerResp(len C.size_t) C.ext_server_resp_t {
|
||||
var resp C.ext_server_resp_t
|
||||
resp.msg_len = len
|
||||
@@ -65,69 +82,6 @@ func extServerResponseToErr(resp C.ext_server_resp_t) error {
|
||||
return fmt.Errorf(C.GoString(resp.msg))
|
||||
}
|
||||
|
||||
type extServer interface {
|
||||
LLM
|
||||
llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t)
|
||||
llama_server_start()
|
||||
llama_server_stop()
|
||||
llama_server_completion(json_req *C.char, resp *C.ext_server_resp_t)
|
||||
llama_server_completion_next_result(task_id C.int, resp *C.ext_server_task_result_t)
|
||||
llama_server_completion_cancel(task_id C.int, err *C.ext_server_resp_t)
|
||||
llama_server_release_task_result(result *C.ext_server_task_result_t)
|
||||
llama_server_tokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
|
||||
llama_server_detokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
|
||||
llama_server_embedding(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
|
||||
llama_server_release_json_resp(json_resp **C.char)
|
||||
}
|
||||
|
||||
type llamaExtServer struct {
|
||||
api.Options
|
||||
}
|
||||
|
||||
// Note: current implementation does not support concurrent instantiations
|
||||
var mutex sync.Mutex
|
||||
|
||||
func (llm *llamaExtServer) llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t) {
|
||||
C.llama_server_init(sparams, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_start() {
|
||||
C.llama_server_start()
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_stop() {
|
||||
C.llama_server_stop()
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) llama_server_completion(json_req *C.char, resp *C.ext_server_resp_t) {
|
||||
C.llama_server_completion(json_req, resp)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_completion_next_result(task_id C.int, resp *C.ext_server_task_result_t) {
|
||||
C.llama_server_completion_next_result(task_id, resp)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_completion_cancel(task_id C.int, err *C.ext_server_resp_t) {
|
||||
C.llama_server_completion_cancel(task_id, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_release_task_result(result *C.ext_server_task_result_t) {
|
||||
C.llama_server_release_task_result(result)
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) llama_server_tokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
|
||||
C.llama_server_tokenize(json_req, json_resp, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_detokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
|
||||
C.llama_server_detokenize(json_req, json_resp, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_embedding(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
|
||||
C.llama_server_embedding(json_req, json_resp, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_release_json_resp(json_resp **C.char) {
|
||||
C.llama_server_release_json_resp(json_resp)
|
||||
}
|
||||
|
||||
func newDefaultExtServer(model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
||||
server := &llamaExtServer{opts}
|
||||
return newExtServer(server, model, adapters, projectors, numLayers, opts)
|
||||
}
|
||||
|
||||
func newExtServer(server extServer, model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
||||
if !mutex.TryLock() {
|
||||
log.Printf("concurrent llm servers not yet supported, waiting for prior server to complete")
|
||||
@@ -199,11 +153,7 @@ func newExtServer(server extServer, model string, adapters, projectors []string,
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
|
||||
return predict(llm, llm.Options, ctx, pred, fn)
|
||||
}
|
||||
|
||||
func predict(llm extServer, opts api.Options, ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
||||
func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(PredictResult)) error {
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
var imageData []ImageData
|
||||
@@ -217,24 +167,25 @@ func predict(llm extServer, opts api.Options, ctx context.Context, predict Predi
|
||||
request := map[string]any{
|
||||
"prompt": predict.Prompt,
|
||||
"stream": true,
|
||||
"n_predict": opts.NumPredict,
|
||||
"n_keep": opts.NumKeep,
|
||||
"temperature": opts.Temperature,
|
||||
"top_k": opts.TopK,
|
||||
"top_p": opts.TopP,
|
||||
"tfs_z": opts.TFSZ,
|
||||
"typical_p": opts.TypicalP,
|
||||
"repeat_last_n": opts.RepeatLastN,
|
||||
"repeat_penalty": opts.RepeatPenalty,
|
||||
"presence_penalty": opts.PresencePenalty,
|
||||
"frequency_penalty": opts.FrequencyPenalty,
|
||||
"mirostat": opts.Mirostat,
|
||||
"mirostat_tau": opts.MirostatTau,
|
||||
"mirostat_eta": opts.MirostatEta,
|
||||
"penalize_nl": opts.PenalizeNewline,
|
||||
"seed": opts.Seed,
|
||||
"stop": opts.Stop,
|
||||
"n_predict": predict.Options.NumPredict,
|
||||
"n_keep": predict.Options.NumKeep,
|
||||
"temperature": predict.Options.Temperature,
|
||||
"top_k": predict.Options.TopK,
|
||||
"top_p": predict.Options.TopP,
|
||||
"tfs_z": predict.Options.TFSZ,
|
||||
"typical_p": predict.Options.TypicalP,
|
||||
"repeat_last_n": predict.Options.RepeatLastN,
|
||||
"repeat_penalty": predict.Options.RepeatPenalty,
|
||||
"presence_penalty": predict.Options.PresencePenalty,
|
||||
"frequency_penalty": predict.Options.FrequencyPenalty,
|
||||
"mirostat": predict.Options.Mirostat,
|
||||
"mirostat_tau": predict.Options.MirostatTau,
|
||||
"mirostat_eta": predict.Options.MirostatEta,
|
||||
"penalize_nl": predict.Options.PenalizeNewline,
|
||||
"seed": predict.Options.Seed,
|
||||
"stop": predict.Options.Stop,
|
||||
"image_data": imageData,
|
||||
"cache_prompt": true,
|
||||
}
|
||||
|
||||
if predict.Format == "json" {
|
||||
@@ -325,9 +276,6 @@ func predict(llm extServer, opts api.Options, ctx context.Context, predict Predi
|
||||
// should never reach here ideally
|
||||
return fmt.Errorf("max retries exceeded")
|
||||
}
|
||||
func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
return encode(llm, ctx, prompt)
|
||||
}
|
||||
|
||||
func encode(llm extServer, ctx context.Context, prompt string) ([]int, error) {
|
||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
||||
@@ -353,10 +301,6 @@ func encode(llm extServer, ctx context.Context, prompt string) ([]int, error) {
|
||||
return encoded.Tokens, err
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
return decode(llm, ctx, tokens)
|
||||
}
|
||||
|
||||
func decode(llm extServer, ctx context.Context, tokens []int) (string, error) {
|
||||
if len(tokens) == 0 {
|
||||
return "", nil
|
||||
@@ -385,9 +329,6 @@ func decode(llm extServer, ctx context.Context, tokens []int) (string, error) {
|
||||
return decoded.Content, err
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||
return embedding(llm, ctx, input)
|
||||
}
|
||||
func embedding(llm extServer, ctx context.Context, input string) ([]float64, error) {
|
||||
data, err := json.Marshal(TokenizeRequest{Content: input})
|
||||
if err != nil {
|
||||
@@ -413,10 +354,6 @@ func embedding(llm extServer, ctx context.Context, input string) ([]float64, err
|
||||
return embedding.Embedding, nil
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Close() {
|
||||
close(llm)
|
||||
}
|
||||
|
||||
func close(llm extServer) {
|
||||
llm.llama_server_stop()
|
||||
mutex.Unlock()
|
||||
80
llm/ext_server_default.go
Normal file
80
llm/ext_server_default.go
Normal file
@@ -0,0 +1,80 @@
|
||||
//go:build !windows
|
||||
|
||||
package llm
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "ext_server.h"
|
||||
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
type llamaExtServer struct {
|
||||
api.Options
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t) {
|
||||
C.llama_server_init(sparams, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_start() {
|
||||
C.llama_server_start()
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_stop() {
|
||||
C.llama_server_stop()
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) llama_server_completion(json_req *C.char, resp *C.ext_server_resp_t) {
|
||||
C.llama_server_completion(json_req, resp)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_completion_next_result(task_id C.int, resp *C.ext_server_task_result_t) {
|
||||
C.llama_server_completion_next_result(task_id, resp)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_completion_cancel(task_id C.int, err *C.ext_server_resp_t) {
|
||||
C.llama_server_completion_cancel(task_id, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_release_task_result(result *C.ext_server_task_result_t) {
|
||||
C.llama_server_release_task_result(result)
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) llama_server_tokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
|
||||
C.llama_server_tokenize(json_req, json_resp, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_detokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
|
||||
C.llama_server_detokenize(json_req, json_resp, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_embedding(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
|
||||
C.llama_server_embedding(json_req, json_resp, err)
|
||||
}
|
||||
func (llm *llamaExtServer) llama_server_release_json_resp(json_resp **C.char) {
|
||||
C.llama_server_release_json_resp(json_resp)
|
||||
}
|
||||
|
||||
func newDefaultExtServer(model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
||||
server := &llamaExtServer{opts}
|
||||
return newExtServer(server, model, adapters, projectors, numLayers, opts)
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
|
||||
return predict(ctx, llm, pred, fn)
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
return encode(llm, ctx, prompt)
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
return decode(llm, ctx, tokens)
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||
return embedding(llm, ctx, input)
|
||||
}
|
||||
|
||||
func (llm *llamaExtServer) Close() {
|
||||
close(llm)
|
||||
}
|
||||
12
llm/ext_server_windows.go
Normal file
12
llm/ext_server_windows.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
func newDefaultExtServer(model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
||||
// On windows we always load the llama.cpp libraries dynamically to avoid startup DLL dependencies
|
||||
// This ensures we can update the PATH at runtime to get everything loaded
|
||||
|
||||
return newDynamicShimExtServer(AvailableShims["cpu"], model, adapters, projectors, numLayers, opts)
|
||||
}
|
||||
29
llm/llama.cpp/CMakeLists.txt
Normal file
29
llm/llama.cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,29 @@
|
||||
# Ollama specific CMakefile to include in llama.cpp/examples/server
|
||||
|
||||
set(TARGET ext_server)
|
||||
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
|
||||
add_library(${TARGET} STATIC ../../../ext_server.cpp)
|
||||
target_include_directories(${TARGET} PRIVATE ../../common)
|
||||
target_include_directories(${TARGET} PRIVATE ../..)
|
||||
target_include_directories(${TARGET} PRIVATE ../../..)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_SERVER_LIBRARY=1)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_definitions(${TARGET} PRIVATE
|
||||
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
||||
)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(ext_server PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(ext_server PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
||||
add_library(ext_server_shared SHARED $<TARGET_OBJECTS:ext_server>)
|
||||
target_link_libraries(ext_server_shared PRIVATE ggml llama llava common ${CMAKE_THREAD_LIBS_INIT})
|
||||
install(TARGETS ext_server_shared LIBRARY)
|
||||
endif()
|
||||
|
||||
if (CUDAToolkit_FOUND)
|
||||
target_include_directories(${TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
if (WIN32)
|
||||
target_link_libraries(ext_server_shared PRIVATE nvml)
|
||||
endif()
|
||||
endif()
|
||||
281
llm/llama.cpp/ext_server.cpp
Normal file
281
llm/llama.cpp/ext_server.cpp
Normal file
@@ -0,0 +1,281 @@
|
||||
#include "ext_server.h"
|
||||
|
||||
// Necessary evil since the server types are not defined in a header
|
||||
#include "server.cpp"
|
||||
|
||||
// Expose the llama server as a callable extern "C" API
|
||||
llama_server_context *llama = NULL;
|
||||
std::atomic<bool> ext_server_running(false);
|
||||
std::thread ext_server_thread;
|
||||
|
||||
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
|
||||
#if SERVER_VERBOSE != 1
|
||||
log_disable();
|
||||
#endif
|
||||
assert(err != NULL && sparams != NULL);
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
llama = new llama_server_context;
|
||||
log_set_target(stdout);
|
||||
gpt_params params;
|
||||
params.n_ctx = sparams->n_ctx;
|
||||
params.n_batch = sparams->n_batch;
|
||||
if (sparams->n_threads > 0) {
|
||||
params.n_threads = sparams->n_threads;
|
||||
}
|
||||
params.n_parallel = sparams->n_parallel;
|
||||
params.rope_freq_base = sparams->rope_freq_base;
|
||||
params.rope_freq_scale = sparams->rope_freq_scale;
|
||||
|
||||
if (sparams->memory_f16) {
|
||||
params.cache_type_k = "f16";
|
||||
params.cache_type_v = "f16";
|
||||
} else {
|
||||
params.cache_type_k = "f32";
|
||||
params.cache_type_v = "f32";
|
||||
}
|
||||
|
||||
params.n_gpu_layers = sparams->n_gpu_layers;
|
||||
params.main_gpu = sparams->main_gpu;
|
||||
params.use_mlock = sparams->use_mlock;
|
||||
params.use_mmap = sparams->use_mmap;
|
||||
params.numa = sparams->numa;
|
||||
params.embedding = sparams->embedding;
|
||||
if (sparams->model != NULL) {
|
||||
params.model = sparams->model;
|
||||
}
|
||||
|
||||
for (ext_server_lora_adapter *la = sparams->lora_adapters; la != NULL;
|
||||
la = la->next) {
|
||||
params.lora_adapter.push_back(std::make_tuple(la->adapter, la->scale));
|
||||
}
|
||||
|
||||
if (sparams->mmproj != NULL) {
|
||||
params.mmproj = std::string(sparams->mmproj);
|
||||
}
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
|
||||
// load the model
|
||||
if (!llama->load_model(params)) {
|
||||
// TODO - consider modifying the logging logic or patching load_model so
|
||||
// we can capture more detailed error messages and pass them back to the
|
||||
// caller for better UX
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "error loading model %s",
|
||||
params.model.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
llama->initialize();
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len,
|
||||
"Unknown exception initializing llama server");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_start() {
|
||||
assert(llama != NULL);
|
||||
// TODO mutex to protect thread creation
|
||||
ext_server_thread = std::thread([&]() {
|
||||
ext_server_running = true;
|
||||
try {
|
||||
LOG_TEE("llama server main loop starting\n");
|
||||
ggml_time_init();
|
||||
while (ext_server_running.load()) {
|
||||
if (!llama->update_slots()) {
|
||||
LOG_TEE(
|
||||
"unexpected error in llama server update_slots - exiting main "
|
||||
"loop\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
LOG_TEE("caught exception in llama server main loop: %s\n", e.what());
|
||||
} catch (...) {
|
||||
LOG_TEE("caught unknown exception in llama server main loop\n");
|
||||
}
|
||||
LOG_TEE("\nllama server shutting down\n");
|
||||
llama_backend_free();
|
||||
});
|
||||
}
|
||||
|
||||
void llama_server_stop() {
|
||||
assert(llama != NULL);
|
||||
// TODO - too verbose, remove once things are solid
|
||||
LOG_TEE("requesting llama server shutdown\n");
|
||||
ext_server_running = false;
|
||||
ext_server_thread.join();
|
||||
delete llama;
|
||||
llama = NULL;
|
||||
LOG_TEE("llama server shutdown complete\n");
|
||||
}
|
||||
|
||||
void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
|
||||
assert(llama != NULL && json_req != NULL && resp != NULL);
|
||||
resp->id = -1;
|
||||
resp->msg[0] = '\0';
|
||||
try {
|
||||
json data = json::parse(json_req);
|
||||
resp->id = llama->request_completion(data, false, false, -1);
|
||||
} catch (std::exception &e) {
|
||||
snprintf(resp->msg, resp->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
snprintf(resp->msg, resp->msg_len, "Unknown exception during completion");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_completion_next_result(const int task_id,
|
||||
ext_server_task_result_t *resp) {
|
||||
assert(llama != NULL && resp != NULL);
|
||||
std::string msg;
|
||||
resp->id = -1;
|
||||
resp->stop = false;
|
||||
resp->error = false;
|
||||
resp->json_resp = NULL;
|
||||
std::string result_json;
|
||||
try {
|
||||
task_result result = llama->next_result(task_id);
|
||||
result_json =
|
||||
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
resp->id = result.id;
|
||||
resp->stop = result.stop;
|
||||
resp->error = result.error;
|
||||
if (result.error) {
|
||||
llama->request_cancel(task_id);
|
||||
} else if (result.stop) {
|
||||
llama->request_cancel(task_id);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
resp->error = true;
|
||||
resp->id = -1;
|
||||
result_json = "{\"error\":\"exception " + std::string(e.what()) + "\"}";
|
||||
LOG_TEE("llama server completion exception %s\n", e.what());
|
||||
} catch (...) {
|
||||
resp->error = true;
|
||||
resp->id = -1;
|
||||
result_json = "{\"error\":\"Unknown exception during completion\"}";
|
||||
LOG_TEE("llama server completion unknown exception\n");
|
||||
}
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
resp->json_resp = new char[size];
|
||||
snprintf(resp->json_resp, size, "%s", result_json.c_str());
|
||||
}
|
||||
|
||||
void llama_server_release_task_result(ext_server_task_result_t *result) {
|
||||
if (result == NULL || result->json_resp == NULL) {
|
||||
return;
|
||||
}
|
||||
delete[] result->json_resp;
|
||||
}
|
||||
|
||||
void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err) {
|
||||
assert(llama != NULL && err != NULL);
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
llama->request_cancel(task_id);
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len,
|
||||
"Unknown exception completion cancel in llama server");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_tokenize(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
assert(llama != NULL && json_req != NULL && json_resp != NULL && err != NULL);
|
||||
*json_resp = NULL;
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
const json body = json::parse(json_req);
|
||||
std::vector<llama_token> tokens;
|
||||
if (body.count("content") != 0) {
|
||||
tokens = llama->tokenize(body["content"], false);
|
||||
}
|
||||
const json data = format_tokenizer_response(tokens);
|
||||
std::string result_json = data.dump();
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
*json_resp = new char[size];
|
||||
snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "Unknown exception during tokenize");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_release_json_resp(char **json_resp) {
|
||||
if (json_resp == NULL || *json_resp == NULL) {
|
||||
return;
|
||||
}
|
||||
delete[] *json_resp;
|
||||
}
|
||||
|
||||
void llama_server_detokenize(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
assert(llama != NULL && json_req != NULL && json_resp != NULL && err != NULL);
|
||||
*json_resp = NULL;
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
const json body = json::parse(json_req);
|
||||
std::string content;
|
||||
if (body.count("tokens") != 0) {
|
||||
const std::vector<llama_token> tokens = body["tokens"];
|
||||
content = tokens_to_str(llama->ctx, tokens.cbegin(), tokens.cend());
|
||||
}
|
||||
const json data = format_detokenized_response(content);
|
||||
std::string result_json = data.dump();
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
*json_resp = new char[size];
|
||||
snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "Unknown exception during detokenize");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_embedding(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
assert(llama != NULL && json_req != NULL && json_resp != NULL && err != NULL);
|
||||
*json_resp = NULL;
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
const json body = json::parse(json_req);
|
||||
json prompt;
|
||||
if (body.count("content") != 0) {
|
||||
prompt = body["content"];
|
||||
} else {
|
||||
prompt = "";
|
||||
}
|
||||
const int task_id = llama->request_completion(
|
||||
{{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
||||
task_result result = llama->next_result(task_id);
|
||||
std::string result_json = result.result_json.dump();
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
*json_resp = new char[size];
|
||||
snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "Unknown exception during embedding");
|
||||
}
|
||||
}
|
||||
94
llm/llama.cpp/ext_server.h
Normal file
94
llm/llama.cpp/ext_server.h
Normal file
@@ -0,0 +1,94 @@
|
||||
#if defined(LLAMA_SERVER_LIBRARY)
|
||||
#ifndef LLAMA_SERVER_H
|
||||
#define LLAMA_SERVER_H
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
int __main(int argc, char **argv);
|
||||
|
||||
// This exposes extern C entrypoints into the llama_server
|
||||
// To enable the server compile with LLAMA_SERVER_LIBRARY
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
typedef struct ext_server_resp {
|
||||
int id; // < 0 on error
|
||||
size_t msg_len; // caller must allocate msg and set msg_len
|
||||
char *msg;
|
||||
} ext_server_resp_t;
|
||||
|
||||
// Allocated and freed by caller
|
||||
typedef struct ext_server_lora_adapter {
|
||||
char *adapter;
|
||||
float scale;
|
||||
struct ext_server_lora_adapter *next;
|
||||
} ext_server_lora_adapter_t;
|
||||
|
||||
// Allocated and freed by caller
|
||||
typedef struct ext_server_params {
|
||||
char *model;
|
||||
uint32_t n_ctx; // token context window, 0 = from model
|
||||
uint32_t n_batch; // prompt processing maximum batch size
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_parallel; // number of parallel sequences to decodewra
|
||||
float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
||||
bool memory_f16; // use f16 instead of f32 for memory kv
|
||||
int32_t n_gpu_layers; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool use_mmap; // use mmap if possible
|
||||
bool numa; // attempt optimizations that help on some NUMA systems
|
||||
bool embedding; // get only sentence embedding
|
||||
ext_server_lora_adapter_t *lora_adapters;
|
||||
char *mmproj;
|
||||
} ext_server_params_t;
|
||||
|
||||
typedef struct ext_server_task_result {
|
||||
int id;
|
||||
bool stop;
|
||||
bool error;
|
||||
char *json_resp; // null terminated, memory managed by ext_server
|
||||
} ext_server_task_result_t;
|
||||
|
||||
// Initialize the server once per process
|
||||
// err->id = 0 for success and err->msg[0] = NULL
|
||||
// err->id != 0 for failure, and err->msg contains error message
|
||||
void llama_server_init(ext_server_params_t *sparams, ext_server_resp_t *err);
|
||||
|
||||
// Run the main loop, called once per init
|
||||
void llama_server_start();
|
||||
// Stop the main loop and free up resources allocated in init and start. Init
|
||||
// must be called again to reuse
|
||||
void llama_server_stop();
|
||||
|
||||
// json_req null terminated string, memory managed by caller
|
||||
// resp->id >= 0 on success (task ID)
|
||||
// resp->id < 0 on error, and resp->msg contains error message
|
||||
void llama_server_completion(const char *json_req, ext_server_resp_t *resp);
|
||||
|
||||
// Caller must call llama_server_release_task_result to free resp->json_resp
|
||||
void llama_server_completion_next_result(const int task_id,
|
||||
ext_server_task_result_t *result);
|
||||
void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err);
|
||||
void llama_server_release_task_result(ext_server_task_result_t *result);
|
||||
|
||||
// Caller must call llama_server_releaes_json_resp to free json_resp if err.id <
|
||||
// 0
|
||||
void llama_server_tokenize(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void llama_server_detokenize(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void llama_server_embedding(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void llama_server_release_json_resp(char **json_resp);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif // LLAMA_SERVER_LIBRARY
|
||||
@@ -3,14 +3,13 @@
|
||||
init_vars() {
|
||||
LLAMACPP_DIR=gguf
|
||||
PATCHES="0001-Expose-callable-API-for-server.patch"
|
||||
CMAKE_DEFS="-DLLAMA_ACCELERATE=on"
|
||||
# TODO - LLAMA_K_QUANTS is stale and needs to be mapped to newer cmake settings
|
||||
CMAKE_DEFS=""
|
||||
CMAKE_TARGETS="--target ggml --target ggml_static --target llama --target build_info --target common --target ext_server --target llava_static"
|
||||
if echo "${CGO_CFLAGS}" | grep -- '-g' >/dev/null; then
|
||||
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_VERBOSE_MAKEFILE=on -DLLAMA_GPROF=on -DLLAMA_SERVER_VERBOSE=on ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_VERBOSE_MAKEFILE=on -DLLAMA_GPROF=on -DLLAMA_SERVER_VERBOSE=on"
|
||||
else
|
||||
# TODO - add additional optimization flags...
|
||||
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=Release -DLLAMA_SERVER_VERBOSE=off ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=Release -DLLAMA_SERVER_VERBOSE=off"
|
||||
fi
|
||||
}
|
||||
|
||||
@@ -25,18 +24,30 @@ git_module_setup() {
|
||||
}
|
||||
|
||||
apply_patches() {
|
||||
if [ -n "${OLLAMA_SKIP_PATCHING}" ]; then
|
||||
echo "Skipping submodule patching"
|
||||
return
|
||||
# Wire up our CMakefile
|
||||
if ! grep ollama gguf/examples/server/CMakeLists.txt; then
|
||||
echo 'include (../../../CMakeLists.txt) # ollama' >>gguf/examples/server/CMakeLists.txt
|
||||
fi
|
||||
# Workaround git apply not handling creation well for iteration
|
||||
rm -f gguf/examples/server/server.h
|
||||
for patch in ${PATCHES}; do
|
||||
git -C gguf apply ../patches/${patch}
|
||||
done
|
||||
# Avoid duplicate main symbols when we link into the cgo binary
|
||||
sed -e 's/int main(/int __main(/g' <./gguf/examples/server/server.cpp >./gguf/examples/server/server.cpp.tmp &&
|
||||
mv ./gguf/examples/server/server.cpp.tmp ./gguf/examples/server/server.cpp
|
||||
}
|
||||
|
||||
build() {
|
||||
cmake -S ${LLAMACPP_DIR} -B ${BUILD_DIR} ${CMAKE_DEFS}
|
||||
cmake --build ${BUILD_DIR} ${CMAKE_TARGETS} -j8
|
||||
}
|
||||
|
||||
install() {
|
||||
rm -rf ${BUILD_DIR}/lib
|
||||
mkdir -p ${BUILD_DIR}/lib
|
||||
cp ${BUILD_DIR}/examples/server/libext_server.a ${BUILD_DIR}/lib
|
||||
cp ${BUILD_DIR}/common/libcommon.a ${BUILD_DIR}/lib
|
||||
cp ${BUILD_DIR}/libllama.a ${BUILD_DIR}/lib
|
||||
cp ${BUILD_DIR}/libggml_static.a ${BUILD_DIR}/lib
|
||||
}
|
||||
|
||||
# Keep the local tree clean after we're done with the build
|
||||
cleanup() {
|
||||
(cd gguf/examples/server/ && git checkout CMakeLists.txt server.cpp)
|
||||
}
|
||||
|
||||
@@ -9,22 +9,24 @@ set -o pipefail
|
||||
echo "Starting darwin generate script"
|
||||
source $(dirname $0)/gen_common.sh
|
||||
init_vars
|
||||
CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DLLAMA_METAL=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="gguf/build/metal"
|
||||
CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="gguf/build/darwin/metal"
|
||||
case "${GOARCH}" in
|
||||
"amd64")
|
||||
CMAKE_DEFS="-DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 ${CMAKE_DEFS}"
|
||||
;;
|
||||
"arm64")
|
||||
CMAKE_DEFS="-DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 ${CMAKE_DEFS}"
|
||||
;;
|
||||
*)
|
||||
echo "GOARCH must be set"
|
||||
echo "this script is meant to be run from within go generate"
|
||||
exit 1
|
||||
;;
|
||||
"amd64")
|
||||
CMAKE_DEFS="-DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
;;
|
||||
"arm64")
|
||||
CMAKE_DEFS="-DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 ${CMAKE_DEFS}"
|
||||
;;
|
||||
*)
|
||||
echo "GOARCH must be set"
|
||||
echo "this script is meant to be run from within go generate"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
git_module_setup
|
||||
apply_patches
|
||||
build
|
||||
build
|
||||
install
|
||||
cleanup
|
||||
|
||||
@@ -16,39 +16,63 @@
|
||||
set -ex
|
||||
set -o pipefail
|
||||
|
||||
# See https://llvm.org/docs/AMDGPUUsage.html#processors for reference
|
||||
amdGPUs() {
|
||||
GPU_LIST=(
|
||||
"gfx803"
|
||||
"gfx900"
|
||||
"gfx906:xnack-"
|
||||
"gfx908:xnack-"
|
||||
"gfx90a:xnack+"
|
||||
"gfx90a:xnack-"
|
||||
"gfx1010"
|
||||
"gfx1012"
|
||||
"gfx1030"
|
||||
"gfx1100"
|
||||
"gfx1101"
|
||||
"gfx1102"
|
||||
)
|
||||
(
|
||||
IFS=$';'
|
||||
echo "'${GPU_LIST[*]}'"
|
||||
)
|
||||
}
|
||||
|
||||
echo "Starting linux generate script"
|
||||
if [ -z "${CUDACXX}" -a -x /usr/local/cuda/bin/nvcc ]; then
|
||||
export CUDACXX=/usr/local/cuda/bin/nvcc
|
||||
fi
|
||||
COMMON_CMAKE_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_ACCELERATE=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off"
|
||||
OLLAMA_DYN_LIB_DIR="gguf/build/lib"
|
||||
COMMON_CMAKE_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off"
|
||||
source $(dirname $0)/gen_common.sh
|
||||
init_vars
|
||||
git_module_setup
|
||||
apply_patches
|
||||
|
||||
mkdir -p ${OLLAMA_DYN_LIB_DIR}
|
||||
touch ${OLLAMA_DYN_LIB_DIR}/.generated
|
||||
|
||||
#
|
||||
# CPU first for the default library
|
||||
#
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
|
||||
BUILD_DIR="gguf/build/cpu"
|
||||
BUILD_DIR="gguf/build/linux/cpu"
|
||||
|
||||
build
|
||||
install
|
||||
|
||||
# Placeholder to keep go embed happy until we start building dynamic CPU lib variants
|
||||
touch ${BUILD_DIR}/lib/dummy.so
|
||||
|
||||
if [ -d /usr/local/cuda/lib64/ ]; then
|
||||
echo "CUDA libraries detected - building dynamic CUDA library"
|
||||
init_vars
|
||||
CMAKE_DEFS="-DLLAMA_CUBLAS=on ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
|
||||
BUILD_DIR="gguf/build/cuda"
|
||||
BUILD_DIR="gguf/build/linux/cuda"
|
||||
CUDA_LIB_DIR=/usr/local/cuda/lib64
|
||||
build
|
||||
gcc -fPIC -g -shared -o ${OLLAMA_DYN_LIB_DIR}/libcuda_server.so \
|
||||
install
|
||||
gcc -fPIC -g -shared -o ${BUILD_DIR}/lib/libext_server.so \
|
||||
-Wl,--whole-archive \
|
||||
${BUILD_DIR}/examples/server/libext_server.a \
|
||||
${BUILD_DIR}/common/libcommon.a \
|
||||
${BUILD_DIR}/libllama.a \
|
||||
${BUILD_DIR}/lib/libext_server.a \
|
||||
${BUILD_DIR}/lib/libcommon.a \
|
||||
${BUILD_DIR}/lib/libllama.a \
|
||||
-Wl,--no-whole-archive \
|
||||
${CUDA_LIB_DIR}/libcudart_static.a \
|
||||
${CUDA_LIB_DIR}/libcublas_static.a \
|
||||
@@ -73,17 +97,20 @@ fi
|
||||
if [ -d "${ROCM_PATH}" ]; then
|
||||
echo "ROCm libraries detected - building dynamic ROCm library"
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS='gfx803;gfx900;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102' -DGPU_TARGETS='gfx803;gfx900;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102'"
|
||||
BUILD_DIR="gguf/build/rocm"
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
||||
BUILD_DIR="gguf/build/linux/rocm"
|
||||
build
|
||||
gcc -fPIC -g -shared -o ${OLLAMA_DYN_LIB_DIR}/librocm_server.so \
|
||||
install
|
||||
gcc -fPIC -g -shared -o ${BUILD_DIR}/lib/libext_server.so \
|
||||
-Wl,--whole-archive \
|
||||
${BUILD_DIR}/examples/server/libext_server.a \
|
||||
${BUILD_DIR}/common/libcommon.a \
|
||||
${BUILD_DIR}/libllama.a \
|
||||
${BUILD_DIR}/lib/libext_server.a \
|
||||
${BUILD_DIR}/lib/libcommon.a \
|
||||
${BUILD_DIR}/lib/libllama.a \
|
||||
-Wl,--no-whole-archive \
|
||||
-lrt -lpthread -ldl -lstdc++ -lm \
|
||||
-L/opt/rocm/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ \
|
||||
-Wl,-rpath,/opt/rocm/lib,-rpath,/opt/amdgpu/lib/x86_64-linux-gnu/ \
|
||||
-lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu
|
||||
fi
|
||||
|
||||
cleanup
|
||||
|
||||
@@ -4,8 +4,8 @@ $ErrorActionPreference = "Stop"
|
||||
|
||||
function init_vars {
|
||||
$script:patches = @("0001-Expose-callable-API-for-server.patch")
|
||||
$script:cmakeDefs = @("-DBUILD_SHARED_LIBS=on", "-DLLAMA_NATIVE=off", "-DLLAMA_F16C=off", "-DLLAMA_FMA=off", "-DLLAMA_AVX512=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX=on", "-DLLAMA_K_QUANTS=on", "-DLLAMA_ACCELERATE=on", "-A","x64")
|
||||
|
||||
$script:cmakeDefs = @("-DBUILD_SHARED_LIBS=on", "-DLLAMA_NATIVE=off", "-DLLAMA_F16C=off", "-DLLAMA_FMA=off", "-DLLAMA_AVX512=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX=on", "-A","x64")
|
||||
$script:cmakeTargets = @("ggml", "ggml_static", "llama", "build_info", "common", "ext_server_shared", "llava_static")
|
||||
if ($env:CGO_CFLAGS -contains "-g") {
|
||||
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on")
|
||||
$script:config = "RelWithDebInfo"
|
||||
@@ -24,12 +24,14 @@ function git_module_setup {
|
||||
}
|
||||
|
||||
function apply_patches {
|
||||
rm -erroraction ignore -path "gguf/examples/server/server.h"
|
||||
foreach ($patch in $script:patches) {
|
||||
write-host "Applying patch $patch"
|
||||
& git -C gguf apply ../patches/$patch
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
# Wire up our CMakefile
|
||||
if (!(Select-String -Path "gguf/examples/server/CMakeLists.txt" -Pattern 'ollama')) {
|
||||
Add-Content -Path "gguf/examples/server/CMakeLists.txt" -Value 'include (../../../CMakeLists.txt) # ollama'
|
||||
}
|
||||
# Avoid duplicate main symbols when we link into the cgo binary
|
||||
$content = Get-Content -Path "./gguf/examples/server/server.cpp"
|
||||
$content = $content -replace 'int main\(', 'int __main('
|
||||
Set-Content -Path "./gguf/examples/server/server.cpp" -Value $content
|
||||
}
|
||||
|
||||
function build {
|
||||
@@ -37,16 +39,24 @@ function build {
|
||||
& cmake --version
|
||||
& cmake -S gguf -B $script:buildDir $script:cmakeDefs
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
write-host "building with: cmake --build $script:buildDir --config $script:config"
|
||||
& cmake --build $script:buildDir --config $script:config
|
||||
write-host "building with: cmake --build $script:buildDir --config $script:config ($script:cmakeTargets | ForEach-Object { "--target", $_ })"
|
||||
& cmake --build $script:buildDir --config $script:config ($script:cmakeTargets | ForEach-Object { "--target", $_ })
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
|
||||
function install {
|
||||
rm -erroraction ignore -recurse -force -path $script:installDir
|
||||
& cmake --install $script:buildDir --prefix $script:installDir --config $script:config
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
rm -ea 0 -recurse -force -path "${script:buildDir}/lib"
|
||||
md "${script:buildDir}/lib" -ea 0 > $null
|
||||
cp "${script:buildDir}/bin/${script:config}/ext_server_shared.dll" "${script:buildDir}/lib"
|
||||
cp "${script:buildDir}/bin/${script:config}/llama.dll" "${script:buildDir}/lib"
|
||||
|
||||
# Display the dll dependencies in the build log
|
||||
dumpbin /dependents "${script:buildDir}/bin/${script:config}/ext_server_shared.dll" | select-string ".dll"
|
||||
}
|
||||
|
||||
function cleanup {
|
||||
Set-Location "gguf/examples/server"
|
||||
git checkout CMakeLists.txt server.cpp
|
||||
}
|
||||
|
||||
init_vars
|
||||
@@ -54,40 +64,24 @@ git_module_setup
|
||||
apply_patches
|
||||
|
||||
# first build CPU based
|
||||
$script:buildDir="gguf/build/wincpu"
|
||||
$script:installDir="gguf/build/wincpu/dist"
|
||||
$script:buildDir="gguf/build/windows/cpu"
|
||||
|
||||
build
|
||||
# install
|
||||
|
||||
md gguf/build/lib -ea 0
|
||||
md gguf/build/wincpu/dist/lib -ea 0
|
||||
mv gguf/build/wincpu/bin/$script:config/ext_server_shared.dll gguf/build/wincpu/dist/lib/cpu_server.dll
|
||||
|
||||
|
||||
# Nope, this barfs on lots of symbol problems
|
||||
#mv gguf/build/wincpu/examples/server/$script:config/ext_server_shared.dll gguf/build/wincpu/dist/lib/cpu_server.lib
|
||||
# Nope: this needs lots of include paths to pull in things like msvcprt.lib and other deps
|
||||
# & cl.exe `
|
||||
# gguf/build/wincpu/examples/server/$script:config/ext_server.lib `
|
||||
# gguf/build/wincpu/common/$script:config/common.lib `
|
||||
# gguf/build/wincpu/$script:config/llama.lib `
|
||||
# gguf/build/wincpu/$script:config/ggml_static.lib `
|
||||
# /link /DLL /DEF:cpu_server.def /NOENTRY /MACHINE:X64 /OUT:gguf/build/wincpu/dist/lib/cpu_server.dll
|
||||
# if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
install
|
||||
|
||||
# Then build cuda as a dynamically loaded library
|
||||
init_vars
|
||||
$script:buildDir="gguf/build/wincuda"
|
||||
$script:installDir="gguf/build/wincuda/dist"
|
||||
$script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DBUILD_SHARED_LIBS=on")
|
||||
$script:buildDir="gguf/build/windows/cuda"
|
||||
$script:cmakeDefs += @("-DLLAMA_CUBLAS=ON")
|
||||
build
|
||||
install
|
||||
cp gguf/build/wincuda/dist/bin/ext_server_shared.dll gguf/build/lib/cuda_server.dll
|
||||
|
||||
# TODO - more to do here to create a usable dll
|
||||
# TODO - actually implement ROCm support on windows
|
||||
$script:buildDir="gguf/build/windows/rocm"
|
||||
|
||||
rm -ea 0 -recurse -force -path "${script:buildDir}/lib"
|
||||
md "${script:buildDir}/lib" -ea 0 > $null
|
||||
echo $null >> "${script:buildDir}/lib/.generated"
|
||||
|
||||
# TODO - implement ROCm support on windows
|
||||
md gguf/build/winrocm/lib -ea 0
|
||||
echo $null >> gguf/build/winrocm/lib/.generated
|
||||
cleanup
|
||||
write-host "`ngo generate completed"
|
||||
@@ -1,464 +0,0 @@
|
||||
From 90c332fe2ef61149b38561d02836e66715df214d Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Hiltgen <daniel@ollama.com>
|
||||
Date: Mon, 13 Nov 2023 12:25:58 -0800
|
||||
Subject: [PATCH] Expose callable API for server
|
||||
|
||||
This adds an extern "C" interface within the example server
|
||||
---
|
||||
examples/server/CMakeLists.txt | 27 ++++
|
||||
examples/server/server.cpp | 280 +++++++++++++++++++++++++++++++++
|
||||
examples/server/server.h | 89 +++++++++++
|
||||
ggml-cuda.cu | 1 +
|
||||
4 files changed, 397 insertions(+)
|
||||
create mode 100644 examples/server/server.h
|
||||
|
||||
diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt
|
||||
index 859cd12..da2b9bf 100644
|
||||
--- a/examples/server/CMakeLists.txt
|
||||
+++ b/examples/server/CMakeLists.txt
|
||||
@@ -11,3 +11,30 @@ if (WIN32)
|
||||
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
|
||||
endif()
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
+
|
||||
+set(TARGET ext_server)
|
||||
+option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
|
||||
+add_library(${TARGET} STATIC server.cpp)
|
||||
+target_include_directories(${TARGET} PRIVATE ../../common)
|
||||
+target_include_directories(${TARGET} PRIVATE ../..)
|
||||
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
+target_compile_definitions(${TARGET} PUBLIC LLAMA_SERVER_LIBRARY=1)
|
||||
+target_link_libraries(${TARGET} PRIVATE common llama llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
+target_compile_definitions(${TARGET} PRIVATE
|
||||
+ SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
||||
+)
|
||||
+
|
||||
+if (BUILD_SHARED_LIBS)
|
||||
+ set_target_properties(ext_server PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
+ target_compile_definitions(ext_server PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
||||
+ add_library(ext_server_shared SHARED $<TARGET_OBJECTS:ext_server>)
|
||||
+ target_link_libraries(ext_server_shared PRIVATE ggml llama llava common ${CMAKE_THREAD_LIBS_INIT})
|
||||
+ install(TARGETS ext_server_shared LIBRARY)
|
||||
+endif()
|
||||
+
|
||||
+if (CUDAToolkit_FOUND)
|
||||
+ target_include_directories(${TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
+ if (WIN32)
|
||||
+ target_link_libraries(ext_server_shared PRIVATE nvml)
|
||||
+ endif()
|
||||
+endif()
|
||||
\ No newline at end of file
|
||||
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
|
||||
index 0403853..07fb05c 100644
|
||||
--- a/examples/server/server.cpp
|
||||
+++ b/examples/server/server.cpp
|
||||
@@ -5,6 +5,9 @@
|
||||
#include "../llava/clip.h"
|
||||
|
||||
#include "stb_image.h"
|
||||
+#if defined(LLAMA_SERVER_LIBRARY)
|
||||
+#include "server.h"
|
||||
+#endif
|
||||
|
||||
#ifndef NDEBUG
|
||||
// crash the server in debug mode, otherwise send an http 500 error
|
||||
@@ -2643,6 +2646,7 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
|
||||
}
|
||||
}
|
||||
|
||||
+#ifndef LLAMA_SERVER_LIBRARY
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
#if SERVER_VERBOSE != 1
|
||||
@@ -3123,3 +3127,279 @@ int main(int argc, char **argv)
|
||||
llama_backend_free();
|
||||
return 0;
|
||||
}
|
||||
+
|
||||
+#else // LLAMA_SERVER_LIBRARY
|
||||
+// Expose the llama server as a callable extern "C" API
|
||||
+llama_server_context *llama = NULL;
|
||||
+std::atomic<bool> ext_server_running(false);
|
||||
+std::thread ext_server_thread;
|
||||
+
|
||||
+void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err)
|
||||
+{
|
||||
+#if SERVER_VERBOSE != 1
|
||||
+ LOG_TEE("disabling verbose llm logging\n");
|
||||
+ log_disable();
|
||||
+#endif
|
||||
+ assert(err != NULL && sparams != NULL);
|
||||
+ err->id = 0;
|
||||
+ err->msg[0] = '\0';
|
||||
+ try {
|
||||
+ llama = new llama_server_context;
|
||||
+ log_set_target(stdout);
|
||||
+ gpt_params params;
|
||||
+ params.n_ctx = sparams->n_ctx;
|
||||
+ params.n_batch = sparams->n_batch;
|
||||
+ if (sparams->n_threads > 0) {
|
||||
+ params.n_threads = sparams->n_threads;
|
||||
+ }
|
||||
+ params.n_parallel = sparams->n_parallel;
|
||||
+ params.rope_freq_base = sparams->rope_freq_base;
|
||||
+ params.rope_freq_scale = sparams->rope_freq_scale;
|
||||
+
|
||||
+ if (sparams->memory_f16) {
|
||||
+ params.cache_type_k = "f16";
|
||||
+ params.cache_type_v = "f16";
|
||||
+ } else {
|
||||
+ params.cache_type_k = "f32";
|
||||
+ params.cache_type_v = "f32";
|
||||
+ }
|
||||
+
|
||||
+ params.n_gpu_layers = sparams->n_gpu_layers;
|
||||
+ params.main_gpu = sparams->main_gpu;
|
||||
+ params.use_mlock = sparams->use_mlock;
|
||||
+ params.use_mmap = sparams->use_mmap;
|
||||
+ params.numa = sparams->numa;
|
||||
+ params.embedding = sparams->embedding;
|
||||
+ if (sparams->model != NULL) {
|
||||
+ params.model = sparams->model;
|
||||
+ }
|
||||
+
|
||||
+ for (ext_server_lora_adapter *la = sparams->lora_adapters; la != NULL; la = la->next) {
|
||||
+ params.lora_adapter.push_back(std::make_tuple(la->adapter, la->scale));
|
||||
+ }
|
||||
+
|
||||
+ if (sparams->mmproj != NULL) {
|
||||
+ params.mmproj = std::string(sparams->mmproj);
|
||||
+ }
|
||||
+
|
||||
+ llama_backend_init(params.numa);
|
||||
+
|
||||
+ // load the model
|
||||
+ if (!llama->load_model(params))
|
||||
+ {
|
||||
+ // TODO - consider modifying the logging logic or patching load_model so we can capture more detailed error messages
|
||||
+ // and pass them back to the caller for better UX
|
||||
+ err->id = -1;
|
||||
+ snprintf(err->msg, err->msg_len, "error loading model %s", params.model.c_str());
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ llama->initialize();
|
||||
+ } catch (std::exception &e) {
|
||||
+ err->id = -1;
|
||||
+ snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
+ } catch (...) {
|
||||
+ err->id = -1;
|
||||
+ snprintf(err->msg, err->msg_len, "Unknown exception initializing llama server");
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+void llama_server_start()
|
||||
+{
|
||||
+ assert(llama != NULL);
|
||||
+ // TODO mutex to protect thread creation
|
||||
+ ext_server_thread = std::thread([&]()
|
||||
+ {
|
||||
+ ext_server_running = true;
|
||||
+ try {
|
||||
+ LOG_TEE("llama server main loop starting\n");
|
||||
+ ggml_time_init();
|
||||
+ while (ext_server_running.load())
|
||||
+ {
|
||||
+ if (!llama->update_slots()) {
|
||||
+ LOG_TEE("unexpected error in llama server update_slots - exiting main loop\n");
|
||||
+ break;
|
||||
+ }
|
||||
+ }
|
||||
+ } catch (std::exception &e) {
|
||||
+ LOG_TEE("caught exception in llama server main loop: %s\n", e.what());
|
||||
+ } catch (...) {
|
||||
+ LOG_TEE("caught unknown exception in llama server main loop\n");
|
||||
+ }
|
||||
+ LOG_TEE("\nllama server shutting down\n");
|
||||
+ llama_backend_free();
|
||||
+ });
|
||||
+}
|
||||
+
|
||||
+void llama_server_stop() {
|
||||
+ assert(llama != NULL);
|
||||
+ // TODO - too verbose, remove once things are solid
|
||||
+ LOG_TEE("requesting llama server shutdown\n");
|
||||
+ ext_server_running = false;
|
||||
+ ext_server_thread.join();
|
||||
+ delete llama;
|
||||
+ llama = NULL;
|
||||
+ LOG_TEE("llama server shutdown complete\n");
|
||||
+}
|
||||
+
|
||||
+void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
|
||||
+ assert(llama != NULL && json_req != NULL && resp != NULL);
|
||||
+ resp->id = -1;
|
||||
+ resp->msg[0] = '\0';
|
||||
+ try {
|
||||
+ json data = json::parse(json_req);
|
||||
+ resp->id = llama->request_completion(data, false, false, -1);
|
||||
+ } catch (std::exception &e) {
|
||||
+ snprintf(resp->msg, resp->msg_len, "exception %s", e.what());
|
||||
+ } catch (...) {
|
||||
+ snprintf(resp->msg, resp->msg_len, "Unknown exception during completion");
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+void llama_server_completion_next_result(const int task_id, ext_server_task_result_t *resp) {
|
||||
+ assert(llama != NULL && resp != NULL);
|
||||
+ std::string msg;
|
||||
+ resp->id = -1;
|
||||
+ resp->stop = false;
|
||||
+ resp->error = false;
|
||||
+ resp->json_resp = NULL;
|
||||
+ std::string result_json;
|
||||
+ try {
|
||||
+ task_result result = llama->next_result(task_id);
|
||||
+ result_json = result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
+ resp->id = result.id;
|
||||
+ resp->stop = result.stop;
|
||||
+ resp->error = result.error;
|
||||
+ if (result.error) {
|
||||
+ llama->request_cancel(task_id);
|
||||
+ } else if (result.stop) {
|
||||
+ llama->request_cancel(task_id);
|
||||
+ }
|
||||
+ } catch (std::exception &e) {
|
||||
+ resp->error = true;
|
||||
+ resp->id = -1;
|
||||
+ result_json = "{\"error\":\"exception " + std::string(e.what()) + "\"}";
|
||||
+ } catch (...) {
|
||||
+ resp->error = true;
|
||||
+ resp->id = -1;
|
||||
+ result_json = "{\"error\":\"Unknown exception during completion\"}";
|
||||
+ }
|
||||
+ const std::string::size_type size = result_json.size() + 1;
|
||||
+ resp->json_resp = new char[size];
|
||||
+ snprintf(resp->json_resp, size, "%s", result_json.c_str());
|
||||
+}
|
||||
+
|
||||
+void llama_server_release_task_result(ext_server_task_result_t *result) {
|
||||
+ if (result == NULL || result->json_resp == NULL) {
|
||||
+ return;
|
||||
+ }
|
||||
+ delete[] result->json_resp;
|
||||
+}
|
||||
+
|
||||
+void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err) {
|
||||
+ assert(llama != NULL && err != NULL);
|
||||
+ err->id = 0;
|
||||
+ err->msg[0] = '\0';
|
||||
+ try {
|
||||
+ llama->request_cancel(task_id);
|
||||
+ } catch (std::exception &e) {
|
||||
+ err->id = -1;
|
||||
+ snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
+ } catch (...) {
|
||||
+ err->id = -1;
|
||||
+ snprintf(err->msg, err->msg_len, "Unknown exception completion cancel in llama server");
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+void llama_server_tokenize(const char *json_req, char **json_resp, ext_server_resp_t *err) {
|
||||
+ assert(llama != NULL && json_req != NULL && json_resp != NULL && err != NULL);
|
||||
+ *json_resp = NULL;
|
||||
+ err->id = 0;
|
||||
+ err->msg[0] = '\0';
|
||||
+ try {
|
||||
+ const json body = json::parse(json_req);
|
||||
+ std::vector<llama_token> tokens;
|
||||
+ if (body.count("content") != 0)
|
||||
+ {
|
||||
+ tokens = llama->tokenize(body["content"], false);
|
||||
+ }
|
||||
+ const json data = format_tokenizer_response(tokens);
|
||||
+ std::string result_json = data.dump();
|
||||
+ const std::string::size_type size = result_json.size() + 1;
|
||||
+ *json_resp = new char[size];
|
||||
+ snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
+ } catch (std::exception &e) {
|
||||
+ err->id = -1;
|
||||
+ snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
+ } catch (...) {
|
||||
+ err->id = -1;
|
||||
+ snprintf(err->msg, err->msg_len, "Unknown exception during tokenize");
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+void llama_server_release_json_resp(char **json_resp) {
|
||||
+ if (json_resp == NULL || *json_resp == NULL) {
|
||||
+ return;
|
||||
+ }
|
||||
+ delete[] *json_resp;
|
||||
+}
|
||||
+
|
||||
+void llama_server_detokenize(const char *json_req, char **json_resp, ext_server_resp_t *err) {
|
||||
+ assert(llama != NULL && json_req != NULL && json_resp != NULL && err != NULL);
|
||||
+ *json_resp = NULL;
|
||||
+ err->id = 0;
|
||||
+ err->msg[0] = '\0';
|
||||
+ try {
|
||||
+ const json body = json::parse(json_req);
|
||||
+ std::string content;
|
||||
+ if (body.count("tokens") != 0)
|
||||
+ {
|
||||
+ const std::vector<llama_token> tokens = body["tokens"];
|
||||
+ content = tokens_to_str(llama->ctx, tokens.cbegin(), tokens.cend());
|
||||
+ }
|
||||
+ const json data = format_detokenized_response(content);
|
||||
+ std::string result_json = data.dump();
|
||||
+ const std::string::size_type size = result_json.size() + 1;
|
||||
+ *json_resp = new char[size];
|
||||
+ snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
+ } catch (std::exception &e) {
|
||||
+ err->id = -1;
|
||||
+ snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
+ } catch (...) {
|
||||
+ err->id = -1;
|
||||
+ snprintf(err->msg, err->msg_len, "Unknown exception during detokenize");
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+void llama_server_embedding(const char *json_req, char** json_resp, ext_server_resp_t *err) {
|
||||
+ assert(llama != NULL && json_req != NULL && json_resp != NULL && err != NULL);
|
||||
+ *json_resp = NULL;
|
||||
+ err->id = 0;
|
||||
+ err->msg[0] = '\0';
|
||||
+ try {
|
||||
+ const json body = json::parse(json_req);
|
||||
+ json prompt;
|
||||
+ if (body.count("content") != 0)
|
||||
+ {
|
||||
+ prompt = body["content"];
|
||||
+ }
|
||||
+ else
|
||||
+ {
|
||||
+ prompt = "";
|
||||
+ }
|
||||
+ const int task_id = llama->request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true, -1);
|
||||
+ task_result result = llama->next_result(task_id);
|
||||
+ std::string result_json = result.result_json.dump();
|
||||
+ const std::string::size_type size = result_json.size() + 1;
|
||||
+ *json_resp = new char[size];
|
||||
+ snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
+ } catch (std::exception &e) {
|
||||
+ err->id = -1;
|
||||
+ snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
+ } catch (...) {
|
||||
+ err->id = -1;
|
||||
+ snprintf(err->msg, err->msg_len, "Unknown exception during embedding");
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+#endif // LLAMA_SERVER_LIBRARY
|
||||
\ No newline at end of file
|
||||
diff --git a/examples/server/server.h b/examples/server/server.h
|
||||
new file mode 100644
|
||||
index 0000000..d22f1b6
|
||||
--- /dev/null
|
||||
+++ b/examples/server/server.h
|
||||
@@ -0,0 +1,89 @@
|
||||
+#if defined(LLAMA_SERVER_LIBRARY)
|
||||
+#ifndef LLAMA_SERVER_H
|
||||
+#define LLAMA_SERVER_H
|
||||
+#include <stddef.h>
|
||||
+#include <stdint.h>
|
||||
+#include <stdio.h>
|
||||
+#include <stdbool.h>
|
||||
+
|
||||
+// This exposes extern C entrypoints into the llama_server
|
||||
+// To enable the server compile with LLAMA_SERVER_LIBRARY
|
||||
+
|
||||
+#ifdef __cplusplus
|
||||
+extern "C"
|
||||
+{
|
||||
+#endif
|
||||
+ typedef struct ext_server_resp {
|
||||
+ int id; // < 0 on error
|
||||
+ size_t msg_len; // caller must allocate msg and set msg_len
|
||||
+ char *msg;
|
||||
+ } ext_server_resp_t;
|
||||
+
|
||||
+ // Allocated and freed by caller
|
||||
+ typedef struct ext_server_lora_adapter {
|
||||
+ char *adapter;
|
||||
+ float scale;
|
||||
+ struct ext_server_lora_adapter *next;
|
||||
+ } ext_server_lora_adapter_t;
|
||||
+
|
||||
+ // Allocated and freed by caller
|
||||
+ typedef struct ext_server_params
|
||||
+ {
|
||||
+ char *model;
|
||||
+ uint32_t n_ctx; // text context, 0 = from model
|
||||
+ uint32_t n_batch; // prompt processing maximum batch size
|
||||
+ uint32_t n_threads; // number of threads to use for generation
|
||||
+ int32_t n_parallel; // number of parallel sequences to decodewra
|
||||
+ float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||
+ float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
||||
+ bool memory_f16; // use f16 instead of f32 for memory kv
|
||||
+ int32_t n_gpu_layers; // number of layers to store in VRAM (-1 - use default)
|
||||
+ int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||
+ bool use_mlock; // force system to keep model in RAM
|
||||
+ bool use_mmap; // use mmap if possible
|
||||
+ bool numa; // attempt optimizations that help on some NUMA systems
|
||||
+ bool embedding; // get only sentence embedding
|
||||
+ ext_server_lora_adapter_t* lora_adapters;
|
||||
+ char *mmproj;
|
||||
+ } ext_server_params_t;
|
||||
+
|
||||
+ typedef struct ext_server_task_result
|
||||
+ {
|
||||
+ int id;
|
||||
+ bool stop;
|
||||
+ bool error;
|
||||
+ char* json_resp; // null terminated, memory managed by ext_server
|
||||
+ } ext_server_task_result_t;
|
||||
+
|
||||
+ // Initialize the server once per process
|
||||
+ // err->id = 0 for success and err->msg[0] = NULL
|
||||
+ // err->id != 0 for failure, and err->msg contains error message
|
||||
+ void llama_server_init(ext_server_params_t *sparams, ext_server_resp_t *err);
|
||||
+
|
||||
+ // Run the main loop, called once per init
|
||||
+ void llama_server_start();
|
||||
+ // Stop the main loop and free up resources allocated in init and start. Init must be called again to reuse
|
||||
+ void llama_server_stop();
|
||||
+
|
||||
+ // json_req null terminated string, memory managed by caller
|
||||
+ // resp->id >= 0 on success (task ID)
|
||||
+ // resp->id < 0 on error, and resp->msg contains error message
|
||||
+ void llama_server_completion(const char *json_req, ext_server_resp_t *resp);
|
||||
+
|
||||
+ // Caller must call llama_server_release_task_result to free resp->json_resp
|
||||
+ void llama_server_completion_next_result(const int task_id, ext_server_task_result_t *result);
|
||||
+ void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err);
|
||||
+ void llama_server_release_task_result(ext_server_task_result_t *result);
|
||||
+
|
||||
+ // Caller must call llama_server_releaes_json_resp to free json_resp if err.id < 0
|
||||
+ void llama_server_tokenize(const char *json_req, char **json_resp, ext_server_resp_t *err);
|
||||
+ void llama_server_detokenize(const char *json_req, char **json_resp, ext_server_resp_t *err);
|
||||
+ void llama_server_embedding(const char *json_req, char** json_resp, ext_server_resp_t *err);
|
||||
+ void llama_server_release_json_resp(char **json_resp);
|
||||
+
|
||||
+#ifdef __cplusplus
|
||||
+}
|
||||
+#endif
|
||||
+
|
||||
+#endif
|
||||
+#endif // LLAMA_SERVER_LIBRARY
|
||||
\ No newline at end of file
|
||||
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
|
||||
index f20846f..9640cf3 100644
|
||||
--- a/ggml-cuda.cu
|
||||
+++ b/ggml-cuda.cu
|
||||
@@ -6757,6 +6757,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
src_ptr = (char *) extra->data_device[id];
|
||||
} else {
|
||||
+ fprintf(stderr, "ggml_cuda_cpy_tensor_2d assert: backend: %d\n", src->backend);
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
char * dst_ptr = (char *) dst;
|
||||
--
|
||||
2.39.3 (Apple Git-145)
|
||||
|
||||
48
llm/llama.go
48
llm/llama.go
@@ -6,11 +6,8 @@ import (
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -169,9 +166,10 @@ const maxRetries = 3
|
||||
const retryDelay = 1 * time.Second
|
||||
|
||||
type PredictOpts struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Images []api.ImageData
|
||||
Prompt string
|
||||
Format string
|
||||
Images []api.ImageData
|
||||
Options api.Options
|
||||
}
|
||||
|
||||
type PredictResult struct {
|
||||
@@ -206,41 +204,3 @@ type EmbeddingRequest struct {
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
func extractDynamicLibs(workDir, glob string) ([]string, error) {
|
||||
files, err := fs.Glob(libEmbed, glob)
|
||||
if err != nil || len(files) == 0 {
|
||||
return nil, payloadMissing
|
||||
}
|
||||
libs := make([]string, len(files))
|
||||
|
||||
for i, file := range files {
|
||||
srcFile, err := libEmbed.Open(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read payload %s: %v", file, err)
|
||||
}
|
||||
defer srcFile.Close()
|
||||
if err := os.MkdirAll(workDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create payload temp dir %s: %v", workDir, err)
|
||||
}
|
||||
|
||||
destFile := filepath.Join(workDir, filepath.Base(file))
|
||||
libs[i] = destFile
|
||||
|
||||
_, err = os.Stat(destFile)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("write payload %s: %v", file, err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
if _, err := io.Copy(destFile, srcFile); err != nil {
|
||||
return nil, fmt.Errorf("copy payload %s: %v", file, err)
|
||||
}
|
||||
case err != nil:
|
||||
return nil, fmt.Errorf("stat payload %s: %v", file, err)
|
||||
}
|
||||
}
|
||||
return libs, nil
|
||||
}
|
||||
|
||||
19
llm/llm.go
19
llm/llm.go
@@ -41,16 +41,6 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
switch ggml.FileType() {
|
||||
case "F32", "Q5_0", "Q5_1", "Q8_0":
|
||||
if ggml.Name() != "gguf" && opts.NumGPU != 0 {
|
||||
// GGML Q8_0 do not support Metal API and will
|
||||
// cause the runner to segmentation fault so disable GPU
|
||||
log.Printf("WARNING: GPU disabled for F32, Q5_0, Q5_1, and Q8_0")
|
||||
opts.NumGPU = 0
|
||||
}
|
||||
}
|
||||
|
||||
var requiredMemory int64
|
||||
var f16Multiplier int64 = 2
|
||||
|
||||
@@ -61,6 +51,8 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
|
||||
requiredMemory = 16 * format.GigaByte
|
||||
case "30B", "34B", "40B":
|
||||
requiredMemory = 32 * format.GigaByte
|
||||
case "47B":
|
||||
requiredMemory = 48 * format.GigaByte
|
||||
case "65B", "70B":
|
||||
requiredMemory = 64 * format.GigaByte
|
||||
case "180B":
|
||||
@@ -71,9 +63,9 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
|
||||
systemMemory := int64(memory.TotalMemory())
|
||||
|
||||
if ggml.FileType() == "F16" && requiredMemory*f16Multiplier > systemMemory {
|
||||
return nil, fmt.Errorf("F16 model requires at least %s of total memory", format.HumanBytes(requiredMemory))
|
||||
return nil, fmt.Errorf("F16 model requires at least %s of memory", format.HumanBytes(requiredMemory))
|
||||
} else if requiredMemory > systemMemory {
|
||||
return nil, fmt.Errorf("model requires at least %s of total memory", format.HumanBytes(requiredMemory))
|
||||
return nil, fmt.Errorf("model requires at least %s of memory", format.HumanBytes(requiredMemory))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +87,8 @@ func newLlmServer(library, model string, adapters, projectors []string, numLayer
|
||||
if err == nil {
|
||||
return srv, nil
|
||||
}
|
||||
log.Printf("Failed to load dynamic library - falling back to CPU mode %s", err)
|
||||
log.Printf("Failed to load dynamic library %s - falling back to CPU mode %s", library, err)
|
||||
// TODO - update some state to indicate we were unable to load the GPU library for future "info" ux
|
||||
}
|
||||
|
||||
return newDefaultExtServer(model, adapters, projectors, numLayers, opts)
|
||||
|
||||
@@ -2,9 +2,13 @@ package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
@@ -18,7 +22,7 @@ func newDynamicShimExtServer(library, model string, adapters, projectors []strin
|
||||
}
|
||||
|
||||
func nativeInit(workdir string) error {
|
||||
_, err := extractDynamicLibs(workdir, "llama.cpp/gguf/ggml-metal.metal")
|
||||
err := extractPayloadFiles(workdir, "llama.cpp/gguf/ggml-metal.metal")
|
||||
if err != nil {
|
||||
if err == payloadMissing {
|
||||
// TODO perhaps consider this a hard failure on arm macs?
|
||||
@@ -30,3 +34,38 @@ func nativeInit(workdir string) error {
|
||||
os.Setenv("GGML_METAL_PATH_RESOURCES", workdir)
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractPayloadFiles(workDir, glob string) error {
|
||||
files, err := fs.Glob(libEmbed, glob)
|
||||
if err != nil || len(files) == 0 {
|
||||
return payloadMissing
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
srcFile, err := libEmbed.Open(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read payload %s: %v", file, err)
|
||||
}
|
||||
defer srcFile.Close()
|
||||
if err := os.MkdirAll(workDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create payload temp dir %s: %v", workDir, err)
|
||||
}
|
||||
|
||||
destFile := filepath.Join(workDir, filepath.Base(file))
|
||||
_, err = os.Stat(destFile)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write payload %s: %v", file, err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
if _, err := io.Copy(destFile, srcFile); err != nil {
|
||||
return fmt.Errorf("copy payload %s: %v", file, err)
|
||||
}
|
||||
case err != nil:
|
||||
return fmt.Errorf("stat payload %s: %v", file, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,9 +11,9 @@ package llm
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
@@ -25,11 +25,6 @@ import (
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/gguf/build/lib/*
|
||||
var libEmbed embed.FS
|
||||
|
||||
var RocmShimMissing = fmt.Errorf("ROCm shim library not included in this build of ollama. Radeon GPUs are not supported")
|
||||
|
||||
type shimExtServer struct {
|
||||
s C.struct_dynamic_llama_server
|
||||
options api.Options
|
||||
@@ -78,6 +73,7 @@ func (llm *shimExtServer) llama_server_release_json_resp(json_resp **C.char) {
|
||||
func newDynamicShimExtServer(library, model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
||||
shimMutex.Lock()
|
||||
defer shimMutex.Unlock()
|
||||
updatePath(filepath.Dir(library))
|
||||
libPath := C.CString(library)
|
||||
defer C.free(unsafe.Pointer(libPath))
|
||||
resp := newExtServerResp(128)
|
||||
@@ -96,7 +92,7 @@ func newDynamicShimExtServer(library, model string, adapters, projectors []strin
|
||||
}
|
||||
|
||||
func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
|
||||
return predict(llm, llm.options, ctx, pred, fn)
|
||||
return predict(ctx, llm, pred, fn)
|
||||
}
|
||||
|
||||
func (llm *shimExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
@@ -116,7 +112,7 @@ func (llm *shimExtServer) Close() {
|
||||
}
|
||||
|
||||
func nativeInit(workdir string) error {
|
||||
libs, err := extractDynamicLibs(workdir, "llama.cpp/gguf/build/lib/*server*")
|
||||
libs, err := extractDynamicLibs(workdir, "llama.cpp/gguf/build/*/*/lib/*")
|
||||
if err != nil {
|
||||
if err == payloadMissing {
|
||||
log.Printf("%s", payloadMissing)
|
||||
@@ -125,28 +121,71 @@ func nativeInit(workdir string) error {
|
||||
return err
|
||||
}
|
||||
for _, lib := range libs {
|
||||
libName := strings.Split(strings.TrimPrefix(filepath.Base(lib), "lib"), ".")[0]
|
||||
AvailableShims[libName] = lib
|
||||
// The last dir component is the variant name
|
||||
variant := filepath.Base(filepath.Dir(lib))
|
||||
AvailableShims[variant] = lib
|
||||
}
|
||||
|
||||
// Only check ROCm access if we have the dynamic lib loaded
|
||||
if _, rocmPresent := AvailableShims["rocm_server"]; rocmPresent {
|
||||
// Verify we have permissions - either running as root, or we have group access to the driver
|
||||
fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0666)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrPermission) {
|
||||
log.Fatalf("Radeon card detected, but permissions not set up properly. Either run ollama as root, or add you user account to the render group.")
|
||||
return err
|
||||
} else if errors.Is(err, fs.ErrNotExist) {
|
||||
// expected behavior without a radeon card
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to check permission on /dev/kfd: %w", err)
|
||||
}
|
||||
fd.Close()
|
||||
|
||||
if err := verifyDriverAccess(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Report which dynamic libraries we have loaded to assist troubleshooting
|
||||
variants := make([]string, len(AvailableShims))
|
||||
i := 0
|
||||
for variant := range AvailableShims {
|
||||
variants[i] = variant
|
||||
i++
|
||||
}
|
||||
log.Printf("Dynamic LLM variants %v", variants)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractDynamicLibs(workDir, glob string) ([]string, error) {
|
||||
files, err := fs.Glob(libEmbed, glob)
|
||||
if err != nil || len(files) == 0 {
|
||||
return nil, payloadMissing
|
||||
}
|
||||
libs := []string{}
|
||||
|
||||
for _, file := range files {
|
||||
pathComps := strings.Split(file, "/")
|
||||
if len(pathComps) != 7 {
|
||||
log.Printf("unexpected payload components: %v", pathComps)
|
||||
continue
|
||||
}
|
||||
// llama.cpp/gguf/build/$OS/$VARIANT/lib/$LIBRARY
|
||||
// Include the variant in the path to avoid conflicts between multiple server libs
|
||||
targetDir := filepath.Join(workDir, pathComps[4])
|
||||
srcFile, err := libEmbed.Open(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read payload %s: %v", file, err)
|
||||
}
|
||||
defer srcFile.Close()
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create payload temp dir %s: %v", workDir, err)
|
||||
}
|
||||
|
||||
destFile := filepath.Join(targetDir, filepath.Base(file))
|
||||
if strings.Contains(destFile, "server") {
|
||||
libs = append(libs, destFile)
|
||||
}
|
||||
|
||||
_, err = os.Stat(destFile)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("write payload %s: %v", file, err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
if _, err := io.Copy(destFile, srcFile); err != nil {
|
||||
return nil, fmt.Errorf("copy payload %s: %v", file, err)
|
||||
}
|
||||
case err != nil:
|
||||
return nil, fmt.Errorf("stat payload %s: %v", file, err)
|
||||
}
|
||||
}
|
||||
return libs, nil
|
||||
}
|
||||
|
||||
46
llm/shim_ext_server_linux.go
Normal file
46
llm/shim_ext_server_linux.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/gguf/build/*/*/lib/*.so
|
||||
var libEmbed embed.FS
|
||||
|
||||
func updatePath(dir string) {
|
||||
pathComponents := strings.Split(os.Getenv("PATH"), ":")
|
||||
for _, comp := range pathComponents {
|
||||
if comp == dir {
|
||||
return
|
||||
}
|
||||
}
|
||||
newPath := strings.Join(append(pathComponents, dir), ":")
|
||||
log.Printf("Updating PATH to %s", newPath)
|
||||
os.Setenv("PATH", newPath)
|
||||
}
|
||||
|
||||
func verifyDriverAccess() error {
|
||||
// Only check ROCm access if we have the dynamic lib loaded
|
||||
if _, rocmPresent := AvailableShims["rocm"]; rocmPresent {
|
||||
// Verify we have permissions - either running as root, or we have group access to the driver
|
||||
fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0666)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrPermission) {
|
||||
return fmt.Errorf("Radeon card detected, but permissions not set up properly. Either run ollama as root, or add you user account to the render group.")
|
||||
} else if errors.Is(err, fs.ErrNotExist) {
|
||||
// expected behavior without a radeon card
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to check permission on /dev/kfd: %w", err)
|
||||
}
|
||||
fd.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
36
llm/shim_ext_server_windows.go
Normal file
36
llm/shim_ext_server_windows.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/gguf/build/windows/*/lib/*.dll
|
||||
var libEmbed embed.FS
|
||||
|
||||
func updatePath(dir string) {
|
||||
tmpDir := filepath.Dir(dir)
|
||||
pathComponents := strings.Split(os.Getenv("PATH"), ";")
|
||||
i := 0
|
||||
for _, comp := range pathComponents {
|
||||
if strings.EqualFold(comp, dir) {
|
||||
return
|
||||
}
|
||||
// Remove any other prior paths to our temp dir
|
||||
if !strings.HasPrefix(strings.ToLower(comp), strings.ToLower(tmpDir)) {
|
||||
pathComponents[i] = comp
|
||||
i++
|
||||
}
|
||||
}
|
||||
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
|
||||
log.Printf("Updating PATH to %s", newPath)
|
||||
os.Setenv("PATH", newPath)
|
||||
}
|
||||
|
||||
func verifyDriverAccess() error {
|
||||
// TODO if applicable
|
||||
return nil
|
||||
}
|
||||
@@ -5,13 +5,11 @@ set -eu
|
||||
export VERSION=${VERSION:-0.0.0}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
|
||||
|
||||
docker buildx build \
|
||||
docker build \
|
||||
--load \
|
||||
--platform=linux/arm64,linux/amd64 \
|
||||
--build-arg=VERSION \
|
||||
--build-arg=GOFLAGS \
|
||||
--cache-from type=local,src=.cache \
|
||||
--cache-to type=local,dest=.cache \
|
||||
-f Dockerfile \
|
||||
-t ollama/ollama:$VERSION \
|
||||
.
|
||||
|
||||
@@ -8,7 +8,7 @@ export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version
|
||||
mkdir -p dist
|
||||
|
||||
for TARGETARCH in amd64 arm64; do
|
||||
docker buildx build --load --platform=linux/$TARGETARCH --build-arg=VERSION --build-arg=GOFLAGS --build-arg=CGO_CFLAGS -f Dockerfile.build -t builder:$TARGETARCH .
|
||||
docker build --platform=linux/$TARGETARCH --build-arg=VERSION --build-arg=GOFLAGS --build-arg=CGO_CFLAGS -f Dockerfile.build -t builder:$TARGETARCH .
|
||||
docker create --platform linux/$TARGETARCH --name builder-$TARGETARCH builder:$TARGETARCH
|
||||
docker cp builder-$TARGETARCH:/go/src/github.com/jmorganca/ollama/ollama ./dist/ollama-linux-$TARGETARCH
|
||||
docker rm builder-$TARGETARCH
|
||||
|
||||
@@ -33,6 +33,14 @@ case "$ARCH" in
|
||||
*) error "Unsupported architecture: $ARCH" ;;
|
||||
esac
|
||||
|
||||
KERN=$(uname -r)
|
||||
case "$KERN" in
|
||||
*icrosoft*WSL2 | *icrosoft*wsl2) ;;
|
||||
*icrosoft) error "Microsoft WSL1 is not currently supported. Please upgrade to WSL2 with 'wsl --set-version <distro> 2'" ;;
|
||||
*) ;;
|
||||
esac
|
||||
|
||||
|
||||
SUDO=
|
||||
if [ "$(id -u)" -ne 0 ]; then
|
||||
# Running as root, no need for sudo
|
||||
@@ -76,6 +84,10 @@ configure_systemd() {
|
||||
status "Creating ollama user..."
|
||||
$SUDO useradd -r -s /bin/false -m -d /usr/share/ollama ollama
|
||||
fi
|
||||
if getent group render >/dev/null 2>&1; then
|
||||
status "Adding ollama user to render group..."
|
||||
$SUDO usermod -a -G render ollama
|
||||
fi
|
||||
|
||||
status "Adding current user to ollama group..."
|
||||
$SUDO usermod -a -G ollama $(whoami)
|
||||
|
||||
@@ -5,12 +5,11 @@ set -eu
|
||||
export VERSION=${VERSION:-0.0.0}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
|
||||
|
||||
docker buildx build \
|
||||
docker build \
|
||||
--push \
|
||||
--platform=linux/arm64,linux/amd64 \
|
||||
--build-arg=VERSION \
|
||||
--build-arg=GOFLAGS \
|
||||
--cache-from type=local,src=.cache \
|
||||
-f Dockerfile \
|
||||
-t ollama/ollama -t ollama/ollama:$VERSION \
|
||||
.
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
|
||||
# This script sets up integration tests which run the full stack to verify
|
||||
# inference locally
|
||||
#
|
||||
# To run the relevant tests use
|
||||
# go test -tags=integration ./server
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
@@ -21,15 +24,15 @@ for model in ${TEST_MODELS[@]}; do
|
||||
echo "Pulling manifest for ${TEST_MODEL}:${TEST_MODEL_TAG}"
|
||||
curl -s --header "${ACCEPT_HEADER}" \
|
||||
-o ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} \
|
||||
${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/manifests/${TEST_MODEL_TAG}
|
||||
${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/manifests/${TEST_MODEL_TAG}
|
||||
|
||||
CFG_HASH=$(cat ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} | jq -r ".config.digest")
|
||||
echo "Pulling config blob ${CFG_HASH}"
|
||||
curl -L -C - --header "${ACCEPT_HEADER}" \
|
||||
-o ${OLLAMA_MODELS}/blobs/${CFG_HASH} \
|
||||
${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/blobs/${CFG_HASH}
|
||||
-o ${OLLAMA_MODELS}/blobs/${CFG_HASH} \
|
||||
${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/blobs/${CFG_HASH}
|
||||
|
||||
for LAYER in $(cat ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} | jq -r ".layers[].digest" ) ; do
|
||||
for LAYER in $(cat ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} | jq -r ".layers[].digest"); do
|
||||
echo "Pulling blob ${LAYER}"
|
||||
curl -L -C - --header "${ACCEPT_HEADER}" \
|
||||
-o ${OLLAMA_MODELS}/blobs/${LAYER} \
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
"text/template/parse"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
@@ -57,17 +58,35 @@ type PromptVars struct {
|
||||
First bool
|
||||
}
|
||||
|
||||
func (m *Model) Prompt(p PromptVars) (string, error) {
|
||||
var prompt strings.Builder
|
||||
// Use the "missingkey=zero" option to handle missing variables without panicking
|
||||
tmpl, err := template.New("").Option("missingkey=zero").Parse(m.Template)
|
||||
// extractParts extracts the parts of the template before and after the {{.Response}} node.
|
||||
func extractParts(tmplStr string) (pre string, post string, err error) {
|
||||
tmpl, err := template.New("").Parse(tmplStr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if p.System == "" {
|
||||
// use the default system message for this model if one is not specified
|
||||
p.System = m.System
|
||||
var foundResponse bool
|
||||
|
||||
for _, node := range tmpl.Tree.Root.Nodes {
|
||||
if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" {
|
||||
foundResponse = true
|
||||
}
|
||||
if !foundResponse {
|
||||
pre += node.String()
|
||||
} else {
|
||||
post += node.String()
|
||||
}
|
||||
}
|
||||
|
||||
return pre, post, nil
|
||||
}
|
||||
|
||||
func Prompt(promptTemplate string, p PromptVars) (string, error) {
|
||||
var prompt strings.Builder
|
||||
// Use the "missingkey=zero" option to handle missing variables without panicking
|
||||
tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
vars := map[string]any{
|
||||
@@ -82,20 +101,59 @@ func (m *Model) Prompt(p PromptVars) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
prompt.WriteString(sb.String())
|
||||
prompt.WriteString(p.Response)
|
||||
|
||||
if !strings.Contains(prompt.String(), p.Response) {
|
||||
// if the response is not in the prompt template, append it to the end
|
||||
prompt.WriteString(p.Response)
|
||||
}
|
||||
|
||||
return prompt.String(), nil
|
||||
}
|
||||
|
||||
// PreResponsePrompt returns the prompt before the response tag
|
||||
func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
|
||||
if p.System == "" {
|
||||
// use the default system prompt for this model if one is not specified
|
||||
p.System = m.System
|
||||
}
|
||||
pre, _, err := extractParts(m.Template)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return Prompt(pre, p)
|
||||
}
|
||||
|
||||
// PostResponseTemplate returns the template after the response tag
|
||||
func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
|
||||
if p.System == "" {
|
||||
// use the default system prompt for this model if one is not specified
|
||||
p.System = m.System
|
||||
}
|
||||
_, post, err := extractParts(m.Template)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if post == "" {
|
||||
// if there is no post-response template, return the provided response
|
||||
return p.Response, nil
|
||||
}
|
||||
|
||||
return Prompt(post, p)
|
||||
}
|
||||
|
||||
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
|
||||
// build the prompt from the list of messages
|
||||
var prompt strings.Builder
|
||||
var currentImages []api.ImageData
|
||||
currentVars := PromptVars{
|
||||
First: true,
|
||||
First: true,
|
||||
System: m.System,
|
||||
}
|
||||
|
||||
writePrompt := func() error {
|
||||
p, err := m.Prompt(currentVars)
|
||||
p, err := Prompt(m.Template, currentVars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -133,9 +191,11 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error)
|
||||
|
||||
// Append the last set of vars if they are non-empty
|
||||
if currentVars.Prompt != "" || currentVars.System != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", nil, err
|
||||
p, err := m.PreResponsePrompt(currentVars)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("pre-response template: %w", err)
|
||||
}
|
||||
prompt.WriteString(p)
|
||||
}
|
||||
|
||||
return prompt.String(), currentImages, nil
|
||||
@@ -425,9 +485,15 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := PullModel(ctx, parent.OriginalModel, &RegistryOptions{}, fn); err != nil {
|
||||
log.Printf("error pulling model: %v", err)
|
||||
|
||||
originalModel := parent.OriginalModel
|
||||
if originalModel == "" {
|
||||
originalModel = parent.ShortName
|
||||
}
|
||||
if err := PullModel(ctx, originalModel, &RegistryOptions{}, fn); err != nil {
|
||||
log.Printf("error pulling parent model: %v", err)
|
||||
}
|
||||
|
||||
// Reset the file pointer to the beginning of the file
|
||||
_, err = fromConfigFile.Seek(0, 0)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,232 @@ import (
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
func TestPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
vars PromptVars
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "System Prompt",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
vars: PromptVars{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "System Prompt with Response",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
||||
vars: PromptVars{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "I don't know.",
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
|
||||
},
|
||||
{
|
||||
name: "Conditional Logic Nodes",
|
||||
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
||||
vars: PromptVars{
|
||||
First: true,
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "I don't know.",
|
||||
},
|
||||
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Prompt(tt.template, tt.vars)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("Prompt() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_PreResponsePrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
vars PromptVars
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "No Response in Template",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
vars: PromptVars{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Response in Template",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
||||
vars: PromptVars{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ",
|
||||
},
|
||||
{
|
||||
name: "Response in Template with Trailing Formatting",
|
||||
template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
|
||||
vars: PromptVars{
|
||||
Prompt: "What are the potion ingredients?",
|
||||
},
|
||||
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "Response in Template with Alternative Formatting",
|
||||
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
|
||||
vars: PromptVars{
|
||||
Prompt: "What are the potion ingredients?",
|
||||
},
|
||||
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
m := Model{Template: tt.template}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := m.PreResponsePrompt(tt.vars)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_PostResponsePrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
vars PromptVars
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "No Response in Template",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
vars: PromptVars{
|
||||
Response: "I don't know.",
|
||||
},
|
||||
want: "I don't know.",
|
||||
},
|
||||
{
|
||||
name: "Response in Template",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
||||
vars: PromptVars{
|
||||
Response: "I don't know.",
|
||||
},
|
||||
want: "I don't know.",
|
||||
},
|
||||
{
|
||||
name: "Response in Template with Trailing Formatting",
|
||||
template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
|
||||
vars: PromptVars{
|
||||
Response: "I don't know.",
|
||||
},
|
||||
want: "I don't know.<|im_end|>",
|
||||
},
|
||||
{
|
||||
name: "Response in Template with Alternative Formatting",
|
||||
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
|
||||
vars: PromptVars{
|
||||
Response: "I don't know.",
|
||||
},
|
||||
want: "I don't know.<|im_end|>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
m := Model{Template: tt.template}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := m.PostResponseTemplate(tt.vars)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
preVars PromptVars
|
||||
postVars PromptVars
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Response in Template",
|
||||
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
|
||||
preVars: PromptVars{
|
||||
Prompt: "What are the potion ingredients?",
|
||||
},
|
||||
postVars: PromptVars{
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "Sugar.",
|
||||
},
|
||||
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>",
|
||||
},
|
||||
{
|
||||
name: "No Response in Template",
|
||||
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n",
|
||||
preVars: PromptVars{
|
||||
Prompt: "What are the potion ingredients?",
|
||||
},
|
||||
postVars: PromptVars{
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "Spice.",
|
||||
},
|
||||
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
m := Model{Template: tt.template}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pre, err := m.PreResponsePrompt(tt.preVars)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
post, err := m.PostResponseTemplate(tt.postVars)
|
||||
if err != nil {
|
||||
t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
result := pre + post
|
||||
if result != tt.want {
|
||||
t.Errorf("Prompt() got = %v, want %v", result, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -30,6 +256,29 @@ func TestChat(t *testing.T) {
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "First Message",
|
||||
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a Wizard.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What are the potion ingredients?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "eye of newt",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Anything else?",
|
||||
},
|
||||
},
|
||||
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Message History",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build integration
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build integration
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build integration
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
@@ -38,7 +40,7 @@ func PrepareModelForPrompts(t *testing.T, modelName string, opts api.Options) (*
|
||||
}
|
||||
|
||||
func OneShotPromptResponse(t *testing.T, ctx context.Context, req api.GenerateRequest, model *Model, runner llm.LLM) string {
|
||||
prompt, err := model.Prompt(PromptVars{
|
||||
prompt, err := model.PreResponsePrompt(PromptVars{
|
||||
System: req.System,
|
||||
Prompt: req.Prompt,
|
||||
First: len(req.Context) == 0,
|
||||
@@ -54,6 +56,7 @@ func OneShotPromptResponse(t *testing.T, ctx context.Context, req api.GenerateRe
|
||||
success <- true
|
||||
}
|
||||
}
|
||||
|
||||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
|
||||
145
server/routes.go
145
server/routes.go
@@ -64,24 +64,9 @@ var loaded struct {
|
||||
var defaultSessionDuration = 5 * time.Minute
|
||||
|
||||
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
||||
func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
|
||||
model, err := GetModel(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
|
||||
workDir := c.GetString("workDir")
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
if err := opts.FromMap(model.Options); err != nil {
|
||||
log.Printf("could not load model options: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := opts.FromMap(reqOpts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
needLoad := loaded.runner == nil || // is there a model loaded?
|
||||
loaded.ModelPath != model.ModelPath || // has the base model changed?
|
||||
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
|
||||
@@ -105,7 +90,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
|
||||
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
loaded.Model = model
|
||||
@@ -135,7 +120,20 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
|
||||
}
|
||||
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
return model, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||
opts := api.DefaultOptions()
|
||||
if err := opts.FromMap(model.Options); err != nil {
|
||||
return api.Options{}, err
|
||||
}
|
||||
|
||||
if err := opts.FromMap(requestOpts); err != nil {
|
||||
return api.Options{}, err
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func GenerateHandler(c *gin.Context) {
|
||||
@@ -168,18 +166,30 @@ func GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
sessionDuration := defaultSessionDuration
|
||||
model, err := load(c, req.Model, req.Options, sessionDuration)
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
switch {
|
||||
case errors.As(err, &pErr):
|
||||
if errors.As(err, &pErr) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
case errors.Is(err, api.ErrInvalidOpts):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
if errors.Is(err, api.ErrInvalidOpts) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
sessionDuration := defaultSessionDuration
|
||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -195,6 +205,7 @@ func GenerateHandler(c *gin.Context) {
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
var prompt string
|
||||
var promptVars PromptVars
|
||||
switch {
|
||||
case req.Raw:
|
||||
prompt = req.Prompt
|
||||
@@ -217,11 +228,12 @@ func GenerateHandler(c *gin.Context) {
|
||||
prevCtx = strings.TrimPrefix(prevCtx, " ")
|
||||
rebuild.WriteString(prevCtx)
|
||||
}
|
||||
p, err := model.Prompt(PromptVars{
|
||||
promptVars = PromptVars{
|
||||
System: req.System,
|
||||
Prompt: req.Prompt,
|
||||
First: len(req.Context) == 0,
|
||||
})
|
||||
}
|
||||
p, err := model.PreResponsePrompt(promptVars)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -264,7 +276,14 @@ func GenerateHandler(c *gin.Context) {
|
||||
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
|
||||
if !req.Raw {
|
||||
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
|
||||
// append the generated text to the history and template it if needed
|
||||
promptVars.Response = generated.String()
|
||||
result, err := model.PostResponseTemplate(promptVars)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
@@ -278,9 +297,10 @@ func GenerateHandler(c *gin.Context) {
|
||||
|
||||
// Start prediction
|
||||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: req.Images,
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: req.Images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
@@ -338,18 +358,29 @@ func EmbeddingHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
sessionDuration := defaultSessionDuration
|
||||
_, err = load(c, req.Model, req.Options, sessionDuration)
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
switch {
|
||||
case errors.As(err, &pErr):
|
||||
if errors.As(err, &pErr) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
case errors.Is(err, api.ErrInvalidOpts):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
if errors.Is(err, api.ErrInvalidOpts) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
sessionDuration := defaultSessionDuration
|
||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -982,18 +1013,29 @@ func ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
sessionDuration := defaultSessionDuration
|
||||
model, err := load(c, req.Model, req.Options, sessionDuration)
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
switch {
|
||||
case errors.As(err, &pErr):
|
||||
if errors.As(err, &pErr) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
case errors.Is(err, api.ErrInvalidOpts):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
if errors.Is(err, api.ErrInvalidOpts) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
sessionDuration := defaultSessionDuration
|
||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1044,9 +1086,10 @@ func ChatHandler(c *gin.Context) {
|
||||
|
||||
// Start prediction
|
||||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: images,
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
|
||||
Reference in New Issue
Block a user