Compare commits
1 Commits
remove-fir
...
mattw/howt
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4522109b11 |
@@ -5,8 +5,8 @@ ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama
|
||||
RUN apt-get update && apt-get install -y git build-essential cmake
|
||||
ADD https://dl.google.com/go/go1.21.3.linux-$TARGETARCH.tar.gz /tmp/go1.21.3.tar.gz
|
||||
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.3.tar.gz
|
||||
ADD https://dl.google.com/go/go1.21.1.linux-$TARGETARCH.tar.gz /tmp/go1.21.1.tar.gz
|
||||
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.1.tar.gz
|
||||
|
||||
COPY . .
|
||||
ENV GOARCH=$TARGETARCH
|
||||
|
@@ -1,5 +1,6 @@
|
||||
|
||||
# centos7 amd64 dependencies
|
||||
FROM --platform=linux/amd64 nvidia/cuda:11.3.1-devel-centos7 AS base-amd64
|
||||
FROM --platform=linux/amd64 nvidia/cuda:11.8.0-devel-centos7 AS base-amd64
|
||||
RUN yum install -y https://repo.ius.io/ius-release-el7.rpm centos-release-scl && \
|
||||
yum update -y && \
|
||||
yum install -y devtoolset-10-gcc devtoolset-10-gcc-c++ git236 wget
|
||||
@@ -7,7 +8,7 @@ RUN wget "https://github.com/Kitware/CMake/releases/download/v3.27.6/cmake-3.27.
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
|
||||
# centos8 arm64 dependencies
|
||||
FROM --platform=linux/arm64 nvidia/cuda-arm64:11.3.1-devel-centos8 AS base-arm64
|
||||
FROM --platform=linux/arm64 nvidia/cuda:11.4.3-devel-centos8 AS base-arm64
|
||||
RUN sed -i -e 's/mirrorlist/#mirrorlist/g' -e 's|#baseurl=http://mirror.centos.org|baseurl=http://vault.centos.org|g' /etc/yum.repos.d/CentOS-*
|
||||
RUN yum install -y git cmake
|
||||
|
||||
@@ -16,8 +17,8 @@ ARG TARGETARCH
|
||||
ARG GOFLAGS="'-ldflags -w -s'"
|
||||
|
||||
# install go
|
||||
ADD https://dl.google.com/go/go1.21.3.linux-$TARGETARCH.tar.gz /tmp/go1.21.3.tar.gz
|
||||
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.3.tar.gz
|
||||
ADD https://dl.google.com/go/go1.21.1.linux-$TARGETARCH.tar.gz /tmp/go1.21.1.tar.gz
|
||||
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.1.tar.gz
|
||||
|
||||
# build the final binary
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama
|
||||
|
75
README.md
75
README.md
@@ -15,10 +15,6 @@ Get up and running with large language models locally.
|
||||
|
||||
[Download](https://ollama.ai/download/Ollama-darwin.zip)
|
||||
|
||||
### Windows
|
||||
|
||||
Coming soon!
|
||||
|
||||
### Linux & WSL2
|
||||
|
||||
```
|
||||
@@ -27,9 +23,9 @@ curl https://ollama.ai/install.sh | sh
|
||||
|
||||
[Manual install instructions](https://github.com/jmorganca/ollama/blob/main/docs/linux.md)
|
||||
|
||||
### Docker
|
||||
### Windows
|
||||
|
||||
The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `ollama/ollama` is available on Docker Hub.
|
||||
coming soon
|
||||
|
||||
## Quickstart
|
||||
|
||||
@@ -60,11 +56,11 @@ Here are some example open-source models that can be downloaded:
|
||||
|
||||
## Customize your own model
|
||||
|
||||
### Import from GGUF
|
||||
### Import from GGUF or GGML
|
||||
|
||||
Ollama supports importing GGUF models in the Modelfile:
|
||||
Ollama supports importing GGUF and GGML file formats in the Modelfile. This means if you have a model that is not in the Ollama library, you can create it, iterate on it, and upload it to the Ollama library to share with others when you are ready.
|
||||
|
||||
1. Create a file named `Modelfile`, with a `FROM` instruction with the local filepath to the model you want to import.
|
||||
1. Create a file named Modelfile, and add a `FROM` instruction with the local filepath to the model you want to import.
|
||||
|
||||
```
|
||||
FROM ./vicuna-33b.Q4_0.gguf
|
||||
@@ -73,22 +69,18 @@ Ollama supports importing GGUF models in the Modelfile:
|
||||
2. Create the model in Ollama
|
||||
|
||||
```
|
||||
ollama create example -f Modelfile
|
||||
ollama create name -f path_to_modelfile
|
||||
```
|
||||
|
||||
3. Run the model
|
||||
|
||||
```
|
||||
ollama run example
|
||||
ollama run name
|
||||
```
|
||||
|
||||
### Import from PyTorch or Safetensors
|
||||
|
||||
See the [guide](docs/import.md) on importing models for more information.
|
||||
|
||||
### Customize a prompt
|
||||
|
||||
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama2` model:
|
||||
Models from the Ollama library can be customized with a prompt. The example
|
||||
|
||||
```
|
||||
ollama pull llama2
|
||||
@@ -159,7 +151,7 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
|
||||
### Pass in prompt as arguments
|
||||
|
||||
```
|
||||
$ ollama run llama2 "Summarize this file: $(cat README.md)"
|
||||
$ ollama run llama2 "summarize this file:" "$(cat README.md)"
|
||||
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
||||
```
|
||||
|
||||
@@ -178,7 +170,8 @@ ollama list
|
||||
Install `cmake` and `go`:
|
||||
|
||||
```
|
||||
brew install cmake go
|
||||
brew install cmake
|
||||
brew install go
|
||||
```
|
||||
|
||||
Then generate dependencies and build:
|
||||
@@ -202,8 +195,9 @@ Finally, in a separate shell, run a model:
|
||||
|
||||
## REST API
|
||||
|
||||
Ollama has a REST API for running and managing models.
|
||||
For example, to generate text from a model:
|
||||
> See the [API documentation](docs/api.md) for all endpoints.
|
||||
|
||||
Ollama has an API for running and managing models. For example to generate text from a model:
|
||||
|
||||
```
|
||||
curl -X POST http://localhost:11434/api/generate -d '{
|
||||
@@ -212,48 +206,19 @@ curl -X POST http://localhost:11434/api/generate -d '{
|
||||
}'
|
||||
```
|
||||
|
||||
See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
## Community Integrations
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
||||
- [Web UI](https://github.com/ollama-webui/ollama-webui)
|
||||
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
|
||||
- [big-AGI](https://github.com/enricoros/big-agi/blob/main/docs/config-ollama.md)
|
||||
|
||||
### Terminal
|
||||
|
||||
- [oterm](https://github.com/ggozad/oterm)
|
||||
- [Ellama Emacs client](https://github.com/s-kostyaev/ellama)
|
||||
- [Emacs client](https://github.com/zweifisch/ollama)
|
||||
- [gen.nvim](https://github.com/David-Kunz/gen.nvim)
|
||||
- [ollama.nvim](https://github.com/nomnivore/ollama.nvim)
|
||||
- [gptel Emacs client](https://github.com/karthink/gptel)
|
||||
|
||||
### Libraries
|
||||
|
||||
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama) with [example](https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa)
|
||||
- [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/examples/llm/ollama.html)
|
||||
- [LiteLLM](https://github.com/BerriAI/litellm)
|
||||
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
|
||||
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
|
||||
- [Ollama4j for Java](https://github.com/amithkoujalgi/ollama4j)
|
||||
- [ModelFusion Typescript Library](https://modelfusion.dev/integration/model-provider/ollama)
|
||||
- [OllamaKit for Swift](https://github.com/kevinhermawan/OllamaKit)
|
||||
- [Ollama for Dart](https://github.com/breitburg/dart-ollama)
|
||||
|
||||
### Extensions & Plugins
|
||||
|
||||
- [Raycast extension](https://github.com/MassimilianoPasquini97/raycast_ollama)
|
||||
- [Discollama](https://github.com/mxyng/discollama) (Discord bot inside the Ollama discord channel)
|
||||
- [Continue](https://github.com/continuedev/continue)
|
||||
- [Obsidian Ollama plugin](https://github.com/hinterdupfinger/obsidian-ollama)
|
||||
- [Logseq Ollama plugin](https://github.com/omagdy7/ollama-logseq)
|
||||
- [Dagger Chatbot](https://github.com/samalba/dagger-chatbot)
|
||||
- [LiteLLM](https://github.com/BerriAI/litellm)
|
||||
- [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot)
|
||||
- [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||
- [Dumbar](https://github.com/JerrySievert/Dumbar)
|
||||
- [Emacs client](https://github.com/zweifisch/ollama)
|
||||
|
@@ -14,10 +14,13 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/jmorganca/ollama/format"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
||||
const DefaultHost = "127.0.0.1:11434"
|
||||
|
||||
var envHost = os.Getenv("OLLAMA_HOST")
|
||||
|
||||
type Client struct {
|
||||
base *url.URL
|
||||
http http.Client
|
||||
@@ -40,28 +43,16 @@ func checkError(resp *http.Response, body []byte) error {
|
||||
}
|
||||
|
||||
func ClientFromEnvironment() (*Client, error) {
|
||||
defaultPort := "11434"
|
||||
|
||||
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
|
||||
switch {
|
||||
case !ok:
|
||||
if !ok {
|
||||
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
|
||||
case scheme == "http":
|
||||
defaultPort = "80"
|
||||
case scheme == "https":
|
||||
defaultPort = "443"
|
||||
}
|
||||
|
||||
// trim trailing slashes
|
||||
hostport = strings.TrimRight(hostport, "/")
|
||||
|
||||
host, port, err := net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
host, port = "127.0.0.1", defaultPort
|
||||
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
|
||||
host, port = "127.0.0.1", "11434"
|
||||
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
|
||||
host = ip.String()
|
||||
} else if hostport != "" {
|
||||
host = hostport
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +63,7 @@ func ClientFromEnvironment() (*Client, error) {
|
||||
},
|
||||
}
|
||||
|
||||
mockRequest, err := http.NewRequest(http.MethodHead, client.base.String(), nil)
|
||||
mockRequest, err := http.NewRequest("HEAD", client.base.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -136,7 +127,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
const maxBufferSize = 512 * 1000 // 512KB
|
||||
|
||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||
var buf *bytes.Buffer
|
||||
|
@@ -7,7 +7,7 @@ BASE_URL = os.environ.get('OLLAMA_HOST', 'http://localhost:11434')
|
||||
# Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses.
|
||||
# The final response object will include statistics and additional data from the request. Use the callback function to override
|
||||
# the default handler.
|
||||
def generate(model_name, prompt, system=None, template=None, format="", context=None, options=None, callback=None):
|
||||
def generate(model_name, prompt, system=None, template=None, context=None, options=None, callback=None):
|
||||
try:
|
||||
url = f"{BASE_URL}/api/generate"
|
||||
payload = {
|
||||
@@ -16,8 +16,7 @@ def generate(model_name, prompt, system=None, template=None, format="", context=
|
||||
"system": system,
|
||||
"template": template,
|
||||
"context": context,
|
||||
"options": options,
|
||||
"format": format,
|
||||
"options": options
|
||||
}
|
||||
|
||||
# Remove keys with None values
|
||||
|
@@ -1,43 +0,0 @@
|
||||
package api
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestClientFromEnvironment(t *testing.T) {
|
||||
type testCase struct {
|
||||
value string
|
||||
expect string
|
||||
err error
|
||||
}
|
||||
|
||||
testCases := map[string]*testCase{
|
||||
"empty": {value: "", expect: "http://127.0.0.1:11434"},
|
||||
"only address": {value: "1.2.3.4", expect: "http://1.2.3.4:11434"},
|
||||
"only port": {value: ":1234", expect: "http://:1234"},
|
||||
"address and port": {value: "1.2.3.4:1234", expect: "http://1.2.3.4:1234"},
|
||||
"scheme http and address": {value: "http://1.2.3.4", expect: "http://1.2.3.4:80"},
|
||||
"scheme https and address": {value: "https://1.2.3.4", expect: "https://1.2.3.4:443"},
|
||||
"scheme, address, and port": {value: "https://1.2.3.4:1234", expect: "https://1.2.3.4:1234"},
|
||||
"hostname": {value: "example.com", expect: "http://example.com:11434"},
|
||||
"hostname and port": {value: "example.com:1234", expect: "http://example.com:1234"},
|
||||
"scheme http and hostname": {value: "http://example.com", expect: "http://example.com:80"},
|
||||
"scheme https and hostname": {value: "https://example.com", expect: "https://example.com:443"},
|
||||
"scheme, hostname, and port": {value: "https://example.com:1234", expect: "https://example.com:1234"},
|
||||
"trailing slash": {value: "example.com/", expect: "http://example.com:11434"},
|
||||
"trailing slash port": {value: "example.com:1234/", expect: "http://example.com:1234"},
|
||||
}
|
||||
|
||||
for k, v := range testCases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", v.value)
|
||||
|
||||
client, err := ClientFromEnvironment()
|
||||
if err != v.err {
|
||||
t.Fatalf("expected %s, got %s", v.err, err)
|
||||
}
|
||||
|
||||
if client.base.String() != v.expect {
|
||||
t.Fatalf("expected %s, got %s", v.expect, client.base.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
139
api/types.go
139
api/types.go
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
@@ -37,56 +38,10 @@ type GenerateRequest struct {
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
|
||||
type Options struct {
|
||||
Runner
|
||||
|
||||
// Predict options used at runtime
|
||||
NumKeep int `json:"num_keep,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
TFSZ float32 `json:"tfs_z,omitempty"`
|
||||
TypicalP float32 `json:"typical_p,omitempty"`
|
||||
RepeatLastN int `json:"repeat_last_n,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||
Mirostat int `json:"mirostat,omitempty"`
|
||||
MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
||||
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
||||
PenalizeNewline bool `json:"penalize_newline,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
}
|
||||
|
||||
// Runner options which must be set when the model is loaded into memory
|
||||
type Runner struct {
|
||||
UseNUMA bool `json:"numa,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
NumBatch int `json:"num_batch,omitempty"`
|
||||
NumGQA int `json:"num_gqa,omitempty"`
|
||||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
LowVRAM bool `json:"low_vram,omitempty"`
|
||||
F16KV bool `json:"f16_kv,omitempty"`
|
||||
LogitsAll bool `json:"logits_all,omitempty"`
|
||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||
UseMMap bool `json:"use_mmap,omitempty"`
|
||||
UseMLock bool `json:"use_mlock,omitempty"`
|
||||
EmbeddingOnly bool `json:"embedding_only,omitempty"`
|
||||
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
|
||||
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
@@ -207,6 +162,49 @@ func (r *GenerateResponse) Summary() {
|
||||
}
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Seed int `json:"seed,omitempty"`
|
||||
|
||||
// Backend options
|
||||
UseNUMA bool `json:"numa,omitempty"`
|
||||
|
||||
// Model options
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
NumKeep int `json:"num_keep,omitempty"`
|
||||
NumBatch int `json:"num_batch,omitempty"`
|
||||
NumGQA int `json:"num_gqa,omitempty"`
|
||||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
LowVRAM bool `json:"low_vram,omitempty"`
|
||||
F16KV bool `json:"f16_kv,omitempty"`
|
||||
LogitsAll bool `json:"logits_all,omitempty"`
|
||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||
UseMMap bool `json:"use_mmap,omitempty"`
|
||||
UseMLock bool `json:"use_mlock,omitempty"`
|
||||
EmbeddingOnly bool `json:"embedding_only,omitempty"`
|
||||
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
|
||||
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
|
||||
|
||||
// Predict options
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
TFSZ float32 `json:"tfs_z,omitempty"`
|
||||
TypicalP float32 `json:"typical_p,omitempty"`
|
||||
RepeatLastN int `json:"repeat_last_n,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||
Mirostat int `json:"mirostat,omitempty"`
|
||||
MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
||||
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
||||
PenalizeNewline bool `json:"penalize_newline,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
}
|
||||
|
||||
var ErrInvalidOpts = fmt.Errorf("invalid options")
|
||||
|
||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
@@ -240,39 +238,44 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
// when JSON unmarshals numbers, it uses float64, not int
|
||||
field.SetInt(int64(t))
|
||||
default:
|
||||
return fmt.Errorf("option %q must be of type integer", key)
|
||||
log.Printf("could not convert model parameter %v of type %T to int, skipped", key, val)
|
||||
}
|
||||
case reflect.Bool:
|
||||
val, ok := val.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type boolean", key)
|
||||
log.Printf("could not convert model parameter %v of type %T to bool, skipped", key, val)
|
||||
continue
|
||||
}
|
||||
field.SetBool(val)
|
||||
case reflect.Float32:
|
||||
// JSON unmarshals to float64
|
||||
val, ok := val.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type float32", key)
|
||||
log.Printf("could not convert model parameter %v of type %T to float32, skipped", key, val)
|
||||
continue
|
||||
}
|
||||
field.SetFloat(val)
|
||||
case reflect.String:
|
||||
val, ok := val.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type string", key)
|
||||
log.Printf("could not convert model parameter %v of type %T to string, skipped", key, val)
|
||||
continue
|
||||
}
|
||||
field.SetString(val)
|
||||
case reflect.Slice:
|
||||
// JSON unmarshals to []interface{}, not []string
|
||||
val, ok := val.([]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type array", key)
|
||||
log.Printf("could not convert model parameter %v of type %T to slice, skipped", key, val)
|
||||
continue
|
||||
}
|
||||
// convert []interface{} to []string
|
||||
slice := make([]string, len(val))
|
||||
for i, item := range val {
|
||||
str, ok := item.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of an array of strings", key)
|
||||
log.Printf("could not convert model parameter %v of type %T to slice of strings, skipped", key, item)
|
||||
continue
|
||||
}
|
||||
slice[i] = str
|
||||
}
|
||||
@@ -296,7 +299,7 @@ func DefaultOptions() Options {
|
||||
return Options{
|
||||
// options set on request to runner
|
||||
NumPredict: -1,
|
||||
NumKeep: 0,
|
||||
NumKeep: -1,
|
||||
Temperature: 0.8,
|
||||
TopK: 40,
|
||||
TopP: 0.9,
|
||||
@@ -312,22 +315,20 @@ func DefaultOptions() Options {
|
||||
PenalizeNewline: true,
|
||||
Seed: -1,
|
||||
|
||||
Runner: Runner{
|
||||
// options set when the model is loaded
|
||||
NumCtx: 2048,
|
||||
RopeFrequencyBase: 10000.0,
|
||||
RopeFrequencyScale: 1.0,
|
||||
NumBatch: 512,
|
||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||
NumGQA: 1,
|
||||
NumThread: 0, // let the runtime decide
|
||||
LowVRAM: false,
|
||||
F16KV: true,
|
||||
UseMLock: false,
|
||||
UseMMap: true,
|
||||
UseNUMA: false,
|
||||
EmbeddingOnly: true,
|
||||
},
|
||||
// options set when the model is loaded
|
||||
NumCtx: 2048,
|
||||
RopeFrequencyBase: 10000.0,
|
||||
RopeFrequencyScale: 1.0,
|
||||
NumBatch: 512,
|
||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||
NumGQA: 1,
|
||||
NumThread: 0, // let the runtime decide
|
||||
LowVRAM: false,
|
||||
F16KV: true,
|
||||
UseMLock: false,
|
||||
UseMMap: true,
|
||||
UseNUMA: false,
|
||||
EmbeddingOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -47,6 +47,16 @@ const config: ForgeConfig = {
|
||||
},
|
||||
rebuildConfig: {},
|
||||
makers: [new MakerSquirrel({}), new MakerZIP({}, ['darwin'])],
|
||||
publishers: [
|
||||
new PublisherGithub({
|
||||
repository: {
|
||||
name: 'ollama',
|
||||
owner: 'jmorganca',
|
||||
},
|
||||
draft: false,
|
||||
prerelease: true,
|
||||
}),
|
||||
],
|
||||
hooks: {
|
||||
readPackageJson: async (_, packageJson) => {
|
||||
return { ...packageJson, version: process.env.VERSION || packageJson.version }
|
||||
|
990
app/package-lock.json
generated
990
app/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -46,7 +46,7 @@
|
||||
"chmodr": "^1.2.0",
|
||||
"copy-webpack-plugin": "^11.0.0",
|
||||
"css-loader": "^6.8.1",
|
||||
"electron": "25.9.2",
|
||||
"electron": "25.2.0",
|
||||
"eslint": "^8.43.0",
|
||||
"eslint-plugin-import": "^2.27.5",
|
||||
"fork-ts-checker-webpack-plugin": "^7.3.0",
|
||||
|
@@ -162,56 +162,13 @@ app.on('before-quit', () => {
|
||||
}
|
||||
})
|
||||
|
||||
const updateURL = `https://ollama.ai/api/update?os=${process.platform}&arch=${
|
||||
process.arch
|
||||
}&version=${app.getVersion()}&id=${id()}`
|
||||
|
||||
let latest = ''
|
||||
async function isNewReleaseAvailable() {
|
||||
try {
|
||||
const response = await fetch(updateURL)
|
||||
|
||||
if (!response.ok) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (response.status === 204) {
|
||||
return false
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
|
||||
const url = data?.url
|
||||
if (!url) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (latest === url) {
|
||||
return false
|
||||
}
|
||||
|
||||
latest = url
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
logger.error(`update check failed - ${error}`)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
async function checkUpdate() {
|
||||
const available = await isNewReleaseAvailable()
|
||||
if (available) {
|
||||
logger.info('checking for update')
|
||||
autoUpdater.checkForUpdates()
|
||||
}
|
||||
}
|
||||
|
||||
function init() {
|
||||
if (app.isPackaged) {
|
||||
checkUpdate()
|
||||
autoUpdater.checkForUpdates()
|
||||
setInterval(() => {
|
||||
checkUpdate()
|
||||
if (!updateAvailable) {
|
||||
autoUpdater.checkForUpdates()
|
||||
}
|
||||
}, 60 * 60 * 1000)
|
||||
}
|
||||
|
||||
@@ -289,7 +246,11 @@ function id(): string {
|
||||
return uuid
|
||||
}
|
||||
|
||||
autoUpdater.setFeedURL({ url: updateURL })
|
||||
autoUpdater.setFeedURL({
|
||||
url: `https://ollama.ai/api/update?os=${process.platform}&arch=${
|
||||
process.arch
|
||||
}&version=${app.getVersion()}&id=${id()}`,
|
||||
})
|
||||
|
||||
autoUpdater.on('error', e => {
|
||||
logger.error(`update check failed - ${e.message}`)
|
||||
|
267
cmd/cmd.go
267
cmd/cmd.go
@@ -1,6 +1,7 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
@@ -10,7 +11,6 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
@@ -20,7 +20,9 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/dustin/go-humanize"
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/pdevine/readline"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
@@ -28,11 +30,30 @@ import (
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/format"
|
||||
"github.com/jmorganca/ollama/progressbar"
|
||||
"github.com/jmorganca/ollama/readline"
|
||||
"github.com/jmorganca/ollama/server"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
||||
type Painter struct {
|
||||
IsMultiLine bool
|
||||
}
|
||||
|
||||
func (p Painter) Paint(line []rune, _ int) []rune {
|
||||
termType := os.Getenv("TERM")
|
||||
if termType == "xterm-256color" && len(line) == 0 {
|
||||
var prompt string
|
||||
if p.IsMultiLine {
|
||||
prompt = "Use \"\"\" to end multi-line input"
|
||||
} else {
|
||||
prompt = "Send a message (/? for help)"
|
||||
}
|
||||
return []rune(fmt.Sprintf("\033[38;5;245m%s\033[%dD\033[0m", prompt, len(prompt)))
|
||||
}
|
||||
// add a space and a backspace to prevent the cursor from walking up the screen
|
||||
line = append(line, []rune(" \b")...)
|
||||
return line
|
||||
}
|
||||
|
||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
filename, _ := cmd.Flags().GetString("file")
|
||||
filename, err := filepath.Abs(filename)
|
||||
@@ -57,12 +78,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
spinner.Stop()
|
||||
}
|
||||
currentDigest = resp.Digest
|
||||
// pulling
|
||||
bar = progressbar.DefaultBytes(
|
||||
resp.Total,
|
||||
resp.Status,
|
||||
)
|
||||
bar.Set64(resp.Completed)
|
||||
switch {
|
||||
case strings.Contains(resp.Status, "embeddings"):
|
||||
bar = progressbar.Default(resp.Total, resp.Status)
|
||||
bar.Set64(resp.Completed)
|
||||
default:
|
||||
// pulling
|
||||
bar = progressbar.DefaultBytes(
|
||||
resp.Total,
|
||||
resp.Status,
|
||||
)
|
||||
bar.Set64(resp.Completed)
|
||||
}
|
||||
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
||||
bar.Set64(resp.Completed)
|
||||
} else {
|
||||
@@ -97,16 +124,19 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
// check if the model exists on the server
|
||||
_, err = client.Show(context.Background(), &api.ShowRequest{Name: name})
|
||||
var statusError api.StatusError
|
||||
switch {
|
||||
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
|
||||
if err := PullHandler(cmd, args); err != nil {
|
||||
return err
|
||||
models, err := client.List(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
canonicalModelPath := server.ParseModelPath(args[0])
|
||||
for _, model := range models.Models {
|
||||
if model.Name == canonicalModelPath.GetShortTagname() {
|
||||
return RunGenerate(cmd, args)
|
||||
}
|
||||
case err != nil:
|
||||
}
|
||||
|
||||
if err := PullHandler(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -172,7 +202,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
for _, m := range models.Models {
|
||||
if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
|
||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
|
||||
data = append(data, []string{m.Name, m.Digest[:12], humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -348,49 +378,34 @@ func pull(model string, insecure bool) error {
|
||||
}
|
||||
|
||||
func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(args) > 1 {
|
||||
// join all args into a single prompt
|
||||
wordWrap := false
|
||||
if term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
wordWrap = true
|
||||
}
|
||||
|
||||
prompts := args[1:]
|
||||
|
||||
// prepend stdin to the prompt if provided
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
in, err := io.ReadAll(os.Stdin)
|
||||
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nowrap {
|
||||
wordWrap = false
|
||||
}
|
||||
|
||||
prompts = append([]string{string(in)}, prompts...)
|
||||
return generate(cmd, args[0], strings.Join(args[1:], " "), wordWrap)
|
||||
}
|
||||
|
||||
// output is being piped
|
||||
if !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
return generate(cmd, args[0], strings.Join(prompts, " "), false, format)
|
||||
if readline.IsTerminal(int(os.Stdin.Fd())) {
|
||||
return generateInteractive(cmd, args[0])
|
||||
}
|
||||
|
||||
wordWrap := os.Getenv("TERM") == "xterm-256color"
|
||||
|
||||
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nowrap {
|
||||
wordWrap = false
|
||||
}
|
||||
|
||||
// prompts are provided via stdin or args so don't enter interactive mode
|
||||
if len(prompts) > 0 {
|
||||
return generate(cmd, args[0], strings.Join(prompts, " "), wordWrap, format)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, args[0], wordWrap, format)
|
||||
return generateBatch(cmd, args[0])
|
||||
}
|
||||
|
||||
type generateContextKey string
|
||||
|
||||
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format string) error {
|
||||
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -406,7 +421,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
|
||||
generateContext = []int{}
|
||||
}
|
||||
|
||||
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
termWidth, _, err := term.GetSize(int(0))
|
||||
if err != nil {
|
||||
wordWrap = false
|
||||
}
|
||||
@@ -427,7 +442,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
|
||||
var currentLineLength int
|
||||
var wordBuffer string
|
||||
|
||||
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format}
|
||||
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
if !spinner.IsFinished() {
|
||||
spinner.Finish()
|
||||
@@ -498,12 +513,39 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error {
|
||||
// load the model
|
||||
if err := generate(cmd, model, "", false, ""); err != nil {
|
||||
func generateInteractive(cmd *cobra.Command, model string) error {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load the model
|
||||
if err := generate(cmd, model, "", false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
completer := readline.NewPrefixCompleter(
|
||||
readline.PcItem("/help"),
|
||||
readline.PcItem("/list"),
|
||||
readline.PcItem("/set",
|
||||
readline.PcItem("history"),
|
||||
readline.PcItem("nohistory"),
|
||||
readline.PcItem("wordwrap"),
|
||||
readline.PcItem("nowordwrap"),
|
||||
readline.PcItem("verbose"),
|
||||
readline.PcItem("quiet"),
|
||||
),
|
||||
readline.PcItem("/show",
|
||||
readline.PcItem("license"),
|
||||
readline.PcItem("modelfile"),
|
||||
readline.PcItem("parameters"),
|
||||
readline.PcItem("system"),
|
||||
readline.PcItem("template"),
|
||||
),
|
||||
readline.PcItem("/exit"),
|
||||
readline.PcItem("/bye"),
|
||||
)
|
||||
|
||||
usage := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
@@ -521,8 +563,6 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
|
||||
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
|
||||
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
|
||||
fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode")
|
||||
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
|
||||
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
@@ -538,32 +578,47 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
prompt := readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
var painter Painter
|
||||
|
||||
config := readline.Config{
|
||||
Painter: &painter,
|
||||
Prompt: ">>> ",
|
||||
HistoryFile: filepath.Join(home, ".ollama", "history"),
|
||||
AutoComplete: completer,
|
||||
}
|
||||
|
||||
scanner, err := readline.New(prompt)
|
||||
scanner, err := readline.NewEx(&config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer scanner.Close()
|
||||
|
||||
fmt.Print(readline.StartBracketedPaste)
|
||||
defer fmt.Printf(readline.EndBracketedPaste)
|
||||
var wordWrap bool
|
||||
termType := os.Getenv("TERM")
|
||||
if termType == "xterm-256color" {
|
||||
wordWrap = true
|
||||
}
|
||||
|
||||
// override wrapping if the user turned it off
|
||||
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nowrap {
|
||||
wordWrap = false
|
||||
}
|
||||
|
||||
var multiLineBuffer string
|
||||
var isMultiLine bool
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
fmt.Println()
|
||||
return nil
|
||||
case errors.Is(err, readline.ErrInterrupt):
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl-D or /bye to exit.")
|
||||
fmt.Println("Use Ctrl-D or /bye to exit.")
|
||||
}
|
||||
|
||||
continue
|
||||
@@ -574,19 +629,23 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
switch {
|
||||
case scanner.Prompt.UseAlt:
|
||||
case isMultiLine:
|
||||
if strings.HasSuffix(line, `"""`) {
|
||||
scanner.Prompt.UseAlt = false
|
||||
isMultiLine = false
|
||||
painter.IsMultiLine = isMultiLine
|
||||
multiLineBuffer += strings.TrimSuffix(line, `"""`)
|
||||
line = multiLineBuffer
|
||||
multiLineBuffer = ""
|
||||
scanner.SetPrompt(">>> ")
|
||||
} else {
|
||||
multiLineBuffer += line + " "
|
||||
continue
|
||||
}
|
||||
case strings.HasPrefix(line, `"""`):
|
||||
scanner.Prompt.UseAlt = true
|
||||
isMultiLine = true
|
||||
painter.IsMultiLine = isMultiLine
|
||||
multiLineBuffer = strings.TrimPrefix(line, `"""`) + " "
|
||||
scanner.SetPrompt("... ")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/list"):
|
||||
args := strings.Fields(line)
|
||||
@@ -613,16 +672,19 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||
case "quiet":
|
||||
cmd.Flags().Set("verbose", "false")
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
case "mode":
|
||||
if len(args) > 2 {
|
||||
switch args[2] {
|
||||
case "vim":
|
||||
scanner.SetVimMode(true)
|
||||
case "emacs", "default":
|
||||
scanner.SetVimMode(false)
|
||||
default:
|
||||
usage()
|
||||
}
|
||||
} else {
|
||||
format = args[2]
|
||||
fmt.Printf("Set format to '%s' mode.\n", args[2])
|
||||
usage()
|
||||
}
|
||||
case "noformat":
|
||||
format = ""
|
||||
fmt.Println("Disabled format.")
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
@@ -632,12 +694,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||
case strings.HasPrefix(line, "/show"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model})
|
||||
resp, err := server.GetModelInfo(model)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model")
|
||||
return err
|
||||
@@ -696,13 +753,26 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||
}
|
||||
|
||||
if len(line) > 0 && line[0] != '/' {
|
||||
if err := generate(cmd, model, line, wordWrap, format); err != nil {
|
||||
if err := generate(cmd, model, line, wordWrap); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func generateBatch(cmd *cobra.Command, model string) error {
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
for scanner.Scan() {
|
||||
prompt := scanner.Text()
|
||||
fmt.Printf(">>> %s\n", prompt)
|
||||
if err := generate(cmd, model, prompt, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||
host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST"))
|
||||
if err != nil {
|
||||
@@ -726,6 +796,21 @@ func RunServer(cmd *cobra.Command, _ []string) error {
|
||||
origins = strings.Split(o, ",")
|
||||
}
|
||||
|
||||
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
||||
if err := server.PruneLayers(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manifestsPath, err := server.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := server.PruneDirectory(manifestsPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return server.Serve(ln, origins)
|
||||
}
|
||||
|
||||
@@ -848,7 +933,7 @@ func NewCLI() *cobra.Command {
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model from a Modelfile",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CreateHandler,
|
||||
}
|
||||
@@ -858,7 +943,7 @@ func NewCLI() *cobra.Command {
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show MODEL",
|
||||
Short: "Show information for a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: ShowHandler,
|
||||
}
|
||||
@@ -880,20 +965,18 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||
|
||||
serveCmd := &cobra.Command{
|
||||
Use: "serve",
|
||||
Aliases: []string{"start"},
|
||||
Short: "Start ollama",
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: RunServer,
|
||||
}
|
||||
|
||||
pullCmd := &cobra.Command{
|
||||
Use: "pull MODEL",
|
||||
Short: "Pull a model from a registry",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: PullHandler,
|
||||
}
|
||||
@@ -903,7 +986,7 @@ func NewCLI() *cobra.Command {
|
||||
pushCmd := &cobra.Command{
|
||||
Use: "push MODEL",
|
||||
Short: "Push a model to a registry",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: PushHandler,
|
||||
}
|
||||
@@ -919,15 +1002,15 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
|
||||
copyCmd := &cobra.Command{
|
||||
Use: "cp SOURCE TARGET",
|
||||
Use: "cp",
|
||||
Short: "Copy a model",
|
||||
Args: cobra.ExactArgs(2),
|
||||
Args: cobra.MinimumNArgs(2),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CopyHandler,
|
||||
}
|
||||
|
||||
deleteCmd := &cobra.Command{
|
||||
Use: "rm MODEL [MODEL...]",
|
||||
Use: "rm",
|
||||
Short: "Remove a model",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
|
314
docs/api.md
314
docs/api.md
@@ -41,36 +41,28 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Currently the only accepted value is `json`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `system`: system prompt to (overrides what is defined in the `Modelfile`)
|
||||
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
|
||||
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself.
|
||||
- `stream`: if `false` the response will be be returned as a single response object, rather than a stream of objects
|
||||
|
||||
### JSON mode
|
||||
|
||||
Enable JSON mode by setting the `format` parameter to `json` and specifying the model should use JSON in the `prompt`. This will structure the response as valid JSON. See the JSON mode [example](#request-json-mode) below.
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
### Request
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
"model": "llama2:7b",
|
||||
"prompt": "Why is the sky blue?"
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
### Response
|
||||
|
||||
A stream of JSON objects is returned:
|
||||
A stream of JSON objects:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"model": "llama2:7b",
|
||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||
"response": "The",
|
||||
"done": false
|
||||
@@ -94,7 +86,7 @@ To calculate how fast the response is generated in tokens per second (token/s),
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"model": "llama2:7b",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"response": "",
|
||||
"context": [1, 2, 3],
|
||||
@@ -110,182 +102,6 @@ To calculate how fast the response is generated in tokens per second (token/s),
|
||||
}
|
||||
```
|
||||
|
||||
#### Request (No streaming)
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2:7b",
|
||||
"prompt": "Why is the sky blue?",
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
If `stream` is set to `false`, the response will be a single JSON object:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2:7b",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"response": "The sky is blue because it is the color of the sky.",
|
||||
"context": [1, 2, 3],
|
||||
"done": true,
|
||||
"total_duration": 5589157167,
|
||||
"load_duration": 3013701500,
|
||||
"sample_count": 114,
|
||||
"sample_duration": 81442000,
|
||||
"prompt_eval_count": 46,
|
||||
"prompt_eval_duration": 1160282000,
|
||||
"eval_count": 13,
|
||||
"eval_duration": 1325948000
|
||||
}
|
||||
```
|
||||
|
||||
#### Request (Raw mode)
|
||||
|
||||
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context.
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/generate -d '{
|
||||
"model": "mistral",
|
||||
"prompt": "[INST] why is the sky blue? [/INST]",
|
||||
"raw": true,
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "mistral",
|
||||
"created_at": "2023-11-03T15:36:02.583064Z",
|
||||
"response": " The sky appears blue because of a phenomenon called Rayleigh scattering.",
|
||||
"done": true,
|
||||
"total_duration": 14648695333,
|
||||
"load_duration": 3302671417,
|
||||
"prompt_eval_count": 14,
|
||||
"prompt_eval_duration": 286243000,
|
||||
"eval_count": 129,
|
||||
"eval_duration": 10931424000
|
||||
}
|
||||
```
|
||||
|
||||
#### Request (JSON mode)
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
"prompt": "What color is the sky at different times of the day? Respond using JSON",
|
||||
"format": "json",
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"created_at": "2023-11-09T21:07:55.186497Z",
|
||||
"response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n",
|
||||
"done": true,
|
||||
"total_duration": 4661289125,
|
||||
"load_duration": 1714434500,
|
||||
"prompt_eval_count": 36,
|
||||
"prompt_eval_duration": 264132000,
|
||||
"eval_count": 75,
|
||||
"eval_duration": 2112149000
|
||||
}
|
||||
```
|
||||
|
||||
The value of `response` will be a string containing JSON similar to:
|
||||
|
||||
```json
|
||||
{
|
||||
"morning": {
|
||||
"color": "blue"
|
||||
},
|
||||
"noon": {
|
||||
"color": "blue-gray"
|
||||
},
|
||||
"afternoon": {
|
||||
"color": "warm gray"
|
||||
},
|
||||
"evening": {
|
||||
"color": "orange"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Request (With options)
|
||||
|
||||
If you want to set custom options for the model at runtime rather than in the Modelfile, you can do so with the `options` parameter. This example sets every available option, but you can set any of them individually and omit the ones you do not want to override.
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2:7b",
|
||||
"prompt": "Why is the sky blue?",
|
||||
"stream": false,
|
||||
"options": {
|
||||
"num_keep": 5,
|
||||
"seed": 42,
|
||||
"num_predict": 100,
|
||||
"top_k": 20,
|
||||
"top_p": 0.9,
|
||||
"tfs_z": 0.5,
|
||||
"typical_p": 0.7,
|
||||
"repeat_last_n": 33,
|
||||
"temperature": 0.8,
|
||||
"repeat_penalty": 1.2,
|
||||
"presence_penalty": 1.5,
|
||||
"frequency_penalty": 1.0,
|
||||
"mirostat": 1,
|
||||
"mirostat_tau": 0.8,
|
||||
"mirostat_eta": 0.6,
|
||||
"penalize_newline": true,
|
||||
"stop": ["\n", "user:"],
|
||||
"numa": false,
|
||||
"num_ctx": 4,
|
||||
"num_batch": 2,
|
||||
"num_gqa": 1,
|
||||
"num_gpu": 1,
|
||||
"main_gpu": 0,
|
||||
"low_vram": false,
|
||||
"f16_kv": true,
|
||||
"logits_all": false,
|
||||
"vocab_only": false,
|
||||
"use_mmap": true,
|
||||
"use_mlock": false,
|
||||
"embedding_only": false,
|
||||
"rope_frequency_base": 1.1,
|
||||
"rope_frequency_scale": 0.8,
|
||||
"num_thread": 8
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2:7b",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"response": "The sky is blue because it is the color of the sky.",
|
||||
"context": [1, 2, 3],
|
||||
"done": true,
|
||||
"total_duration": 5589157167,
|
||||
"load_duration": 3013701500,
|
||||
"sample_count": 114,
|
||||
"sample_duration": 81442000,
|
||||
"prompt_eval_count": 46,
|
||||
"prompt_eval_duration": 1160282000,
|
||||
"eval_count": 13,
|
||||
"eval_duration": 1325948000
|
||||
}
|
||||
```
|
||||
|
||||
## Create a Model
|
||||
|
||||
```shell
|
||||
@@ -298,11 +114,9 @@ Create a model from a [`Modelfile`](./modelfile.md)
|
||||
|
||||
- `name`: name of the model to create
|
||||
- `path`: path to the Modelfile
|
||||
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `stream`: (optional) if `false` the response will be be returned as a single response object, rather than a stream of objects
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
### Request
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/create -d '{
|
||||
@@ -311,7 +125,7 @@ curl -X POST http://localhost:11434/api/create -d '{
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
### Response
|
||||
|
||||
A stream of JSON objects. When finished, `status` is `success`.
|
||||
|
||||
@@ -329,17 +143,13 @@ GET /api/tags
|
||||
|
||||
List models that are available locally.
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/tags
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
A single JSON object will be returned.
|
||||
### Response
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -370,9 +180,7 @@ Show details about a model including modelfile, template, parameters, license, a
|
||||
|
||||
- `name`: name of the model to show
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/show -d '{
|
||||
@@ -380,14 +188,14 @@ curl http://localhost:11434/api/show -d '{
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"license": "<contents of license block>",
|
||||
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llama2:latest\n\nFROM /Users/username/.ollama/models/blobs/sha256:8daa9615cce30c259a9555b1cc250d461d1bc69980a274b44d7eda0be78076d8\nTEMPLATE \"\"\"[INST] <<SYS>>{{ .System }}<</SYS>>\n\n{{ .Prompt }} [/INST] \"\"\"\nSYSTEM \"\"\"\"\"\"\nPARAMETER stop [INST]\nPARAMETER stop [/INST]\nPARAMETER stop <<SYS>>\nPARAMETER stop <</SYS>>\n",
|
||||
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llama2:latest\n\nFROM /Users/username/.ollama/models/blobs/sha256:8daa9615cce30c259a9555b1cc250d461d1bc69980a274b44d7eda0be78076d8\nTEMPLATE \"\"\"[INST] {{ if and .First .System }}<<SYS>>{{ .System }}<</SYS>>\n\n{{ end }}{{ .Prompt }} [/INST] \"\"\"\nSYSTEM \"\"\"\"\"\"\nPARAMETER stop [INST]\nPARAMETER stop [/INST]\nPARAMETER stop <<SYS>>\nPARAMETER stop <</SYS>>\n",
|
||||
"parameters": "stop [INST]\nstop [/INST]\nstop <<SYS>>\nstop <</SYS>>",
|
||||
"template": "[INST] <<SYS>>{{ .System }}<</SYS>>\n\n{{ .Prompt }} [/INST] "
|
||||
"template": "[INST] {{ if and .First .System }}<<SYS>>{{ .System }}<</SYS>>\n\n{{ end }}{{ .Prompt }} [/INST] "
|
||||
}
|
||||
```
|
||||
|
||||
@@ -399,9 +207,7 @@ POST /api/copy
|
||||
|
||||
Copy a model. Creates a model with another name from an existing model.
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/copy -d '{
|
||||
@@ -410,10 +216,6 @@ curl http://localhost:11434/api/copy -d '{
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
The only response is a 200 OK if successful.
|
||||
|
||||
## Delete a Model
|
||||
|
||||
```shell
|
||||
@@ -424,11 +226,9 @@ Delete a model and its data.
|
||||
|
||||
### Parameters
|
||||
|
||||
- `name`: model name to delete
|
||||
- `model`: model name to delete
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
### Request
|
||||
|
||||
```shell
|
||||
curl -X DELETE http://localhost:11434/api/delete -d '{
|
||||
@@ -436,10 +236,6 @@ curl -X DELETE http://localhost:11434/api/delete -d '{
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
If successful, the only response is a 200 OK.
|
||||
|
||||
## Pull a Model
|
||||
|
||||
```shell
|
||||
@@ -452,11 +248,9 @@ Download a model from the ollama library. Cancelled pulls are resumed from where
|
||||
|
||||
- `name`: name of the model to pull
|
||||
- `insecure`: (optional) allow insecure connections to the library. Only use this if you are pulling from your own library during development.
|
||||
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `stream`: (optional) if `false` the response will be be returned as a single response object, rather than a stream of objects
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
### Request
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/pull -d '{
|
||||
@@ -464,51 +258,13 @@ curl -X POST http://localhost:11434/api/pull -d '{
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
If `stream` is not specified, or set to `true`, a stream of JSON objects is returned:
|
||||
|
||||
The first object is the manifest:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "pulling manifest"
|
||||
}
|
||||
```
|
||||
|
||||
Then there is a series of downloading responses. Until any of the download is completed, the `completed` key may not be included. The number of files to be downloaded depends on the number of layers specified in the manifest.
|
||||
### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "downloading digestname",
|
||||
"digest": "digestname",
|
||||
"total": 2142590208,
|
||||
"completed": 241970
|
||||
}
|
||||
```
|
||||
|
||||
After all the files are downloaded, the final responses are:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "verifying sha256 digest"
|
||||
}
|
||||
{
|
||||
"status": "writing manifest"
|
||||
}
|
||||
{
|
||||
"status": "removing any unused layers"
|
||||
}
|
||||
{
|
||||
"status": "success"
|
||||
}
|
||||
```
|
||||
|
||||
if `stream` is set to false, then the response is a single JSON object:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "success"
|
||||
"total": 2142590208
|
||||
}
|
||||
```
|
||||
|
||||
@@ -524,11 +280,9 @@ Upload a model to a model library. Requires registering for ollama.ai and adding
|
||||
|
||||
- `name`: name of the model to push in the form of `<namespace>/<model>:<tag>`
|
||||
- `insecure`: (optional) allow insecure connections to the library. Only use this if you are pushing to your library during development.
|
||||
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `stream`: (optional) if `false` the response will be be returned as a single response object, rather than a stream of objects
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
### Request
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/push -d '{
|
||||
@@ -536,9 +290,9 @@ curl -X POST http://localhost:11434/api/push -d '{
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
### Response
|
||||
|
||||
If `stream` is not specified, or set to `true`, a stream of JSON objects is returned:
|
||||
Streaming response that starts with:
|
||||
|
||||
```json
|
||||
{ "status": "retrieving manifest" }
|
||||
@@ -571,12 +325,6 @@ Finally, when the upload is complete:
|
||||
{"status":"success"}
|
||||
```
|
||||
|
||||
If `stream` is set to `false`, then the response is a single JSON object:
|
||||
|
||||
```json
|
||||
{ "status": "success" }
|
||||
```
|
||||
|
||||
## Generate Embeddings
|
||||
|
||||
```shell
|
||||
@@ -594,9 +342,7 @@ Advanced parameters:
|
||||
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
### Request
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/embeddings -d '{
|
||||
@@ -605,11 +351,11 @@ curl -X POST http://localhost:11434/api/embeddings -d '{
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"embedding": [
|
||||
"embeddings": [
|
||||
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
|
||||
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
|
||||
]
|
||||
|
88
docs/faq.md
88
docs/faq.md
@@ -1,98 +1,18 @@
|
||||
# FAQ
|
||||
|
||||
## How can I view the logs?
|
||||
|
||||
On macOS:
|
||||
|
||||
```
|
||||
cat ~/.ollama/logs/server.log
|
||||
```
|
||||
|
||||
On Linux:
|
||||
|
||||
```
|
||||
journalctl -u ollama
|
||||
```
|
||||
|
||||
If you're running `ollama serve` directly, the logs will be printed to the console.
|
||||
|
||||
## How can I expose Ollama on my network?
|
||||
|
||||
Ollama binds to 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable.
|
||||
|
||||
On macOS:
|
||||
## How can I expose the Ollama server?
|
||||
|
||||
```bash
|
||||
OLLAMA_HOST=0.0.0.0:11435 ollama serve
|
||||
```
|
||||
|
||||
On Linux:
|
||||
|
||||
Create a `systemd` drop-in directory and set `Environment=OLLAMA_HOST`
|
||||
|
||||
```bash
|
||||
mkdir -p /etc/systemd/system/ollama.service.d
|
||||
echo "[Service]" >>/etc/systemd/system/ollama.service.d/environment.conf
|
||||
```
|
||||
|
||||
```bash
|
||||
echo "Environment=OLLAMA_HOST=0.0.0.0:11434" >>/etc/systemd/system/ollama.service.d/environment.conf
|
||||
```
|
||||
|
||||
Reload `systemd` and restart Ollama:
|
||||
|
||||
```bash
|
||||
systemctl daemon-reload
|
||||
systemctl restart ollama
|
||||
```
|
||||
|
||||
## How can I allow additional web origins to access Ollama?
|
||||
|
||||
Ollama allows cross origin requests from `127.0.0.1` and `0.0.0.0` by default. Add additional origins with the `OLLAMA_ORIGINS` environment variable:
|
||||
|
||||
On macOS:
|
||||
By default, Ollama allows cross origin requests from `127.0.0.1` and `0.0.0.0`. To support more origins, you can use the `OLLAMA_ORIGINS` environment variable:
|
||||
|
||||
```bash
|
||||
OLLAMA_ORIGINS=http://192.168.1.1:*,https://example.com ollama serve
|
||||
```
|
||||
|
||||
On Linux:
|
||||
|
||||
```bash
|
||||
echo "Environment=OLLAMA_ORIGINS=http://129.168.1.1:*,https://example.com" >>/etc/systemd/system/ollama.service.d/environment.conf
|
||||
```
|
||||
|
||||
Reload `systemd` and restart Ollama:
|
||||
|
||||
```bash
|
||||
systemctl daemon-reload
|
||||
systemctl restart ollama
|
||||
```
|
||||
|
||||
## Where are models stored?
|
||||
|
||||
- macOS: Raw model data is stored under `~/.ollama/models`.
|
||||
- Linux: Raw model data is stored under `/usr/share/ollama/.ollama/models`
|
||||
|
||||
|
||||
|
||||
Below the models directory you will find a structure similar to the following:
|
||||
|
||||
```shell
|
||||
.
|
||||
├── blobs
|
||||
└── manifests
|
||||
└── registry.ollama.ai
|
||||
├── f0rodo
|
||||
├── library
|
||||
├── mattw
|
||||
└── saikatkumardey
|
||||
```
|
||||
|
||||
There is a `manifests/registry.ollama.ai/namespace` path. In example above, the user has downloaded models from the official `library`, `f0rodo`, `mattw`, and `saikatkumardey` namespaces. Within each of those directories, you will find directories for each of the models downloaded. And in there you will find a file name representing each tag. Each tag file is the manifest for the model.
|
||||
|
||||
The manifest lists all the layers used in this model. You will see a `media type` for each layer, along with a digest. That digest corresponds with a file in the `models/blobs directory`.
|
||||
|
||||
### How can I change where Ollama stores models?
|
||||
|
||||
To modify where models are stored, you can use the `OLLAMA_MODELS` environment variable. Note that on Linux this means defining `OLLAMA_MODELS` in a drop-in `/etc/systemd/system/ollama.service.d` service file, reloading systemd, and restarting the ollama service.
|
||||
* macOS: Raw model data is stored under `~/.ollama/models`.
|
||||
* Linux: Raw model data is stored under `/usr/share/ollama/.ollama/models`
|
||||
|
198
docs/import.md
198
docs/import.md
@@ -1,198 +0,0 @@
|
||||
# Import a model
|
||||
|
||||
This guide walks through importing a GGUF, PyTorch or Safetensors model.
|
||||
|
||||
## Importing (GGUF)
|
||||
|
||||
### Step 1: Write a `Modelfile`
|
||||
|
||||
Start by creating a `Modelfile`. This file is the blueprint for your model, specifying weights, parameters, prompt templates and more.
|
||||
|
||||
```
|
||||
FROM ./mistral-7b-v0.1.Q4_0.gguf
|
||||
```
|
||||
|
||||
(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`:
|
||||
|
||||
```
|
||||
FROM ./q4_0.bin
|
||||
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
|
||||
```
|
||||
|
||||
### Step 2: Create the Ollama model
|
||||
|
||||
Finally, create a model from your `Modelfile`:
|
||||
|
||||
```
|
||||
ollama create example -f Modelfile
|
||||
```
|
||||
|
||||
### Step 3: Run your model
|
||||
|
||||
Next, test the model with `ollama run`:
|
||||
|
||||
```
|
||||
ollama run example "What is your favourite condiment?"
|
||||
```
|
||||
|
||||
## Importing (PyTorch & Safetensors)
|
||||
|
||||
### Supported models
|
||||
|
||||
Ollama supports a set of model architectures, with support for more coming soon:
|
||||
|
||||
- Llama & Mistral
|
||||
- Falcon & RW
|
||||
- GPT-NeoX
|
||||
- BigCode
|
||||
|
||||
To view a model's architecture, check the `config.json` file in its HuggingFace repo. You should see an entry under `architectures` (e.g. `LlamaForCausalLM`).
|
||||
|
||||
### Step 1: Clone the HuggingFace repository (optional)
|
||||
|
||||
If the model is currently hosted in a HuggingFace repository, first clone that repository to download the raw model.
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
||||
cd Mistral-7B-Instruct-v0.1
|
||||
```
|
||||
|
||||
### Step 2: Convert and quantize to a `.bin` file (optional, for PyTorch and Safetensors)
|
||||
|
||||
If the model is in PyTorch or Safetensors format, a [Docker image](https://hub.docker.com/r/ollama/quantize) with the tooling required to convert and quantize models is available.
|
||||
|
||||
First, Install [Docker](https://www.docker.com/get-started/).
|
||||
|
||||
Next, to convert and quantize your model, run:
|
||||
|
||||
```
|
||||
docker run --rm -v .:/model ollama/quantize -q q4_0 /model
|
||||
```
|
||||
|
||||
This will output two files into the directory:
|
||||
|
||||
- `f16.bin`: the model converted to GGUF
|
||||
- `q4_0.bin` the model quantized to a 4-bit quantization (we will use this file to create the Ollama model)
|
||||
|
||||
### Step 3: Write a `Modelfile`
|
||||
|
||||
Next, create a `Modelfile` for your model:
|
||||
|
||||
```
|
||||
FROM ./q4_0.bin
|
||||
```
|
||||
|
||||
(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`:
|
||||
|
||||
```
|
||||
FROM ./q4_0.bin
|
||||
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
|
||||
```
|
||||
|
||||
### Step 4: Create the Ollama model
|
||||
|
||||
Finally, create a model from your `Modelfile`:
|
||||
|
||||
```
|
||||
ollama create example -f Modelfile
|
||||
```
|
||||
|
||||
### Step 5: Run your model
|
||||
|
||||
Next, test the model with `ollama run`:
|
||||
|
||||
```
|
||||
ollama run example "What is your favourite condiment?"
|
||||
```
|
||||
|
||||
## Publishing your model (optional – early alpha)
|
||||
|
||||
Publishing models is in early alpha. If you'd like to publish your model to share with others, follow these steps:
|
||||
|
||||
1. Create [an account](https://ollama.ai/signup)
|
||||
2. Run `cat ~/.ollama/id_ed25519.pub` to view your Ollama public key. Copy this to the clipboard.
|
||||
3. Add your public key to your [Ollama account](https://ollama.ai/settings/keys)
|
||||
|
||||
Next, copy your model to your username's namespace:
|
||||
|
||||
```
|
||||
ollama cp example <your username>/example
|
||||
```
|
||||
|
||||
Then push the model:
|
||||
|
||||
```
|
||||
ollama push <your username>/example
|
||||
```
|
||||
|
||||
After publishing, your model will be available at `https://ollama.ai/<your username>/example`.
|
||||
|
||||
## Quantization reference
|
||||
|
||||
The quantization options are as follow (from highest highest to lowest levels of quantization). Note: some architectures such as Falcon do not support K quants.
|
||||
|
||||
- `q2_K`
|
||||
- `q3_K`
|
||||
- `q3_K_S`
|
||||
- `q3_K_M`
|
||||
- `q3_K_L`
|
||||
- `q4_0` (recommended)
|
||||
- `q4_1`
|
||||
- `q4_K`
|
||||
- `q4_K_S`
|
||||
- `q4_K_M`
|
||||
- `q5_0`
|
||||
- `q5_1`
|
||||
- `q5_K`
|
||||
- `q5_K_S`
|
||||
- `q5_K_M`
|
||||
- `q6_K`
|
||||
- `q8_0`
|
||||
|
||||
## Manually converting & quantizing models
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Start by cloning the `llama.cpp` repo to your machine in another directory:
|
||||
|
||||
```
|
||||
git clone https://github.com/ggerganov/llama.cpp.git
|
||||
cd llama.cpp
|
||||
```
|
||||
|
||||
Next, install the Python dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Finally, build the `quantize` tool:
|
||||
|
||||
```
|
||||
make quantize
|
||||
```
|
||||
|
||||
### Convert the model
|
||||
|
||||
Run the correct conversion script for your model architecture:
|
||||
|
||||
```shell
|
||||
# LlamaForCausalLM or MistralForCausalLM
|
||||
python convert.py <path to model directory>
|
||||
|
||||
# FalconForCausalLM
|
||||
python convert-falcon-hf-to-gguf.py <path to model directory>
|
||||
|
||||
# GPTNeoXForCausalLM
|
||||
python convert-gptneox-hf-to-gguf.py <path to model directory>
|
||||
|
||||
# GPTBigCodeForCausalLM
|
||||
python convert-starcoder-hf-to-gguf.py <path to model directory>
|
||||
```
|
||||
|
||||
### Quantize the model
|
||||
|
||||
```
|
||||
quantize <path to model dir>/ggml-model-f32.bin <path to model dir>/q4_0.bin q4_0
|
||||
```
|
@@ -1,16 +1,12 @@
|
||||
# Ollama on Linux
|
||||
# Installing Ollama on Linux
|
||||
|
||||
## Install
|
||||
|
||||
Install Ollama running this one-liner:
|
||||
> Note: A one line installer for Ollama is available by running:
|
||||
>
|
||||
```bash
|
||||
curl https://ollama.ai/install.sh | sh
|
||||
```
|
||||
> ```bash
|
||||
> curl https://ollama.ai/install.sh | sh
|
||||
> ```
|
||||
|
||||
## Manual install
|
||||
|
||||
### Download the `ollama` binary
|
||||
## Download the `ollama` binary
|
||||
|
||||
Ollama is distributed as a self-contained binary. Download it to a directory in your PATH:
|
||||
|
||||
@@ -19,7 +15,31 @@ sudo curl -L https://ollama.ai/download/ollama-linux-amd64 -o /usr/bin/ollama
|
||||
sudo chmod +x /usr/bin/ollama
|
||||
```
|
||||
|
||||
### Adding Ollama as a startup service (recommended)
|
||||
## Start Ollama
|
||||
|
||||
Start Ollama by running `ollama serve`:
|
||||
|
||||
```bash
|
||||
ollama serve
|
||||
```
|
||||
|
||||
Once Ollama is running, run a model in another terminal session:
|
||||
|
||||
```bash
|
||||
ollama run llama2
|
||||
```
|
||||
|
||||
## Install CUDA drivers (optional – for Nvidia GPUs)
|
||||
|
||||
[Download and install](https://developer.nvidia.com/cuda-downloads) CUDA.
|
||||
|
||||
Verify that the drivers are installed by running the following command, which should print details about your GPU:
|
||||
|
||||
```bash
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
## Adding Ollama as a startup service (optional)
|
||||
|
||||
Create a user for Ollama:
|
||||
|
||||
@@ -40,6 +60,7 @@ User=ollama
|
||||
Group=ollama
|
||||
Restart=always
|
||||
RestartSec=3
|
||||
Environment="HOME=/usr/share/ollama"
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
||||
@@ -52,40 +73,7 @@ sudo systemctl daemon-reload
|
||||
sudo systemctl enable ollama
|
||||
```
|
||||
|
||||
### Install CUDA drivers (optional – for Nvidia GPUs)
|
||||
|
||||
[Download and install](https://developer.nvidia.com/cuda-downloads) CUDA.
|
||||
|
||||
Verify that the drivers are installed by running the following command, which should print details about your GPU:
|
||||
|
||||
```bash
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
### Start Ollama
|
||||
|
||||
Start Ollama using `systemd`:
|
||||
|
||||
```bash
|
||||
sudo systemctl start ollama
|
||||
```
|
||||
|
||||
## Update
|
||||
|
||||
Update ollama by running the install script again:
|
||||
|
||||
```bash
|
||||
curl https://ollama.ai/install.sh | sh
|
||||
```
|
||||
|
||||
Or by downloading the ollama binary:
|
||||
|
||||
```bash
|
||||
sudo curl -L https://ollama.ai/download/ollama-linux-amd64 -o /usr/bin/ollama
|
||||
sudo chmod +x /usr/bin/ollama
|
||||
```
|
||||
|
||||
## Viewing logs
|
||||
### Viewing logs
|
||||
|
||||
To view logs of Ollama running as a startup service, run:
|
||||
|
||||
@@ -93,24 +81,3 @@ To view logs of Ollama running as a startup service, run:
|
||||
journalctl -u ollama
|
||||
```
|
||||
|
||||
## Uninstall
|
||||
|
||||
Remove the ollama service:
|
||||
|
||||
```bash
|
||||
sudo systemctl stop ollama
|
||||
sudo systemctl disable ollama
|
||||
sudo rm /etc/systemd/system/ollama.service
|
||||
```
|
||||
|
||||
Remove the ollama binary from your bin directory (either `/usr/local/bin`, `/usr/bin`, or `/bin`):
|
||||
|
||||
```bash
|
||||
sudo rm $(which ollama)
|
||||
```
|
||||
|
||||
Remove the downloaded models and Ollama service user:
|
||||
```bash
|
||||
sudo rm -r /usr/share/ollama
|
||||
sudo userdel ollama
|
||||
```
|
||||
|
@@ -12,6 +12,7 @@ A model file is the blueprint to create and share models with Ollama.
|
||||
- [FROM (Required)](#from-required)
|
||||
- [Build from llama2](#build-from-llama2)
|
||||
- [Build from a bin file](#build-from-a-bin-file)
|
||||
- [EMBED](#embed)
|
||||
- [PARAMETER](#parameter)
|
||||
- [Valid Parameters and Values](#valid-parameters-and-values)
|
||||
- [TEMPLATE](#template)
|
||||
@@ -90,6 +91,17 @@ FROM ./ollama-model.bin
|
||||
|
||||
This bin file location should be specified as an absolute path or relative to the `Modelfile` location.
|
||||
|
||||
### EMBED
|
||||
|
||||
The `EMBED` instruction is used to add embeddings of files to a model. This is useful for adding custom data that the model can reference when generating an answer. Note that currently only text files are supported, formatted with each line as one embedding.
|
||||
|
||||
```modelfile
|
||||
FROM <model name>:<tag>
|
||||
EMBED <file path>.txt
|
||||
EMBED <different file path>.txt
|
||||
EMBED <path to directory>/*.txt
|
||||
```
|
||||
|
||||
### PARAMETER
|
||||
|
||||
The `PARAMETER` instruction defines a parameter that can be set when the model is run.
|
||||
@@ -112,8 +124,8 @@ PARAMETER <parameter> <parametervalue>
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
||||
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
||||
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. | int | seed 42 |
|
||||
| stop | Sets the stop sequences to use. | string | stop "AI assistant:" |
|
||||
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
|
||||
| num_predict | Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) | int | num_predict 42 |
|
||||
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
|
||||
@@ -129,11 +141,14 @@ PARAMETER <parameter> <parametervalue>
|
||||
| --------------- | ------------------------------------------------------------------------------------------------------------ |
|
||||
| `{{ .System }}` | The system prompt used to specify custom behavior, this must also be set in the Modelfile as an instruction. |
|
||||
| `{{ .Prompt }}` | The incoming prompt, this is not specified in the model file and will be set based on input. |
|
||||
| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. |
|
||||
|
||||
```modelfile
|
||||
TEMPLATE """
|
||||
{{- if .First }}
|
||||
### System:
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
|
||||
### User:
|
||||
{{ .Prompt }}
|
||||
|
111
docs/quantize.md
Normal file
111
docs/quantize.md
Normal file
@@ -0,0 +1,111 @@
|
||||
# How to Quantize a Model
|
||||
|
||||
Sometimes the model you want to work with is not available at [https://ollama.ai/library](https://ollama.ai/library).
|
||||
|
||||
## Figure out if we can run the model?
|
||||
|
||||
Not all models will work with Ollama. There are a number of factors that go into whether we are able to work with the next cool model. First it has to work with llama.cpp. Then we have to have implemented the features of llama.cpp that it requires. And then, sometimes, even with both of those, the model might not work...
|
||||
|
||||
1. What is the model you want to convert and upload?
|
||||
2. Visit the model's page on HuggingFace.
|
||||
3. Switch to the **Files and versions** tab.
|
||||
4. Click on the **config.json** file. If there is no config.json file, it may not work.
|
||||
5. Take note of the **architecture** list in the json file.
|
||||
6. Does any entry in the list match one of the following architectures?
|
||||
1. LlamaForCausalLM
|
||||
2. MistralForCausalLM
|
||||
3. RWForCausalLM
|
||||
4. FalconForCausalLM
|
||||
5. GPTNeoXForCausalLM
|
||||
6. GPTBigCodeForCausalLM
|
||||
7. If the answer is yes, then there is a good chance the model will run after being converted and quantized.
|
||||
8. An alternative to this process is to visit [https://caniquant.tvl.st](https://caniquant.tvl.st) and enter the org/modelname in the box and submit.
|
||||
|
||||
At this point there are two processes you can use. You can either use a Docker container to convert and quantize, OR you can manually run the scripts. The Docker container is the easiest way to do it, but it requires you to have Docker installed on your machine. If you don't have Docker installed, you can follow the manual process.
|
||||
|
||||
## Convert and Quantize with Docker
|
||||
|
||||
Run `docker run --rm -v /path/to/model/repo:/repo ollama/quantize -q quantlevel /repo`. For instance, if you have downloaded the latest Mistral 7B model, then clone it to your machine. Then change into that directory and you can run:
|
||||
|
||||
```shell
|
||||
docker run --rm -v .:/repo ollama/quantize -q q4_0 /repo
|
||||
```
|
||||
|
||||
You can find the different quantization levels below under **Quantize the Model**.
|
||||
|
||||
This will output two files into the directory. First is a f16.bin file that is the model converted to GGUF. The second file is a q4_0.bin file which is the model quantized to a 4 bit quantization. You should rename it to something more descriptive.
|
||||
|
||||
You can find the repository for the Docker container here: [https://github.com/mxyng/quantize](https://github.com/mxyng/quantize)
|
||||
|
||||
For instance, if you wanted to convert the Mistral 7B model to a Q4 quantized model, then you could go through the following steps:
|
||||
|
||||
1. First verify the model will potentially work.
|
||||
2. Now clone Mistral 7B to your machine. You can find the command to run when you click the three vertical dots button on the model page, then click **Clone Repository**.
|
||||
1. For this repo, the command is:
|
||||
|
||||
```shell
|
||||
git lfs install
|
||||
git clone https://huggingface.co/mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
2. Navigate into the new directory and run `docker run --rm -v .:/repo ollama/quantize -q q4_0 /repo`
|
||||
3. Now you can create a modelfile using the q4_0.bin file that was created.
|
||||
|
||||
## Convert and Quantize Manually
|
||||
|
||||
### Clone llama.cpp to your machine
|
||||
|
||||
If we know the model has a chance of working, then we need to convert and quantize. This is a matter of running two separate scripts in the llama.cpp project.
|
||||
|
||||
1. Decide where you want the llama.cpp repository on your machine.
|
||||
2. Navigate to that location and then run:
|
||||
[`git clone https://github.com/ggerganov/llama.cpp.git`](https://github.com/ggerganov/llama.cpp.git)
|
||||
1. If you don't have git installed, download this zip file and unzip it to that location: https://github.com/ggerganov/llama.cpp/archive/refs/heads/master.zip
|
||||
3. Install the Python dependencies: `pip install torch transformers sentencepiece`
|
||||
4. Run 'make' to build the project and the quantize executable.
|
||||
|
||||
### Convert the model to GGUF
|
||||
|
||||
1. Decide on the right convert script to run. What was the model architecture you found in the first section.
|
||||
1. LlamaForCausalLM or MistralForCausalLM:
|
||||
run `python3 convert.py <modelfilename>`
|
||||
No need to specify fp16 or fp32.
|
||||
2. FalconForCausalLM or RWForCausalLM:
|
||||
run `python3 convert-falcon-hf-to-gguf.py <modelfilename> <fpsize>`
|
||||
fpsize depends on the weight size. 1 for fp16, 0 for fp32
|
||||
3. GPTNeoXForCausalLM:
|
||||
run `python3 convert-gptneox-hf-to-gguf.py <modelfilename> <fpsize>`
|
||||
fpsize depends on the weight size. 1 for fp16, 0 for fp32
|
||||
4. GPTBigCodeForCausalLM:
|
||||
run `python3 convert-starcoder-hf-to-gguf.py <modelfilename> <fpsize>`
|
||||
fpsize depends on the weight size. 1 for fp16, 0 for fp32
|
||||
|
||||
### Quantize the model
|
||||
|
||||
If the model converted successfully, there is a good chance it will also quantize successfully. Now you need to decide on the quantization to use. We will always try to create all the quantizations and upload them to the library. You should decide which level is more important to you and quantize accordingly.
|
||||
|
||||
The quantization options are as follows. Note that some architectures such as Falcon do not support K quants.
|
||||
|
||||
- Q4_0
|
||||
- Q4_1
|
||||
- Q5_0
|
||||
- Q5_1
|
||||
- Q2_K
|
||||
- Q3_K
|
||||
- Q3_K_S
|
||||
- Q3_K_M
|
||||
- Q3_K_L
|
||||
- Q4_K
|
||||
- Q4_K_S
|
||||
- Q4_K_M
|
||||
- Q5_K
|
||||
- Q5_K_S
|
||||
- Q5_K_M
|
||||
- Q6_K
|
||||
- Q8_0
|
||||
|
||||
Run the following command `quantize <converted model from above> <output file> <quantization type>`
|
||||
|
||||
## Now Create the Model
|
||||
|
||||
Now you can create the Ollama model. Refer to the [modelfile](./modelfile.md) doc for more information on doing that.
|
@@ -4,6 +4,5 @@ Here is a list of ways you can use Ollama with other tools to build interesting
|
||||
|
||||
- [Using LangChain with Ollama in JavaScript](./tutorials/langchainjs.md)
|
||||
- [Using LangChain with Ollama in Python](./tutorials/langchainpy.md)
|
||||
- [Running Ollama on NVIDIA Jetson Devices](./tutorials/nvidia-jetson.md)
|
||||
|
||||
Also be sure to check out the [examples](../examples) directory for more ways to use Ollama.
|
||||
Also be sure to check out the [examples](../examples) directory for more ways to use Ollama.
|
@@ -23,17 +23,13 @@ const answer = await ollama.call(`why is the sky blue?`);
|
||||
console.log(answer);
|
||||
```
|
||||
|
||||
That will get us the same thing as if we ran `ollama run llama2 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
|
||||
|
||||
```bash
|
||||
npm install cheerio
|
||||
```
|
||||
That will get us the same thing as if we ran `ollama run llama2 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's build that part of the app.
|
||||
|
||||
```javascript
|
||||
import { CheerioWebBaseLoader } from "langchain/document_loaders/web/cheerio";
|
||||
|
||||
const loader = new CheerioWebBaseLoader("https://en.wikipedia.org/wiki/2023_Hawaii_wildfires");
|
||||
const data = await loader.load();
|
||||
const data = loader.load();
|
||||
```
|
||||
|
||||
That will load the document. Although this page is smaller than the Odyssey, it is certainly bigger than the context size for most LLMs. So we are going to need to split into smaller pieces, and then select just the pieces relevant to our question. This is a great use for a vector datastore. In this example, we will use the **MemoryVectorStore** that is part of **LangChain**. But there is one more thing we need to get the content into the datastore. We have to run an embeddings process that converts the tokens in the text into a series of vectors. And for that, we are going to use **Tensorflow**. There is a lot of stuff going on in this one. First, install the **Tensorflow** components that we need.
|
||||
|
@@ -1,38 +0,0 @@
|
||||
# Running Ollama on NVIDIA Jetson Devices
|
||||
|
||||
With some minor configuration, Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/). The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack).
|
||||
|
||||
NVIDIA Jetson devices are Linux-based embedded AI computers that are purpose-built for AI applications.
|
||||
|
||||
Jetsons have an integrated GPU that is wired directly to the memory controller of the machine. For this reason, the `nvidia-smi` command is unrecognized, and Ollama proceeds to operate in "CPU only"
|
||||
mode. This can be verified by using a monitoring tool like jtop.
|
||||
|
||||
In order to address this, we simply pass the path to the Jetson's pre-installed CUDA libraries into `ollama serve` (while in a tmux session). We then hardcode the num_gpu parameters into a cloned
|
||||
version of our target model.
|
||||
|
||||
Prerequisites:
|
||||
|
||||
- curl
|
||||
- tmux
|
||||
|
||||
Here are the steps:
|
||||
|
||||
- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.ai/install.sh | sh`
|
||||
- Stop the Ollama service: `sudo systemctl stop ollama`
|
||||
- Start Ollama serve in a tmux session called ollama_jetson and reference the CUDA libraries path: `tmux has-session -t ollama_jetson 2>/dev/null || tmux new-session -d -s ollama_jetson
|
||||
'LD_LIBRARY_PATH=/usr/local/cuda/lib64 ollama serve'`
|
||||
- Pull the model you want to use (e.g. mistral): `ollama pull mistral`
|
||||
- Create a new Modelfile specifically for enabling GPU support on the Jetson: `touch ModelfileMistralJetson`
|
||||
- In the ModelfileMistralJetson file, specify the FROM model and the num_gpu PARAMETER as shown below:
|
||||
|
||||
```
|
||||
FROM mistral
|
||||
PARAMETER num_gpu 999
|
||||
```
|
||||
|
||||
- Create a new model from your Modelfile: `ollama create mistral-jetson -f ./ModelfileMistralJetson`
|
||||
- Run the new model: `ollama run mistral-jetson`
|
||||
|
||||
If you run a monitoring tool like jtop you should now see that Ollama is using the Jetson's integrated GPU.
|
||||
|
||||
And that's it!
|
@@ -1,10 +0,0 @@
|
||||
# Bash Shell examples
|
||||
|
||||
When calling `ollama`, you can pass it a file to run all the prompts in the file, one after the other:
|
||||
|
||||
`ollama run llama2 < sourcequestions.txt`
|
||||
|
||||
This concept is used in the following example.
|
||||
|
||||
## Compare Models
|
||||
`comparemodels.sh` is a script that runs all the questions in `sourcequestions.txt` using any 4 models you choose that you have already pulled from the Ollama library or have created locally.
|
@@ -1,64 +0,0 @@
|
||||
#! /usr/bin/env bash
|
||||
# Compare multiple models by running them with the same questions
|
||||
|
||||
NUMBEROFCHOICES=4
|
||||
SELECTIONS=()
|
||||
declare -a SUMS=()
|
||||
|
||||
# Get the list of models
|
||||
CHOICES=$(ollama list | awk '{print $1}')
|
||||
|
||||
# Select which models to run as a comparison
|
||||
echo "Select $NUMBEROFCHOICES models to compare:"
|
||||
select ITEM in $CHOICES; do
|
||||
if [[ -n $ITEM ]]; then
|
||||
echo "You have selected $ITEM"
|
||||
SELECTIONS+=("$ITEM")
|
||||
((COUNT++))
|
||||
if [[ $COUNT -eq $NUMBEROFCHOICES ]]; then
|
||||
break
|
||||
fi
|
||||
else
|
||||
echo "Invalid selection"
|
||||
fi
|
||||
done
|
||||
|
||||
# Loop through each of the selected models
|
||||
for ITEM in "${SELECTIONS[@]}"; do
|
||||
echo "--------------------------------------------------------------"
|
||||
echo "Loading the model $ITEM into memory"
|
||||
ollama run "$ITEM" ""
|
||||
echo "--------------------------------------------------------------"
|
||||
echo "Running the questions through the model $ITEM"
|
||||
COMMAND_OUTPUT=$(ollama run "$ITEM" --verbose < sourcequestions.txt 2>&1| tee /dev/stderr)
|
||||
|
||||
# eval duration is sometimes listed in seconds and sometimes in milliseconds.
|
||||
# Add up the values for each model
|
||||
SUM=$(echo "$COMMAND_OUTPUT" | awk '
|
||||
/eval duration:/ {
|
||||
value = $3
|
||||
if (index(value, "ms") > 0) {
|
||||
gsub("ms", "", value)
|
||||
value /= 1000
|
||||
} else {
|
||||
gsub("s", "", value)
|
||||
}
|
||||
sum += value
|
||||
}
|
||||
END { print sum }')
|
||||
|
||||
|
||||
SUMS+=("All questions for $ITEM completed in $SUM seconds")
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "--------------------------------------------------------------"
|
||||
echo -e "Sums of eval durations for each run:"
|
||||
for val in "${SUMS[@]}"; do
|
||||
echo "$val"
|
||||
done
|
||||
|
||||
echo "--------------------------------------------------------------"
|
||||
echo "Comparison complete. Now you can decide"
|
||||
echo "which model is best."
|
||||
echo "--------------------------------------------------------------"
|
@@ -1,7 +0,0 @@
|
||||
Why is the sky blue
|
||||
What is a black hole
|
||||
Explain the big bang theory like I am 5?
|
||||
What is the quickest way to win a game of Monopoly with 3 others?
|
||||
Why does a vacuum bottle keep my coffee hot and my milkshake cold?
|
||||
What is the difference between a meteor, a meteorite, and a meteoroid?
|
||||
Create an array with 5 items and print to the console. Do this in Python, C#, Typescript, and Rust.
|
@@ -3,10 +3,10 @@ package main
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"io"
|
||||
"log"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -16,7 +16,7 @@ func main() {
|
||||
if err != nil {
|
||||
fmt.Print(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
responseData, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
|
@@ -1,36 +0,0 @@
|
||||
# Deploy Ollama to Kubernetes
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Ollama: https://ollama.ai/download
|
||||
- Kubernetes cluster. This example will use Google Kubernetes Engine.
|
||||
|
||||
## Steps
|
||||
|
||||
1. Create the Ollama namespace, daemon set, and service
|
||||
|
||||
```bash
|
||||
kubectl apply -f cpu.yaml
|
||||
```
|
||||
|
||||
1. Port forward the Ollama service to connect and use it locally
|
||||
|
||||
```bash
|
||||
kubectl -n ollama port-forward service/ollama 11434:80
|
||||
```
|
||||
|
||||
1. Pull and run a model, for example `orca-mini:3b`
|
||||
|
||||
```bash
|
||||
ollama run orca-mini:3b
|
||||
```
|
||||
|
||||
## (Optional) Hardware Acceleration
|
||||
|
||||
Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin). Follow the link for more details.
|
||||
|
||||
Once configured, create a GPU enabled Ollama deployment.
|
||||
|
||||
```bash
|
||||
kubectl apply -f gpu.yaml
|
||||
```
|
@@ -1,42 +0,0 @@
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Namespace
|
||||
metadata:
|
||||
name: ollama
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: ollama
|
||||
namespace: ollama
|
||||
spec:
|
||||
selector:
|
||||
matchLabels:
|
||||
name: ollama
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
name: ollama
|
||||
spec:
|
||||
containers:
|
||||
- name: ollama
|
||||
image: ollama/ollama:latest
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: 11434
|
||||
protocol: TCP
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: ollama
|
||||
namespace: ollama
|
||||
spec:
|
||||
type: ClusterIP
|
||||
selector:
|
||||
name: ollama
|
||||
ports:
|
||||
- port: 80
|
||||
name: http
|
||||
targetPort: http
|
||||
protocol: TCP
|
@@ -1,56 +0,0 @@
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Namespace
|
||||
metadata:
|
||||
name: ollama
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: ollama
|
||||
namespace: ollama
|
||||
spec:
|
||||
strategy:
|
||||
type: Recreate
|
||||
selector:
|
||||
matchLabels:
|
||||
name: ollama
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
name: ollama
|
||||
spec:
|
||||
containers:
|
||||
- name: ollama
|
||||
image: ollama/ollama:latest
|
||||
env:
|
||||
- name: PATH
|
||||
value: /usr/local/nvidia/bin:/usr/local/nvidia/lib64:/usr/bin:/usr/sbin:/bin:/sbin
|
||||
- name: LD_LIBRARY_PATH
|
||||
value: /usr/local/nvidia/lib64
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: 11434
|
||||
protocol: TCP
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 1
|
||||
tolerations:
|
||||
- key: nvidia.com/gpu
|
||||
operator: Exists
|
||||
effect: NoSchedule
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: ollama
|
||||
namespace: ollama
|
||||
spec:
|
||||
type: ClusterIP
|
||||
selector:
|
||||
name: ollama
|
||||
ports:
|
||||
- port: 80
|
||||
name: http
|
||||
targetPort: http
|
||||
protocol: TCP
|
@@ -6,6 +6,7 @@ PERSIST_DIRECTORY = os.environ.get('PERSIST_DIRECTORY', 'db')
|
||||
|
||||
# Define the Chroma settings
|
||||
CHROMA_SETTINGS = Settings(
|
||||
chroma_db_impl='duckdb+parquet',
|
||||
persist_directory=PERSIST_DIRECTORY,
|
||||
anonymized_telemetry=False
|
||||
)
|
||||
|
@@ -150,7 +150,7 @@ def main():
|
||||
print("Creating new vectorstore")
|
||||
texts = process_documents()
|
||||
print(f"Creating embeddings. May take some minutes...")
|
||||
db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory)
|
||||
db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS)
|
||||
db.persist()
|
||||
db = None
|
||||
|
||||
|
@@ -4,7 +4,6 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.llms import Ollama
|
||||
import chromadb
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
@@ -23,9 +22,7 @@ def main():
|
||||
# Parse the command line arguments
|
||||
args = parse_arguments()
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
|
||||
|
||||
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
|
||||
|
||||
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
|
||||
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
|
||||
# activate/deactivate the streaming StdOut callback for LLMs
|
||||
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -3,8 +3,10 @@
|
||||
|
||||
FROM orca
|
||||
TEMPLATE """
|
||||
{{- if .First }}
|
||||
### System:
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
### User:
|
||||
I hate it when my phone dies
|
||||
### Response:
|
||||
|
@@ -3,8 +3,10 @@
|
||||
This is a simple sentiments analyzer using the Orca model. When you pull Orca from the registry, it has a Template already defined that looks like this:
|
||||
|
||||
```Modelfile
|
||||
{{- if .First }}
|
||||
### System:
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
|
||||
### User:
|
||||
{{ .Prompt }}
|
||||
|
@@ -1,22 +0,0 @@
|
||||
# News Summarizer
|
||||
|
||||
This example goes through a series of steps:
|
||||
|
||||
1. You choose a topic area (e.g., "news", "NVidia", "music", etc.).
|
||||
2. Gets the most recent articles on that topic from various sources.
|
||||
3. Uses Ollama to summarize each article.
|
||||
4. Creates chunks of sentences from each article.
|
||||
5. Uses Sentence Transformers to generate embeddings for each of those chunks.
|
||||
6. You enter a question regarding the summaries shown.
|
||||
7. Uses Sentence Transformers to generate an embedding for that question.
|
||||
8. Uses the embedded question to find the most similar chunks.
|
||||
9. Feeds all that to Ollama to generate a good answer to your question based on these news articles.
|
||||
|
||||
This example lets you pick from a few different topic areas, then summarize the most recent x articles for that topic. It then creates chunks of sentences from each article and then generates embeddings for each of those chunks.
|
||||
|
||||
You can run the example like this:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
python summ.py
|
||||
```
|
@@ -1,9 +0,0 @@
|
||||
beautifulsoup4==4.12.2
|
||||
feedparser==6.0.10
|
||||
mattsollamatools==0.0.8
|
||||
newspaper3k==0.2.8
|
||||
nltk==3.8.1
|
||||
numpy==1.24.3
|
||||
Requests==2.31.0
|
||||
scikit_learn==1.3.0
|
||||
sentence_transformers==2.2.2
|
@@ -1,86 +0,0 @@
|
||||
import curses
|
||||
import json
|
||||
from utils import get_url_for_topic, topic_urls, menu, getUrls, get_summary, getArticleText, knn_search
|
||||
import requests
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from mattsollamatools import chunker
|
||||
|
||||
if __name__ == "__main__":
|
||||
chosen_topic = curses.wrapper(menu)
|
||||
print("Here is your news summary:\n")
|
||||
urls = getUrls(chosen_topic, n=5)
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
allEmbeddings = []
|
||||
|
||||
for url in urls:
|
||||
article={}
|
||||
article['embeddings'] = []
|
||||
article['url'] = url
|
||||
text = getArticleText(url)
|
||||
summary = get_summary(text)
|
||||
chunks = chunker(text) # Use the chunk_text function from web_utils
|
||||
embeddings = model.encode(chunks)
|
||||
for (chunk, embedding) in zip(chunks, embeddings):
|
||||
item = {}
|
||||
item['source'] = chunk
|
||||
item['embedding'] = embedding.tolist() # Convert NumPy array to list
|
||||
item['sourcelength'] = len(chunk)
|
||||
article['embeddings'].append(item)
|
||||
|
||||
allEmbeddings.append(article)
|
||||
|
||||
print(f"{summary}\n")
|
||||
|
||||
|
||||
while True:
|
||||
context = []
|
||||
# Input a question from the user
|
||||
question = input("Enter your question about the news, or type quit: ")
|
||||
|
||||
if question.lower() == 'quit':
|
||||
break
|
||||
|
||||
# Embed the user's question
|
||||
question_embedding = model.encode([question])
|
||||
|
||||
# Perform KNN search to find the best matches (indices and source text)
|
||||
best_matches = knn_search(question_embedding, allEmbeddings, k=10)
|
||||
|
||||
|
||||
sourcetext=""
|
||||
for i, (index, source_text) in enumerate(best_matches, start=1):
|
||||
sourcetext += f"{i}. Index: {index}, Source Text: {source_text}"
|
||||
|
||||
systemPrompt = f"Only use the following information to answer the question. Do not use anything else: {sourcetext}"
|
||||
|
||||
url = "http://localhost:11434/api/generate"
|
||||
|
||||
payload = {
|
||||
"model": "mistral-openorca",
|
||||
"prompt": question,
|
||||
"system": systemPrompt,
|
||||
"stream": False,
|
||||
"context": context
|
||||
}
|
||||
|
||||
# Convert the payload to a JSON string
|
||||
payload_json = json.dumps(payload)
|
||||
|
||||
# Set the headers to specify JSON content
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Send the POST request
|
||||
response = requests.post(url, data=payload_json, headers=headers)
|
||||
|
||||
# Check the response
|
||||
if response.status_code == 200:
|
||||
output = json.loads(response.text)
|
||||
context = output['context']
|
||||
print(output['response']+ "\n")
|
||||
|
||||
|
||||
else:
|
||||
print(f"Request failed with status code {response.status_code}")
|
||||
|
@@ -1,108 +0,0 @@
|
||||
import curses
|
||||
import feedparser
|
||||
import requests
|
||||
import unicodedata
|
||||
import json
|
||||
from newspaper import Article
|
||||
from bs4 import BeautifulSoup
|
||||
from nltk.tokenize import sent_tokenize, word_tokenize
|
||||
import numpy as np
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
from mattsollamatools import chunker
|
||||
|
||||
# Create a dictionary to store topics and their URLs
|
||||
topic_urls = {
|
||||
"Mac": "https://9to5mac.com/guides/mac/feed",
|
||||
"News": "http://www.npr.org/rss/rss.php?id=1001",
|
||||
"Nvidia": "https://nvidianews.nvidia.com/releases.xml",
|
||||
"Raspberry Pi": "https://www.raspberrypi.com/news/feed/",
|
||||
"Music": "https://www.billboard.com/c/music/music-news/feed/"
|
||||
}
|
||||
|
||||
# Use curses to create a menu of topics
|
||||
def menu(stdscr):
|
||||
chosen_topic = get_url_for_topic(stdscr)
|
||||
url = topic_urls[chosen_topic] if chosen_topic in topic_urls else "Topic not found"
|
||||
|
||||
stdscr.addstr(len(topic_urls) + 3, 0, f"Selected URL for {chosen_topic}: {url}")
|
||||
stdscr.refresh()
|
||||
|
||||
return chosen_topic
|
||||
|
||||
# You have chosen a topic. Now return the url for that topic
|
||||
def get_url_for_topic(stdscr):
|
||||
curses.curs_set(0) # Hide the cursor
|
||||
stdscr.clear()
|
||||
|
||||
stdscr.addstr(0, 0, "Choose a topic using the arrow keys (Press Enter to select):")
|
||||
|
||||
# Create a list of topics
|
||||
topics = list(topic_urls.keys())
|
||||
current_topic = 0
|
||||
|
||||
while True:
|
||||
for i, topic in enumerate(topics):
|
||||
if i == current_topic:
|
||||
stdscr.addstr(i + 2, 2, f"> {topic}")
|
||||
else:
|
||||
stdscr.addstr(i + 2, 2, f" {topic}")
|
||||
|
||||
stdscr.refresh()
|
||||
|
||||
key = stdscr.getch()
|
||||
|
||||
if key == curses.KEY_DOWN and current_topic < len(topics) - 1:
|
||||
current_topic += 1
|
||||
elif key == curses.KEY_UP and current_topic > 0:
|
||||
current_topic -= 1
|
||||
elif key == 10: # Enter key
|
||||
return topic_urls[topics[current_topic]]
|
||||
|
||||
# Get the last N URLs from an RSS feed
|
||||
def getUrls(feed_url, n=20):
|
||||
feed = feedparser.parse(feed_url)
|
||||
entries = feed.entries[-n:]
|
||||
urls = [entry.link for entry in entries]
|
||||
return urls
|
||||
|
||||
# Often there are a bunch of ads and menus on pages for a news article. This uses newspaper3k to get just the text of just the article.
|
||||
def getArticleText(url):
|
||||
article = Article(url)
|
||||
article.download()
|
||||
article.parse()
|
||||
return article.text
|
||||
|
||||
def get_summary(text):
|
||||
systemPrompt = "Write a concise summary of the text, return your responses with 5 lines that cover the key points of the text given."
|
||||
prompt = text
|
||||
|
||||
url = "http://localhost:11434/api/generate"
|
||||
|
||||
payload = {
|
||||
"model": "mistral-openorca",
|
||||
"prompt": prompt,
|
||||
"system": systemPrompt,
|
||||
"stream": False
|
||||
}
|
||||
payload_json = json.dumps(payload)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(url, data=payload_json, headers=headers)
|
||||
|
||||
return json.loads(response.text)["response"]
|
||||
|
||||
# Perform K-nearest neighbors (KNN) search
|
||||
def knn_search(question_embedding, embeddings, k=5):
|
||||
X = np.array([item['embedding'] for article in embeddings for item in article['embeddings']])
|
||||
source_texts = [item['source'] for article in embeddings for item in article['embeddings']]
|
||||
|
||||
# Fit a KNN model on the embeddings
|
||||
knn = NearestNeighbors(n_neighbors=k, metric='cosine')
|
||||
knn.fit(X)
|
||||
|
||||
# Find the indices and distances of the k-nearest neighbors
|
||||
distances, indices = knn.kneighbors(question_embedding, n_neighbors=k)
|
||||
|
||||
# Get the indices and source texts of the best matches
|
||||
best_matches = [(indices[0][i], source_texts[indices[0][i]]) for i in range(k)]
|
||||
|
||||
return best_matches
|
@@ -17,7 +17,7 @@ def generate(prompt, context):
|
||||
for line in r.iter_lines():
|
||||
body = json.loads(line)
|
||||
response_part = body.get('response', '')
|
||||
# the response streams one token at a time, print that as we receive it
|
||||
# the response streams one token at a time, print that as we recieve it
|
||||
print(response_part, end='', flush=True)
|
||||
|
||||
if 'error' in body:
|
||||
@@ -35,4 +35,4 @@ def main():
|
||||
print()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
@@ -2,21 +2,14 @@ package format
|
||||
|
||||
import "fmt"
|
||||
|
||||
const (
|
||||
Byte = 1
|
||||
KiloByte = Byte * 1000
|
||||
MegaByte = KiloByte * 1000
|
||||
GigaByte = MegaByte * 1000
|
||||
)
|
||||
|
||||
func HumanBytes(b int64) string {
|
||||
switch {
|
||||
case b > GigaByte:
|
||||
return fmt.Sprintf("%.1f GB", float64(b)/GigaByte)
|
||||
case b > MegaByte:
|
||||
return fmt.Sprintf("%.1f MB", float64(b)/MegaByte)
|
||||
case b > KiloByte:
|
||||
return fmt.Sprintf("%.1f KB", float64(b)/KiloByte)
|
||||
case b > 1000*1000*1000:
|
||||
return fmt.Sprintf("%d GB", b/1000/1000/1000)
|
||||
case b > 1000*1000:
|
||||
return fmt.Sprintf("%d MB", b/1000/1000)
|
||||
case b > 1000:
|
||||
return fmt.Sprintf("%d KB", b/1000)
|
||||
default:
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
|
@@ -1,25 +0,0 @@
|
||||
package format
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
const (
|
||||
Thousand = 1000
|
||||
Million = Thousand * 1000
|
||||
Billion = Million * 1000
|
||||
)
|
||||
|
||||
func HumanNumber(b uint64) string {
|
||||
switch {
|
||||
case b > Billion:
|
||||
return fmt.Sprintf("%.0fB", math.Round(float64(b)/Billion))
|
||||
case b > Million:
|
||||
return fmt.Sprintf("%.0fM", math.Round(float64(b)/Million))
|
||||
case b > Thousand:
|
||||
return fmt.Sprintf("%.0fK", math.Round(float64(b)/Thousand))
|
||||
default:
|
||||
return fmt.Sprintf("%d", b)
|
||||
}
|
||||
}
|
@@ -29,7 +29,7 @@ func TestHumanTime(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("soon", func(t *testing.T) {
|
||||
v := now.Add(800 * time.Millisecond)
|
||||
v := now.Add(800*time.Millisecond)
|
||||
assertEqual(t, HumanTime(v, ""), "Less than a second from now")
|
||||
})
|
||||
}
|
||||
|
14
go.mod
14
go.mod
@@ -3,11 +3,12 @@ module github.com/jmorganca/ollama
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/dustin/go-humanize v1.0.1
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/pdevine/readline v1.5.2
|
||||
github.com/spf13/cobra v1.7.0
|
||||
golang.org/x/sync v0.3.0
|
||||
)
|
||||
@@ -38,12 +39,13 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/crypto v0.10.0
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
golang.org/x/term v0.13.0
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/net v0.10.0 // indirect
|
||||
golang.org/x/sys v0.11.0 // indirect
|
||||
golang.org/x/term v0.10.0
|
||||
golang.org/x/text v0.10.0 // indirect
|
||||
gonum.org/v1/gonum v0.13.0
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
33
go.sum
33
go.sum
@@ -4,13 +4,17 @@ github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
|
||||
github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ=
|
||||
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
|
||||
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
|
||||
@@ -74,6 +78,8 @@ github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0=
|
||||
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y=
|
||||
github.com/pdevine/readline v1.5.2 h1:oz6Y5GdTmhPG+08hhxcAvtHitSANWuA2100Sppb38xI=
|
||||
github.com/pdevine/readline v1.5.2/go.mod h1:na/LbuE5PYwxI7GyopWdIs3U8HVe89lYlNTFTXH3wOw=
|
||||
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
@@ -112,32 +118,35 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
|
||||
golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
|
||||
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
|
||||
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
|
||||
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM=
|
||||
gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
|
@@ -175,8 +175,7 @@ const (
|
||||
// Magic constant for `ggla` files (LoRA adapter).
|
||||
FILE_MAGIC_GGLA = 0x67676C61
|
||||
// Magic constant for `gguf` files (versioned, gguf)
|
||||
FILE_MAGIC_GGUF_LE = 0x46554747
|
||||
FILE_MAGIC_GGUF_BE = 0x47475546
|
||||
FILE_MAGIC_GGUF = 0x46554747
|
||||
)
|
||||
|
||||
func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
|
||||
@@ -192,10 +191,8 @@ func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
|
||||
ggml.container = &containerGGJT{}
|
||||
case FILE_MAGIC_GGLA:
|
||||
ggml.container = &containerLORA{}
|
||||
case FILE_MAGIC_GGUF_LE:
|
||||
ggml.container = &containerGGUF{bo: binary.LittleEndian}
|
||||
case FILE_MAGIC_GGUF_BE:
|
||||
ggml.container = &containerGGUF{bo: binary.BigEndian}
|
||||
case FILE_MAGIC_GGUF:
|
||||
ggml.container = &containerGGUF{}
|
||||
default:
|
||||
return nil, errors.New("invalid file magic")
|
||||
}
|
||||
|
126
llm/gguf.go
126
llm/gguf.go
@@ -3,15 +3,12 @@ package llm
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/jmorganca/ollama/format"
|
||||
)
|
||||
|
||||
type containerGGUF struct {
|
||||
bo binary.ByteOrder
|
||||
|
||||
Version uint32
|
||||
|
||||
V1 struct {
|
||||
@@ -23,8 +20,6 @@ type containerGGUF struct {
|
||||
NumTensor uint64
|
||||
NumKV uint64
|
||||
}
|
||||
|
||||
parameters uint64
|
||||
}
|
||||
|
||||
func (c *containerGGUF) Name() string {
|
||||
@@ -32,13 +27,15 @@ func (c *containerGGUF) Name() string {
|
||||
}
|
||||
|
||||
func (c *containerGGUF) Decode(r io.Reader) (model, error) {
|
||||
binary.Read(r, c.bo, &c.Version)
|
||||
binary.Read(r, binary.LittleEndian, &c.Version)
|
||||
|
||||
switch c.Version {
|
||||
case 1:
|
||||
binary.Read(r, c.bo, &c.V1)
|
||||
binary.Read(r, binary.LittleEndian, &c.V1)
|
||||
case 2:
|
||||
binary.Read(r, binary.LittleEndian, &c.V2)
|
||||
default:
|
||||
binary.Read(r, c.bo, &c.V2)
|
||||
return nil, errors.New("invalid version")
|
||||
}
|
||||
|
||||
model := newGGUFModel(c)
|
||||
@@ -79,14 +76,6 @@ func newGGUFModel(container *containerGGUF) *ggufModel {
|
||||
}
|
||||
}
|
||||
|
||||
func (llm *ggufModel) NumTensor() uint64 {
|
||||
if llm.Version == 1 {
|
||||
return uint64(llm.V1.NumTensor)
|
||||
}
|
||||
|
||||
return llm.V2.NumTensor
|
||||
}
|
||||
|
||||
func (llm *ggufModel) NumKV() uint64 {
|
||||
if llm.Version == 1 {
|
||||
return uint64(llm.V1.NumKV)
|
||||
@@ -105,10 +94,6 @@ func (llm *ggufModel) ModelFamily() string {
|
||||
}
|
||||
|
||||
func (llm *ggufModel) ModelType() string {
|
||||
if llm.parameters > 0 {
|
||||
return format.HumanNumber(llm.parameters)
|
||||
}
|
||||
|
||||
switch llm.ModelFamily() {
|
||||
case "llama":
|
||||
if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
|
||||
@@ -143,9 +128,13 @@ func (llm *ggufModel) FileType() string {
|
||||
}
|
||||
|
||||
func (llm *ggufModel) Decode(r io.Reader) error {
|
||||
// decode key-values
|
||||
read := llm.readString
|
||||
if llm.Version == 1 {
|
||||
read = llm.readStringV1
|
||||
}
|
||||
|
||||
for i := 0; uint64(i) < llm.NumKV(); i++ {
|
||||
k, err := llm.readString(r)
|
||||
k, err := read(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -177,14 +166,24 @@ func (llm *ggufModel) Decode(r io.Reader) error {
|
||||
case ggufTypeBool:
|
||||
v = llm.readBool(r)
|
||||
case ggufTypeString:
|
||||
s, err := llm.readString(r)
|
||||
fn := llm.readString
|
||||
if llm.Version == 1 {
|
||||
fn = llm.readStringV1
|
||||
}
|
||||
|
||||
s, err := fn(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v = s
|
||||
case ggufTypeArray:
|
||||
a, err := llm.readArray(r)
|
||||
fn := llm.readArray
|
||||
if llm.Version == 1 {
|
||||
fn = llm.readArrayV1
|
||||
}
|
||||
|
||||
a, err := fn(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -197,25 +196,6 @@ func (llm *ggufModel) Decode(r io.Reader) error {
|
||||
llm.kv[k] = v
|
||||
}
|
||||
|
||||
// decode tensors
|
||||
for i := 0; uint64(i) < llm.NumTensor(); i++ {
|
||||
if _, err := llm.readString(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dimensions := llm.readU32(r)
|
||||
|
||||
var elements uint64 = 1
|
||||
for i := 0; uint32(i) < dimensions; i++ {
|
||||
elements *= llm.readU64(r)
|
||||
}
|
||||
|
||||
llm.readU32(r) // type
|
||||
llm.readU64(r) // offset
|
||||
|
||||
llm.parameters += elements
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -229,75 +209,75 @@ func (llm *ggufModel) NumLayers() int64 {
|
||||
return int64(v)
|
||||
}
|
||||
|
||||
func (llm ggufModel) readU8(r io.Reader) uint8 {
|
||||
func (ggufModel) readU8(r io.Reader) uint8 {
|
||||
var u8 uint8
|
||||
binary.Read(r, llm.bo, &u8)
|
||||
binary.Read(r, binary.LittleEndian, &u8)
|
||||
return u8
|
||||
}
|
||||
|
||||
func (llm ggufModel) readI8(r io.Reader) int8 {
|
||||
func (ggufModel) readI8(r io.Reader) int8 {
|
||||
var i8 int8
|
||||
binary.Read(r, llm.bo, &i8)
|
||||
binary.Read(r, binary.LittleEndian, &i8)
|
||||
return i8
|
||||
}
|
||||
|
||||
func (llm ggufModel) readU16(r io.Reader) uint16 {
|
||||
func (ggufModel) readU16(r io.Reader) uint16 {
|
||||
var u16 uint16
|
||||
binary.Read(r, llm.bo, &u16)
|
||||
binary.Read(r, binary.LittleEndian, &u16)
|
||||
return u16
|
||||
}
|
||||
|
||||
func (llm ggufModel) readI16(r io.Reader) int16 {
|
||||
func (ggufModel) readI16(r io.Reader) int16 {
|
||||
var i16 int16
|
||||
binary.Read(r, llm.bo, &i16)
|
||||
binary.Read(r, binary.LittleEndian, &i16)
|
||||
return i16
|
||||
}
|
||||
|
||||
func (llm ggufModel) readU32(r io.Reader) uint32 {
|
||||
func (ggufModel) readU32(r io.Reader) uint32 {
|
||||
var u32 uint32
|
||||
binary.Read(r, llm.bo, &u32)
|
||||
binary.Read(r, binary.LittleEndian, &u32)
|
||||
return u32
|
||||
}
|
||||
|
||||
func (llm ggufModel) readI32(r io.Reader) int32 {
|
||||
func (ggufModel) readI32(r io.Reader) int32 {
|
||||
var i32 int32
|
||||
binary.Read(r, llm.bo, &i32)
|
||||
binary.Read(r, binary.LittleEndian, &i32)
|
||||
return i32
|
||||
}
|
||||
|
||||
func (llm ggufModel) readU64(r io.Reader) uint64 {
|
||||
func (ggufModel) readU64(r io.Reader) uint64 {
|
||||
var u64 uint64
|
||||
binary.Read(r, llm.bo, &u64)
|
||||
binary.Read(r, binary.LittleEndian, &u64)
|
||||
return u64
|
||||
}
|
||||
|
||||
func (llm ggufModel) readI64(r io.Reader) int64 {
|
||||
func (ggufModel) readI64(r io.Reader) int64 {
|
||||
var i64 int64
|
||||
binary.Read(r, llm.bo, &i64)
|
||||
binary.Read(r, binary.LittleEndian, &i64)
|
||||
return i64
|
||||
}
|
||||
|
||||
func (llm ggufModel) readF32(r io.Reader) float32 {
|
||||
func (ggufModel) readF32(r io.Reader) float32 {
|
||||
var f32 float32
|
||||
binary.Read(r, llm.bo, &f32)
|
||||
binary.Read(r, binary.LittleEndian, &f32)
|
||||
return f32
|
||||
}
|
||||
|
||||
func (llm ggufModel) readF64(r io.Reader) float64 {
|
||||
func (ggufModel) readF64(r io.Reader) float64 {
|
||||
var f64 float64
|
||||
binary.Read(r, llm.bo, &f64)
|
||||
binary.Read(r, binary.LittleEndian, &f64)
|
||||
return f64
|
||||
}
|
||||
|
||||
func (llm ggufModel) readBool(r io.Reader) bool {
|
||||
func (ggufModel) readBool(r io.Reader) bool {
|
||||
var b bool
|
||||
binary.Read(r, llm.bo, &b)
|
||||
binary.Read(r, binary.LittleEndian, &b)
|
||||
return b
|
||||
}
|
||||
|
||||
func (llm ggufModel) readStringV1(r io.Reader) (string, error) {
|
||||
func (ggufModel) readStringV1(r io.Reader) (string, error) {
|
||||
var nameLength uint32
|
||||
binary.Read(r, llm.bo, &nameLength)
|
||||
binary.Read(r, binary.LittleEndian, &nameLength)
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
|
||||
@@ -311,12 +291,8 @@ func (llm ggufModel) readStringV1(r io.Reader) (string, error) {
|
||||
}
|
||||
|
||||
func (llm ggufModel) readString(r io.Reader) (string, error) {
|
||||
if llm.Version == 1 {
|
||||
return llm.readStringV1(r)
|
||||
}
|
||||
|
||||
var nameLength uint64
|
||||
binary.Read(r, llm.bo, &nameLength)
|
||||
binary.Read(r, binary.LittleEndian, &nameLength)
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
|
||||
@@ -364,10 +340,6 @@ func (llm *ggufModel) readArrayV1(r io.Reader) (arr []any, err error) {
|
||||
}
|
||||
|
||||
func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) {
|
||||
if llm.Version == 1 {
|
||||
return llm.readArrayV1(r)
|
||||
}
|
||||
|
||||
atype := llm.readU32(r)
|
||||
n := llm.readU64(r)
|
||||
|
||||
|
@@ -12,8 +12,7 @@ package llm
|
||||
//go:generate mv ggml/build/cpu/bin/server ggml/build/cpu/bin/ollama-runner
|
||||
|
||||
//go:generate git submodule update --force gguf
|
||||
//go:generate git -C gguf apply ../patches/0001-update-default-log-target.patch
|
||||
//go:generate git -C gguf apply ../patches/0001-metal-handle-ggml_scale-for-n-4-0-close-3754.patch
|
||||
//go:generate git -C gguf apply ../patches/0001-remove-warm-up-logging.patch
|
||||
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
|
||||
//go:generate cmake --build gguf/build/cpu --target server --config Release
|
||||
//go:generate mv gguf/build/cpu/bin/server gguf/build/cpu/bin/ollama-runner
|
||||
|
@@ -12,8 +12,7 @@ package llm
|
||||
//go:generate mv ggml/build/metal/bin/server ggml/build/metal/bin/ollama-runner
|
||||
|
||||
//go:generate git submodule update --force gguf
|
||||
//go:generate git -C gguf apply ../patches/0001-update-default-log-target.patch
|
||||
//go:generate git -C gguf apply ../patches/0001-metal-handle-ggml_scale-for-n-4-0-close-3754.patch
|
||||
//go:generate git -C gguf apply ../patches/0001-remove-warm-up-logging.patch
|
||||
//go:generate cmake -S gguf -B gguf/build/metal -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
|
||||
//go:generate cmake --build gguf/build/metal --target server --config Release
|
||||
//go:generate mv gguf/build/metal/bin/server gguf/build/metal/bin/ollama-runner
|
||||
|
@@ -13,14 +13,14 @@ package llm
|
||||
|
||||
//go:generate git submodule update --force gguf
|
||||
//go:generate git -C gguf apply ../patches/0001-copy-cuda-runtime-libraries.patch
|
||||
//go:generate git -C gguf apply ../patches/0001-update-default-log-target.patch
|
||||
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off
|
||||
//go:generate git -C gguf apply ../patches/0001-remove-warm-up-logging.patch
|
||||
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on
|
||||
//go:generate cmake --build gguf/build/cpu --target server --config Release
|
||||
//go:generate mv gguf/build/cpu/bin/server gguf/build/cpu/bin/ollama-runner
|
||||
|
||||
//go:generate cmake -S ggml -B ggml/build/cuda -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on
|
||||
//go:generate cmake --build ggml/build/cuda --target server --config Release
|
||||
//go:generate mv ggml/build/cuda/bin/server ggml/build/cuda/bin/ollama-runner
|
||||
//go:generate cmake -S gguf -B gguf/build/cuda -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off
|
||||
//go:generate cmake -S gguf -B gguf/build/cuda -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on
|
||||
//go:generate cmake --build gguf/build/cuda --target server --config Release
|
||||
//go:generate mv gguf/build/cuda/bin/server gguf/build/cuda/bin/ollama-runner
|
||||
|
@@ -10,7 +10,7 @@ package llm
|
||||
//go:generate cmd /c move ggml\build\cpu\bin\Release\server.exe ggml\build\cpu\bin\Release\ollama-runner.exe
|
||||
|
||||
//go:generate git submodule update --force gguf
|
||||
//go:generate git -C gguf apply ../patches/0001-update-default-log-target.patch
|
||||
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off
|
||||
//go:generate git -C gguf apply ../patches/0001-remove-warm-up-logging.patch
|
||||
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on
|
||||
//go:generate cmake --build gguf/build/cpu --target server --config Release
|
||||
//go:generate cmd /c move gguf\build\cpu\bin\Release\server.exe gguf\build\cpu\bin\Release\ollama-runner.exe
|
||||
|
Submodule llm/llama.cpp/gguf updated: 9e70cc0322...bc9d3e3971
@@ -1,91 +0,0 @@
|
||||
From 469c9addef75893e6be12edda852d12e840bf064 Mon Sep 17 00:00:00 2001
|
||||
From: Georgi Gerganov <ggerganov@gmail.com>
|
||||
Date: Tue, 24 Oct 2023 09:46:50 +0300
|
||||
Subject: [PATCH 1/2] metal : handle ggml_scale for n%4 != 0 (close #3754)
|
||||
|
||||
ggml-ci
|
||||
---
|
||||
ggml-metal.m | 18 +++++++++++++-----
|
||||
ggml-metal.metal | 10 +++++++++-
|
||||
2 files changed, 22 insertions(+), 6 deletions(-)
|
||||
|
||||
diff --git a/ggml-metal.m b/ggml-metal.m
|
||||
index c908106..c1901dc 100644
|
||||
--- a/ggml-metal.m
|
||||
+++ b/ggml-metal.m
|
||||
@@ -62,6 +62,7 @@
|
||||
GGML_METAL_DECL_KERNEL(mul);
|
||||
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
||||
GGML_METAL_DECL_KERNEL(scale);
|
||||
+ GGML_METAL_DECL_KERNEL(scale_4);
|
||||
GGML_METAL_DECL_KERNEL(silu);
|
||||
GGML_METAL_DECL_KERNEL(relu);
|
||||
GGML_METAL_DECL_KERNEL(gelu);
|
||||
@@ -249,6 +250,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
|
||||
GGML_METAL_ADD_KERNEL(mul);
|
||||
GGML_METAL_ADD_KERNEL(mul_row);
|
||||
GGML_METAL_ADD_KERNEL(scale);
|
||||
+ GGML_METAL_ADD_KERNEL(scale_4);
|
||||
GGML_METAL_ADD_KERNEL(silu);
|
||||
GGML_METAL_ADD_KERNEL(relu);
|
||||
GGML_METAL_ADD_KERNEL(gelu);
|
||||
@@ -347,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(mul);
|
||||
GGML_METAL_DEL_KERNEL(mul_row);
|
||||
GGML_METAL_DEL_KERNEL(scale);
|
||||
+ GGML_METAL_DEL_KERNEL(scale_4);
|
||||
GGML_METAL_DEL_KERNEL(silu);
|
||||
GGML_METAL_DEL_KERNEL(relu);
|
||||
GGML_METAL_DEL_KERNEL(gelu);
|
||||
@@ -923,15 +926,20 @@ void ggml_metal_graph_compute(
|
||||
|
||||
const float scale = *(const float *) src1->data;
|
||||
|
||||
- [encoder setComputePipelineState:ctx->pipeline_scale];
|
||||
+ int64_t n = ggml_nelements(dst);
|
||||
+
|
||||
+ if (n % 4 == 0) {
|
||||
+ n /= 4;
|
||||
+ [encoder setComputePipelineState:ctx->pipeline_scale_4];
|
||||
+ } else {
|
||||
+ [encoder setComputePipelineState:ctx->pipeline_scale];
|
||||
+ }
|
||||
+
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||
|
||||
- const int64_t n = ggml_nelements(dst);
|
||||
- GGML_ASSERT(n % 4 == 0);
|
||||
-
|
||||
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||
diff --git a/ggml-metal.metal b/ggml-metal.metal
|
||||
index 69fc713..f4b4605 100644
|
||||
--- a/ggml-metal.metal
|
||||
+++ b/ggml-metal.metal
|
||||
@@ -125,9 +125,17 @@ kernel void kernel_mul_row(
|
||||
}
|
||||
|
||||
kernel void kernel_scale(
|
||||
+ device const float * src0,
|
||||
+ device float * dst,
|
||||
+ constant float & scale,
|
||||
+ uint tpig[[thread_position_in_grid]]) {
|
||||
+ dst[tpig] = src0[tpig] * scale;
|
||||
+}
|
||||
+
|
||||
+kernel void kernel_scale_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
- constant float & scale,
|
||||
+ constant float & scale,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * scale;
|
||||
}
|
||||
--
|
||||
2.39.3 (Apple Git-145)
|
||||
|
25
llm/llama.cpp/patches/0001-remove-warm-up-logging.patch
Normal file
25
llm/llama.cpp/patches/0001-remove-warm-up-logging.patch
Normal file
@@ -0,0 +1,25 @@
|
||||
From 07993bdc35345b67b27aa649a7c099ad42d80c4c Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <mxyng@pm.me>
|
||||
Date: Thu, 21 Sep 2023 14:43:21 -0700
|
||||
Subject: [PATCH] remove warm up logging
|
||||
|
||||
---
|
||||
common/common.cpp | 2 --
|
||||
1 file changed, 2 deletions(-)
|
||||
|
||||
diff --git a/common/common.cpp b/common/common.cpp
|
||||
index 2597ba0..b56549b 100644
|
||||
--- a/common/common.cpp
|
||||
+++ b/common/common.cpp
|
||||
@@ -780,8 +780,6 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
||||
}
|
||||
|
||||
{
|
||||
- LOG("warming up the model with an empty run\n");
|
||||
-
|
||||
const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
|
||||
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
|
||||
llama_reset_timings(lctx);
|
||||
--
|
||||
2.42.0
|
||||
|
@@ -1,25 +0,0 @@
|
||||
From 6465fec6290f0a7f5d4d0fbe6bcf634e4810dde6 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <mxyng@pm.me>
|
||||
Date: Mon, 23 Oct 2023 10:39:34 -0700
|
||||
Subject: [PATCH] default log stderr
|
||||
|
||||
---
|
||||
common/log.h | 2 +-
|
||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||
|
||||
diff --git a/common/log.h b/common/log.h
|
||||
index b8953fd..25522cd 100644
|
||||
--- a/common/log.h
|
||||
+++ b/common/log.h
|
||||
@@ -90,7 +90,7 @@
|
||||
// }
|
||||
//
|
||||
#ifndef LOG_TARGET
|
||||
- #define LOG_TARGET log_handler()
|
||||
+ #define LOG_TARGET nullptr
|
||||
#endif
|
||||
|
||||
#ifndef LOG_TEE_TARGET
|
||||
--
|
||||
2.42.0
|
||||
|
341
llm/llama.go
341
llm/llama.go
@@ -24,81 +24,51 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/format"
|
||||
)
|
||||
|
||||
const jsonGrammar = `
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
|
||||
//go:embed llama.cpp/*/build/*/bin/*
|
||||
var llamaCppEmbed embed.FS
|
||||
|
||||
type ModelRunner struct {
|
||||
Path string // path to the model runner executable
|
||||
Accelerated bool
|
||||
Path string // path to the model runner executable
|
||||
}
|
||||
|
||||
func chooseRunners(workDir, runnerType string) []ModelRunner {
|
||||
buildPath := path.Join("llama.cpp", runnerType, "build")
|
||||
var runners []ModelRunner
|
||||
var runners []string
|
||||
|
||||
// set the runners based on the OS
|
||||
// IMPORTANT: the order of the runners in the array is the priority order
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
runners = []ModelRunner{
|
||||
{Path: path.Join(buildPath, "metal", "bin", "ollama-runner")},
|
||||
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
|
||||
runners = []string{
|
||||
path.Join(buildPath, "metal", "bin", "ollama-runner"),
|
||||
path.Join(buildPath, "cpu", "bin", "ollama-runner"),
|
||||
}
|
||||
case "linux":
|
||||
runners = []ModelRunner{
|
||||
{Path: path.Join(buildPath, "cuda", "bin", "ollama-runner"), Accelerated: true},
|
||||
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
|
||||
runners = []string{
|
||||
path.Join(buildPath, "cuda", "bin", "ollama-runner"),
|
||||
path.Join(buildPath, "cpu", "bin", "ollama-runner"),
|
||||
}
|
||||
case "windows":
|
||||
// TODO: select windows GPU runner here when available
|
||||
runners = []ModelRunner{
|
||||
{Path: path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe")},
|
||||
runners = []string{
|
||||
path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe"),
|
||||
}
|
||||
default:
|
||||
log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
|
||||
runners = []ModelRunner{
|
||||
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
|
||||
runners = []string{
|
||||
path.Join(buildPath, "cpu", "bin", "ollama-runner"),
|
||||
}
|
||||
}
|
||||
|
||||
runnerAvailable := false // if no runner files are found in the embed, this flag will cause a fast fail
|
||||
for _, r := range runners {
|
||||
// find all the files in the runner's bin directory
|
||||
files, err := fs.Glob(llamaCppEmbed, path.Join(path.Dir(r.Path), "*"))
|
||||
files, err := fs.Glob(llamaCppEmbed, path.Join(path.Dir(r), "*"))
|
||||
if err != nil {
|
||||
// this is expected, ollama may be compiled without all runners packed in
|
||||
log.Printf("%s runner not found: %v", r.Path, err)
|
||||
log.Printf("%s runner not found: %v", r, err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -145,10 +115,7 @@ func chooseRunners(workDir, runnerType string) []ModelRunner {
|
||||
localRunnersByPriority := []ModelRunner{}
|
||||
for _, r := range runners {
|
||||
// clean the ModelRunner paths so that they match the OS we are running on
|
||||
localRunnersByPriority = append(localRunnersByPriority, ModelRunner{
|
||||
Path: filepath.Clean(path.Join(workDir, r.Path)),
|
||||
Accelerated: r.Accelerated,
|
||||
})
|
||||
localRunnersByPriority = append(localRunnersByPriority, ModelRunner{Path: filepath.Clean(path.Join(workDir, r))})
|
||||
}
|
||||
|
||||
return localRunnersByPriority
|
||||
@@ -211,12 +178,12 @@ type llamaHyperparameters struct {
|
||||
}
|
||||
|
||||
type Running struct {
|
||||
Port int
|
||||
Cmd *exec.Cmd
|
||||
Cancel context.CancelFunc
|
||||
exitOnce sync.Once
|
||||
exitCh chan error // channel to receive the exit status of the subprocess
|
||||
*StatusWriter // captures error messages from the llama runner process
|
||||
Port int
|
||||
Cmd *exec.Cmd
|
||||
Cancel context.CancelFunc
|
||||
exitOnce sync.Once
|
||||
exitCh chan error // channel to receive the exit status of the subprocess
|
||||
exitErr error // error returned by the subprocess
|
||||
}
|
||||
|
||||
type llama struct {
|
||||
@@ -224,44 +191,31 @@ type llama struct {
|
||||
Running
|
||||
}
|
||||
|
||||
var (
|
||||
errNvidiaSMI = errors.New("nvidia-smi command failed")
|
||||
errAvailableVRAM = errors.New("not enough VRAM available, falling back to CPU only")
|
||||
)
|
||||
var errNoGPU = errors.New("nvidia-smi command failed")
|
||||
|
||||
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
|
||||
// CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs
|
||||
func CheckVRAM() (int64, error) {
|
||||
cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits")
|
||||
var stdout bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return 0, errNvidiaSMI
|
||||
return 0, errNoGPU
|
||||
}
|
||||
|
||||
var freeMiB int64
|
||||
var free int64
|
||||
scanner := bufio.NewScanner(&stdout)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.Contains(line, "[Insufficient Permissions]") {
|
||||
return 0, fmt.Errorf("GPU support may not enabled, check you have installed GPU drivers and have the necessary permissions to run nvidia-smi")
|
||||
}
|
||||
|
||||
vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
|
||||
}
|
||||
|
||||
freeMiB += vram
|
||||
free += vram
|
||||
}
|
||||
|
||||
freeBytes := freeMiB * 1024 * 1024
|
||||
if freeBytes < 2*format.GigaByte {
|
||||
log.Printf("less than 2 GB VRAM available")
|
||||
return 0, errAvailableVRAM
|
||||
}
|
||||
|
||||
return freeBytes, nil
|
||||
return free, nil
|
||||
}
|
||||
|
||||
func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
||||
@@ -269,25 +223,24 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
||||
return opts.NumGPU
|
||||
}
|
||||
if runtime.GOOS == "linux" {
|
||||
freeBytes, err := CheckVRAM()
|
||||
vramMib, err := CheckVRAM()
|
||||
if err != nil {
|
||||
if !errors.Is(err, errNvidiaSMI) {
|
||||
if err.Error() != "nvidia-smi command failed" {
|
||||
log.Print(err.Error())
|
||||
}
|
||||
// nvidia driver not installed or no nvidia GPU found
|
||||
return 0
|
||||
}
|
||||
|
||||
/*
|
||||
Calculate bytes per layer, this will roughly be the size of the model file divided by the number of layers.
|
||||
We can store the model weights and the kv cache in vram,
|
||||
to enable kv chache vram storage add two additional layers to the number of layers retrieved from the model file.
|
||||
*/
|
||||
freeVramBytes := int64(vramMib) * 1024 * 1024 // 1 MiB = 1024^2 bytes
|
||||
|
||||
// Calculate bytes per layer
|
||||
// TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size
|
||||
bytesPerLayer := fileSizeBytes / numLayer
|
||||
|
||||
// 75% of the absolute max number of layers we can fit in available VRAM, off-loading too many layers to the GPU can cause OOM errors
|
||||
layers := int(freeBytes/bytesPerLayer) * 3 / 4
|
||||
log.Printf("%d MB VRAM available, loading up to %d GPU layers", freeBytes/(1024*1024), layers)
|
||||
// max number of layers we can fit in VRAM, subtract 5% to prevent consuming all available VRAM and running out of memory
|
||||
layers := int(freeVramBytes/bytesPerLayer) * 95 / 100
|
||||
log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, layers)
|
||||
|
||||
return layers
|
||||
}
|
||||
@@ -297,8 +250,7 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
||||
|
||||
// StatusWriter is a writer that captures error messages from the llama runner process
|
||||
type StatusWriter struct {
|
||||
ErrCh chan error
|
||||
LastErrMsg string
|
||||
ErrCh chan error
|
||||
}
|
||||
|
||||
func NewStatusWriter() *StatusWriter {
|
||||
@@ -308,18 +260,10 @@ func NewStatusWriter() *StatusWriter {
|
||||
}
|
||||
|
||||
func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||
var errMsg string
|
||||
if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
|
||||
errMsg = string(bytes.TrimSpace(after))
|
||||
} else if _, after, ok := bytes.Cut(b, []byte("CUDA error")); ok {
|
||||
errMsg = string(bytes.TrimSpace(after))
|
||||
err := fmt.Errorf("llama runner: %s", after)
|
||||
w.ErrCh <- err
|
||||
}
|
||||
|
||||
if errMsg != "" {
|
||||
w.LastErrMsg = errMsg
|
||||
w.ErrCh <- fmt.Errorf("llama runner: %s", errMsg)
|
||||
}
|
||||
|
||||
return os.Stderr.Write(b)
|
||||
}
|
||||
|
||||
@@ -333,23 +277,16 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
|
||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||
}
|
||||
|
||||
numGPU := NumGPU(numLayers, fileInfo.Size(), opts)
|
||||
params := []string{
|
||||
"--model", model,
|
||||
"--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
|
||||
"--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
|
||||
"--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
|
||||
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
|
||||
"--n-gpu-layers", fmt.Sprintf("%d", numGPU),
|
||||
"--n-gpu-layers", fmt.Sprintf("%d", NumGPU(numLayers, fileInfo.Size(), opts)),
|
||||
"--embedding",
|
||||
}
|
||||
|
||||
if opts.RopeFrequencyBase > 0 {
|
||||
params = append(params, "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase))
|
||||
}
|
||||
|
||||
if opts.RopeFrequencyScale > 0 {
|
||||
params = append(params, "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale))
|
||||
}
|
||||
|
||||
if opts.NumGQA > 0 {
|
||||
params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
|
||||
}
|
||||
@@ -380,11 +317,6 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
|
||||
|
||||
// start the llama.cpp server with a retry in case the port is already in use
|
||||
for _, runner := range runners {
|
||||
if runner.Accelerated && numGPU == 0 {
|
||||
log.Printf("skipping accelerated runner because num_gpu=0")
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := os.Stat(runner.Path); err != nil {
|
||||
log.Printf("llama runner not found: %v", err)
|
||||
continue
|
||||
@@ -397,15 +329,7 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
|
||||
runner.Path,
|
||||
append(params, "--port", strconv.Itoa(port))...,
|
||||
)
|
||||
|
||||
var libraryPaths []string
|
||||
if libraryPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||
libraryPaths = append(libraryPaths, libraryPath)
|
||||
}
|
||||
|
||||
libraryPaths = append(libraryPaths, filepath.Dir(runner.Path))
|
||||
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", strings.Join(libraryPaths, ":")))
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", filepath.Dir(runner.Path)))
|
||||
cmd.Stdout = os.Stderr
|
||||
statusWriter := NewStatusWriter()
|
||||
cmd.Stderr = statusWriter
|
||||
@@ -421,13 +345,7 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
|
||||
// monitor the llama runner process and signal when it exits
|
||||
go func() {
|
||||
err := llm.Cmd.Wait()
|
||||
// default to printing the exit message of the command process, it will probably just say 'exit staus 1'
|
||||
errMsg := err.Error()
|
||||
// try to set a better error message if llama runner logs captured an error
|
||||
if statusWriter.LastErrMsg != "" {
|
||||
errMsg = statusWriter.LastErrMsg
|
||||
}
|
||||
log.Println(errMsg)
|
||||
llm.exitErr = err
|
||||
// llm.Cmd.Wait() can only be called once, use this exit channel to signal that the process has exited
|
||||
llm.exitOnce.Do(func() {
|
||||
close(llm.exitCh)
|
||||
@@ -497,9 +415,10 @@ func (llm *llama) Close() {
|
||||
|
||||
// wait for the command to exit to prevent race conditions with the next run
|
||||
<-llm.exitCh
|
||||
err := llm.exitErr
|
||||
|
||||
if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
|
||||
log.Printf("llama runner stopped with error: %v", llm.StatusWriter.LastErrMsg)
|
||||
if err != nil {
|
||||
log.Printf("llama runner stopped with error: %v", err)
|
||||
} else {
|
||||
log.Print("llama runner stopped successfully")
|
||||
}
|
||||
@@ -509,72 +428,111 @@ func (llm *llama) SetOptions(opts api.Options) {
|
||||
llm.Options = opts
|
||||
}
|
||||
|
||||
type prediction struct {
|
||||
type GenerationSettings struct {
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
IgnoreEOS bool `json:"ignore_eos"`
|
||||
LogitBias []interface{} `json:"logit_bias"`
|
||||
Mirostat int `json:"mirostat"`
|
||||
MirostatEta float64 `json:"mirostat_eta"`
|
||||
MirostatTau float64 `json:"mirostat_tau"`
|
||||
Model string `json:"model"`
|
||||
NCtx int `json:"n_ctx"`
|
||||
NKeep int `json:"n_keep"`
|
||||
NPredict int `json:"n_predict"`
|
||||
NProbs int `json:"n_probs"`
|
||||
PenalizeNl bool `json:"penalize_nl"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
RepeatLastN int `json:"repeat_last_n"`
|
||||
RepeatPenalty float64 `json:"repeat_penalty"`
|
||||
Seed uint32 `json:"seed"`
|
||||
Stop []string `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
Temp float64 `json:"temp"`
|
||||
TfsZ float64 `json:"tfs_z"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float64 `json:"top_p"`
|
||||
TypicalP float64 `json:"typical_p"`
|
||||
}
|
||||
|
||||
type Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
|
||||
type Prediction struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
Timings `json:"timings"`
|
||||
}
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
type PredictRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Stream bool `json:"stream"`
|
||||
NPredict int `json:"n_predict"`
|
||||
NKeep int `json:"n_keep"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float32 `json:"top_p"`
|
||||
TfsZ float32 `json:"tfs_z"`
|
||||
TypicalP float32 `json:"typical_p"`
|
||||
RepeatLastN int `json:"repeat_last_n"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
Mirostat int `json:"mirostat"`
|
||||
MirostatTau float32 `json:"mirostat_tau"`
|
||||
MirostatEta float32 `json:"mirostat_eta"`
|
||||
PenalizeNl bool `json:"penalize_nl"`
|
||||
Seed int `json:"seed"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
}
|
||||
|
||||
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
|
||||
const maxBufferSize = 512 * 1000 // 512KB
|
||||
|
||||
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||
prevConvo, err := llm.Decode(ctx, prevContext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove leading spaces from prevConvo if present
|
||||
prevConvo = strings.TrimPrefix(prevConvo, " ")
|
||||
|
||||
var nextContext strings.Builder
|
||||
nextContext.WriteString(prevConvo)
|
||||
nextContext.WriteString(prompt)
|
||||
|
||||
request := map[string]any{
|
||||
"prompt": nextContext.String(),
|
||||
"stream": true,
|
||||
"n_predict": llm.NumPredict,
|
||||
"n_keep": llm.NumKeep,
|
||||
"temperature": llm.Temperature,
|
||||
"top_k": llm.TopK,
|
||||
"top_p": llm.TopP,
|
||||
"tfs_z": llm.TFSZ,
|
||||
"typical_p": llm.TypicalP,
|
||||
"repeat_last_n": llm.RepeatLastN,
|
||||
"repeat_penalty": llm.RepeatPenalty,
|
||||
"presence_penalty": llm.PresencePenalty,
|
||||
"frequency_penalty": llm.FrequencyPenalty,
|
||||
"mirostat": llm.Mirostat,
|
||||
"mirostat_tau": llm.MirostatTau,
|
||||
"mirostat_eta": llm.MirostatEta,
|
||||
"penalize_nl": llm.PenalizeNewline,
|
||||
"seed": llm.Seed,
|
||||
"stop": llm.Stop,
|
||||
}
|
||||
|
||||
if format == "json" {
|
||||
request["grammar"] = jsonGrammar
|
||||
}
|
||||
|
||||
// Handling JSON marshaling with special characters unescaped.
|
||||
buffer := &bytes.Buffer{}
|
||||
enc := json.NewEncoder(buffer)
|
||||
enc.SetEscapeHTML(false)
|
||||
|
||||
if err := enc.Encode(request); err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %v", err)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
||||
predReq := PredictRequest{
|
||||
Prompt: nextContext.String(),
|
||||
Stream: true,
|
||||
NPredict: llm.NumPredict,
|
||||
NKeep: llm.NumKeep,
|
||||
Temperature: llm.Temperature,
|
||||
TopK: llm.TopK,
|
||||
TopP: llm.TopP,
|
||||
TfsZ: llm.TFSZ,
|
||||
TypicalP: llm.TypicalP,
|
||||
RepeatLastN: llm.RepeatLastN,
|
||||
RepeatPenalty: llm.RepeatPenalty,
|
||||
PresencePenalty: llm.PresencePenalty,
|
||||
FrequencyPenalty: llm.FrequencyPenalty,
|
||||
Mirostat: llm.Mirostat,
|
||||
MirostatTau: llm.MirostatTau,
|
||||
MirostatEta: llm.MirostatEta,
|
||||
PenalizeNl: llm.PenalizeNewline,
|
||||
Seed: llm.Seed,
|
||||
Stop: llm.Stop,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(predReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling data: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating POST request: %v", err)
|
||||
}
|
||||
@@ -605,14 +563,16 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||
// This handles the request cancellation
|
||||
return ctx.Err()
|
||||
default:
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
|
||||
var p prediction
|
||||
if err := json.Unmarshal(evt, &p); err != nil {
|
||||
// Read data from the server-side event stream
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
evt := line[6:]
|
||||
var p Prediction
|
||||
if err := json.Unmarshal([]byte(evt), &p); err != nil {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||
}
|
||||
|
||||
@@ -630,10 +590,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||
fn(api.GenerateResponse{
|
||||
Done: true,
|
||||
Context: embd,
|
||||
PromptEvalCount: p.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
||||
EvalCount: p.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
||||
PromptEvalCount: p.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(p.PromptMS),
|
||||
EvalCount: p.PredictedN,
|
||||
EvalDuration: parseDurationMs(p.PredictedMS),
|
||||
})
|
||||
|
||||
return nil
|
||||
@@ -643,14 +603,6 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if strings.Contains(err.Error(), "unexpected EOF") {
|
||||
// this means the llama runner subprocess crashed
|
||||
llm.Close()
|
||||
if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
|
||||
return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
|
||||
}
|
||||
return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
|
||||
}
|
||||
return fmt.Errorf("error reading llm response: %v", err)
|
||||
}
|
||||
|
||||
@@ -747,6 +699,9 @@ func (llm *llama) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", fmt.Errorf("unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
// decoded content contains a leading whitespace
|
||||
decoded.Content, _ = strings.CutPrefix(decoded.Content, "")
|
||||
|
||||
return decoded.Content, nil
|
||||
}
|
||||
|
||||
|
61
llm/llm.go
61
llm/llm.go
@@ -10,11 +10,10 @@ import (
|
||||
"github.com/pbnjay/memory"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/format"
|
||||
)
|
||||
|
||||
type LLM interface {
|
||||
Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
|
||||
Predict(context.Context, []int, string, func(api.GenerateResponse)) error
|
||||
Embedding(context.Context, string) ([]float64, error)
|
||||
Encode(context.Context, string) ([]int, error)
|
||||
Decode(context.Context, []int) (string, error)
|
||||
@@ -56,39 +55,45 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
|
||||
opts.NumGPU = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var requiredMemory int64
|
||||
var f16Multiplier int64 = 2
|
||||
|
||||
switch ggml.ModelType() {
|
||||
case "3B", "7B":
|
||||
requiredMemory = 8 * format.GigaByte
|
||||
case "13B":
|
||||
requiredMemory = 16 * format.GigaByte
|
||||
case "30B", "34B", "40B":
|
||||
requiredMemory = 32 * format.GigaByte
|
||||
case "65B", "70B":
|
||||
requiredMemory = 64 * format.GigaByte
|
||||
case "180B":
|
||||
requiredMemory = 128 * format.GigaByte
|
||||
f16Multiplier = 4
|
||||
totalResidentMemory := memory.TotalMemory()
|
||||
switch ggml.ModelType() {
|
||||
case "3B", "7B":
|
||||
if ggml.FileType() == "F16" && totalResidentMemory < 16*1000*1000 {
|
||||
return nil, fmt.Errorf("F16 model requires at least 16 GB of memory")
|
||||
} else if totalResidentMemory < 8*1000*1000 {
|
||||
return nil, fmt.Errorf("model requires at least 8 GB of memory")
|
||||
}
|
||||
|
||||
systemMemory := int64(memory.TotalMemory())
|
||||
|
||||
if ggml.FileType() == "F16" && requiredMemory*f16Multiplier > systemMemory {
|
||||
return nil, fmt.Errorf("F16 model requires at least %s of total memory", format.HumanBytes(requiredMemory))
|
||||
} else if requiredMemory > systemMemory {
|
||||
return nil, fmt.Errorf("model requires at least %s of total memory", format.HumanBytes(requiredMemory))
|
||||
case "13B":
|
||||
if ggml.FileType() == "F16" && totalResidentMemory < 32*1000*1000 {
|
||||
return nil, fmt.Errorf("F16 model requires at least 32 GB of memory")
|
||||
} else if totalResidentMemory < 16*1000*1000 {
|
||||
return nil, fmt.Errorf("model requires at least 16 GB of memory")
|
||||
}
|
||||
case "30B", "34B", "40B":
|
||||
if ggml.FileType() == "F16" && totalResidentMemory < 64*1000*1000 {
|
||||
return nil, fmt.Errorf("F16 model requires at least 64 GB of memory")
|
||||
} else if totalResidentMemory < 32*1000*1000 {
|
||||
return nil, fmt.Errorf("model requires at least 32 GB of memory")
|
||||
}
|
||||
case "65B", "70B":
|
||||
if ggml.FileType() == "F16" && totalResidentMemory < 128*1000*1000 {
|
||||
return nil, fmt.Errorf("F16 model requires at least 128 GB of memory")
|
||||
} else if totalResidentMemory < 64*1000*1000 {
|
||||
return nil, fmt.Errorf("model requires at least 64 GB of memory")
|
||||
}
|
||||
case "180B":
|
||||
if ggml.FileType() == "F16" && totalResidentMemory < 512*1000*1000 {
|
||||
return nil, fmt.Errorf("F16 model requires at least 512GB of memory")
|
||||
} else if totalResidentMemory < 128*1000*1000 {
|
||||
return nil, fmt.Errorf("model requires at least 128GB of memory")
|
||||
}
|
||||
}
|
||||
|
||||
switch ggml.Name() {
|
||||
case "gguf":
|
||||
// TODO: gguf will load these options automatically from the model binary
|
||||
opts.NumGQA = 0
|
||||
opts.RopeFrequencyBase = 0.0
|
||||
opts.RopeFrequencyScale = 0.0
|
||||
opts.NumGQA = 0 // TODO: remove this when llama.cpp runners differ enough to need separate newLlama functions
|
||||
return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
|
||||
case "ggml", "ggmf", "ggjt", "ggla":
|
||||
return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
|
||||
|
@@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) {
|
||||
command.Args = string(fields[1])
|
||||
// copy command for validation
|
||||
modelCommand = command
|
||||
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "ADAPTER":
|
||||
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED", "ADAPTER":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
command.Args = string(fields[1])
|
||||
case "PARAMETER":
|
||||
@@ -51,8 +51,6 @@ func Parse(reader io.Reader) ([]Command, error) {
|
||||
|
||||
command.Name = string(fields[0])
|
||||
command.Args = string(fields[1])
|
||||
case "EMBED":
|
||||
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
|
||||
default:
|
||||
if !bytes.HasPrefix(fields[0], []byte("#")) {
|
||||
// log a warning for unknown commands
|
||||
|
@@ -291,7 +291,7 @@ func OptionShowDescriptionAtLineEnd() Option {
|
||||
}
|
||||
}
|
||||
|
||||
var defaultTheme = Theme{Saucer: "█", SaucerPadding: " ", BarStart: "▕", BarEnd: "▏"}
|
||||
var defaultTheme = Theme{Saucer: "█", SaucerPadding: " ", BarStart: "|", BarEnd: "|"}
|
||||
|
||||
// NewOptions constructs a new instance of ProgressBar, with any options you specify
|
||||
func NewOptions(max int, options ...Option) *ProgressBar {
|
||||
|
@@ -1,372 +0,0 @@
|
||||
package readline
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/emirpasic/gods/lists/arraylist"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
type Buffer struct {
|
||||
Pos int
|
||||
Buf *arraylist.List
|
||||
Prompt *Prompt
|
||||
LineWidth int
|
||||
Width int
|
||||
Height int
|
||||
}
|
||||
|
||||
func NewBuffer(prompt *Prompt) (*Buffer, error) {
|
||||
fd := int(os.Stdout.Fd())
|
||||
width, height, err := term.GetSize(fd)
|
||||
if err != nil {
|
||||
fmt.Println("Error getting size:", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lwidth := width - len(prompt.Prompt)
|
||||
if prompt.UseAlt {
|
||||
lwidth = width - len(prompt.AltPrompt)
|
||||
}
|
||||
|
||||
b := &Buffer{
|
||||
Pos: 0,
|
||||
Buf: arraylist.New(),
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
LineWidth: lwidth,
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *Buffer) MoveLeft() {
|
||||
if b.Pos > 0 {
|
||||
if b.Pos%b.LineWidth == 0 {
|
||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
||||
} else {
|
||||
fmt.Print(CursorLeft)
|
||||
}
|
||||
b.Pos -= 1
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) MoveLeftWord() {
|
||||
if b.Pos > 0 {
|
||||
var foundNonspace bool
|
||||
for {
|
||||
v, _ := b.Buf.Get(b.Pos - 1)
|
||||
if v == ' ' {
|
||||
if foundNonspace {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
foundNonspace = true
|
||||
}
|
||||
b.MoveLeft()
|
||||
|
||||
if b.Pos == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) MoveRight() {
|
||||
if b.Pos < b.Size() {
|
||||
b.Pos += 1
|
||||
if b.Pos%b.LineWidth == 0 {
|
||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(b.PromptSize()))
|
||||
} else {
|
||||
fmt.Print(CursorRight)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) MoveRightWord() {
|
||||
if b.Pos < b.Size() {
|
||||
for {
|
||||
b.MoveRight()
|
||||
v, _ := b.Buf.Get(b.Pos)
|
||||
if v == ' ' {
|
||||
break
|
||||
}
|
||||
|
||||
if b.Pos == b.Size() {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) MoveToStart() {
|
||||
if b.Pos > 0 {
|
||||
currLine := b.Pos / b.LineWidth
|
||||
if currLine > 0 {
|
||||
for cnt := 0; cnt < currLine; cnt++ {
|
||||
fmt.Print(CursorUp)
|
||||
}
|
||||
}
|
||||
fmt.Printf(CursorBOL + cursorRightN(b.PromptSize()))
|
||||
b.Pos = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) MoveToEnd() {
|
||||
if b.Pos < b.Size() {
|
||||
currLine := b.Pos / b.LineWidth
|
||||
totalLines := b.Size() / b.LineWidth
|
||||
if currLine < totalLines {
|
||||
for cnt := 0; cnt < totalLines-currLine; cnt++ {
|
||||
fmt.Print(CursorDown)
|
||||
}
|
||||
remainder := b.Size() % b.LineWidth
|
||||
fmt.Printf(CursorBOL + cursorRightN(b.PromptSize()+remainder))
|
||||
} else {
|
||||
fmt.Print(cursorRightN(b.Size() - b.Pos))
|
||||
}
|
||||
|
||||
b.Pos = b.Size()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) Size() int {
|
||||
return b.Buf.Size()
|
||||
}
|
||||
|
||||
func min(n, m int) int {
|
||||
if n > m {
|
||||
return m
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (b *Buffer) PromptSize() int {
|
||||
if b.Prompt.UseAlt {
|
||||
return len(b.Prompt.AltPrompt)
|
||||
}
|
||||
return len(b.Prompt.Prompt)
|
||||
}
|
||||
|
||||
func (b *Buffer) Add(r rune) {
|
||||
if b.Pos == b.Buf.Size() {
|
||||
fmt.Printf("%c", r)
|
||||
b.Buf.Add(r)
|
||||
b.Pos += 1
|
||||
if b.Pos > 0 && b.Pos%b.LineWidth == 0 {
|
||||
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("%c", r)
|
||||
b.Buf.Insert(b.Pos, r)
|
||||
b.Pos += 1
|
||||
if b.Pos > 0 && b.Pos%b.LineWidth == 0 {
|
||||
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||
}
|
||||
b.drawRemaining()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) drawRemaining() {
|
||||
var place int
|
||||
remainingText := b.StringN(b.Pos)
|
||||
if b.Pos > 0 {
|
||||
place = b.Pos % b.LineWidth
|
||||
}
|
||||
fmt.Print(CursorHide)
|
||||
|
||||
// render the rest of the current line
|
||||
currLine := remainingText[:min(b.LineWidth-place, len(remainingText))]
|
||||
if len(currLine) > 0 {
|
||||
fmt.Printf(ClearToEOL + currLine)
|
||||
fmt.Print(cursorLeftN(len(currLine)))
|
||||
} else {
|
||||
fmt.Print(ClearToEOL)
|
||||
}
|
||||
|
||||
// render the other lines
|
||||
if len(remainingText) > len(currLine) {
|
||||
remaining := []rune(remainingText[len(currLine):])
|
||||
var totalLines int
|
||||
for i, c := range remaining {
|
||||
if i%b.LineWidth == 0 {
|
||||
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||
totalLines += 1
|
||||
}
|
||||
fmt.Printf("%c", c)
|
||||
}
|
||||
fmt.Print(ClearToEOL)
|
||||
fmt.Print(cursorUpN(totalLines))
|
||||
fmt.Printf(CursorBOL + cursorRightN(b.Width-len(currLine)))
|
||||
}
|
||||
|
||||
fmt.Print(CursorShow)
|
||||
}
|
||||
|
||||
func (b *Buffer) Remove() {
|
||||
if b.Buf.Size() > 0 && b.Pos > 0 {
|
||||
if b.Pos%b.LineWidth == 0 {
|
||||
// if the user backspaces over the word boundary, do this magic to clear the line
|
||||
// and move to the end of the previous line
|
||||
fmt.Printf(CursorBOL + ClearToEOL)
|
||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width) + " " + CursorLeft)
|
||||
} else {
|
||||
fmt.Printf(CursorLeft + " " + CursorLeft)
|
||||
}
|
||||
|
||||
var eraseExtraLine bool
|
||||
if (b.Size()-1)%b.LineWidth == 0 {
|
||||
eraseExtraLine = true
|
||||
}
|
||||
|
||||
b.Pos -= 1
|
||||
b.Buf.Remove(b.Pos)
|
||||
|
||||
if b.Pos < b.Size() {
|
||||
b.drawRemaining()
|
||||
// this erases a line which is left over when backspacing in the middle of a line and there
|
||||
// are trailing characters which go over the line width boundary
|
||||
if eraseExtraLine {
|
||||
remainingLines := (b.Size() - b.Pos) / b.LineWidth
|
||||
fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
|
||||
place := b.Pos % b.LineWidth
|
||||
fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.Prompt)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) Delete() {
|
||||
if b.Size() > 0 && b.Pos < b.Size() {
|
||||
b.Buf.Remove(b.Pos)
|
||||
b.drawRemaining()
|
||||
if b.Size()%b.LineWidth == 0 {
|
||||
if b.Pos != b.Size() {
|
||||
remainingLines := (b.Size() - b.Pos) / b.LineWidth
|
||||
fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL)
|
||||
place := b.Pos % b.LineWidth
|
||||
fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.Prompt)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) DeleteBefore() {
|
||||
if b.Pos > 0 {
|
||||
for cnt := b.Pos - 1; cnt >= 0; cnt-- {
|
||||
b.Remove()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) DeleteRemaining() {
|
||||
if b.Size() > 0 && b.Pos < b.Size() {
|
||||
charsToDel := b.Size() - b.Pos
|
||||
for cnt := 0; cnt < charsToDel; cnt++ {
|
||||
b.Delete()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) DeleteWord() {
|
||||
if b.Buf.Size() > 0 && b.Pos > 0 {
|
||||
var foundNonspace bool
|
||||
for {
|
||||
v, _ := b.Buf.Get(b.Pos - 1)
|
||||
if v == ' ' {
|
||||
if !foundNonspace {
|
||||
b.Remove()
|
||||
} else {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
foundNonspace = true
|
||||
b.Remove()
|
||||
}
|
||||
|
||||
if b.Pos == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) ClearScreen() {
|
||||
fmt.Printf(ClearScreen + CursorReset + b.Prompt.Prompt)
|
||||
if b.IsEmpty() {
|
||||
ph := b.Prompt.Placeholder
|
||||
fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault)
|
||||
} else {
|
||||
currPos := b.Pos
|
||||
b.Pos = 0
|
||||
b.drawRemaining()
|
||||
fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.Prompt)))
|
||||
if currPos > 0 {
|
||||
targetLine := currPos / b.LineWidth
|
||||
if targetLine > 0 {
|
||||
for cnt := 0; cnt < targetLine; cnt++ {
|
||||
fmt.Print(CursorDown)
|
||||
}
|
||||
}
|
||||
remainder := currPos % b.LineWidth
|
||||
if remainder > 0 {
|
||||
fmt.Print(cursorRightN(remainder))
|
||||
}
|
||||
if currPos%b.LineWidth == 0 {
|
||||
fmt.Printf(CursorBOL + b.Prompt.AltPrompt)
|
||||
}
|
||||
}
|
||||
b.Pos = currPos
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) IsEmpty() bool {
|
||||
return b.Buf.Empty()
|
||||
}
|
||||
|
||||
func (b *Buffer) Replace(r []rune) {
|
||||
b.Pos = 0
|
||||
b.Buf.Clear()
|
||||
fmt.Printf(ClearLine + CursorBOL + b.Prompt.Prompt)
|
||||
for _, c := range r {
|
||||
b.Add(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) String() string {
|
||||
return b.StringN(0)
|
||||
}
|
||||
|
||||
func (b *Buffer) StringN(n int) string {
|
||||
return b.StringNM(n, 0)
|
||||
}
|
||||
|
||||
func (b *Buffer) StringNM(n, m int) string {
|
||||
var s string
|
||||
if m == 0 {
|
||||
m = b.Size()
|
||||
}
|
||||
for cnt := n; cnt < m; cnt++ {
|
||||
c, _ := b.Buf.Get(cnt)
|
||||
s += string(c.(rune))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func cursorLeftN(n int) string {
|
||||
return fmt.Sprintf(CursorLeftN, n)
|
||||
}
|
||||
|
||||
func cursorRightN(n int) string {
|
||||
return fmt.Sprintf(CursorRightN, n)
|
||||
}
|
||||
|
||||
func cursorUpN(n int) string {
|
||||
return fmt.Sprintf(CursorUpN, n)
|
||||
}
|
||||
|
||||
func cursorDownN(n int) string {
|
||||
return fmt.Sprintf(CursorDownN, n)
|
||||
}
|
@@ -1,17 +0,0 @@
|
||||
package readline
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInterrupt = errors.New("Interrupt")
|
||||
)
|
||||
|
||||
type InterruptError struct {
|
||||
Line []rune
|
||||
}
|
||||
|
||||
func (*InterruptError) Error() string {
|
||||
return "Interrupted"
|
||||
}
|
@@ -1,152 +0,0 @@
|
||||
package readline
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/emirpasic/gods/lists/arraylist"
|
||||
)
|
||||
|
||||
type History struct {
|
||||
Buf *arraylist.List
|
||||
Autosave bool
|
||||
Pos int
|
||||
Limit int
|
||||
Filename string
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
func NewHistory() (*History, error) {
|
||||
h := &History{
|
||||
Buf: arraylist.New(),
|
||||
Limit: 100, //resizeme
|
||||
Autosave: true,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
err := h.Init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *History) Init() error {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
path := filepath.Join(home, ".ollama", "history")
|
||||
h.Filename = path
|
||||
|
||||
//todo check if the file exists
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0600)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
r := bufio.NewReader(f)
|
||||
for {
|
||||
line, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
h.Add([]rune(line))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *History) Add(l []rune) {
|
||||
h.Buf.Add(l)
|
||||
h.Compact()
|
||||
h.Pos = h.Size()
|
||||
if h.Autosave {
|
||||
h.Save()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *History) Compact() {
|
||||
s := h.Buf.Size()
|
||||
if s > h.Limit {
|
||||
for cnt := 0; cnt < s-h.Limit; cnt++ {
|
||||
h.Buf.Remove(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *History) Clear() {
|
||||
h.Buf.Clear()
|
||||
}
|
||||
|
||||
func (h *History) Prev() []rune {
|
||||
var line []rune
|
||||
if h.Pos > 0 {
|
||||
h.Pos -= 1
|
||||
}
|
||||
v, _ := h.Buf.Get(h.Pos)
|
||||
line, _ = v.([]rune)
|
||||
return line
|
||||
}
|
||||
|
||||
func (h *History) Next() []rune {
|
||||
var line []rune
|
||||
if h.Pos < h.Buf.Size() {
|
||||
h.Pos += 1
|
||||
v, _ := h.Buf.Get(h.Pos)
|
||||
line, _ = v.([]rune)
|
||||
}
|
||||
return line
|
||||
}
|
||||
|
||||
func (h *History) Size() int {
|
||||
return h.Buf.Size()
|
||||
}
|
||||
|
||||
func (h *History) Save() error {
|
||||
if !h.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
tmpFile := h.Filename + ".tmp"
|
||||
|
||||
f, err := os.OpenFile(tmpFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
buf := bufio.NewWriter(f)
|
||||
for cnt := 0; cnt < h.Size(); cnt++ {
|
||||
v, _ := h.Buf.Get(cnt)
|
||||
line, _ := v.([]rune)
|
||||
buf.WriteString(string(line) + "\n")
|
||||
}
|
||||
buf.Flush()
|
||||
f.Close()
|
||||
|
||||
if err = os.Rename(tmpFile, h.Filename); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@@ -1,254 +0,0 @@
|
||||
package readline
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type Prompt struct {
|
||||
Prompt string
|
||||
AltPrompt string
|
||||
Placeholder string
|
||||
AltPlaceholder string
|
||||
UseAlt bool
|
||||
}
|
||||
|
||||
type Terminal struct {
|
||||
outchan chan rune
|
||||
}
|
||||
|
||||
type Instance struct {
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
}
|
||||
|
||||
func New(prompt Prompt) (*Instance, error) {
|
||||
term, err := NewTerminal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
history, err := NewHistory()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Instance{
|
||||
Prompt: &prompt,
|
||||
Terminal: term,
|
||||
History: history,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (i *Instance) Readline() (string, error) {
|
||||
prompt := i.Prompt.Prompt
|
||||
if i.Prompt.UseAlt {
|
||||
prompt = i.Prompt.AltPrompt
|
||||
}
|
||||
fmt.Print(prompt)
|
||||
|
||||
fd := int(syscall.Stdin)
|
||||
termios, err := SetRawMode(fd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer UnsetRawMode(fd, termios)
|
||||
|
||||
buf, _ := NewBuffer(i.Prompt)
|
||||
|
||||
var esc bool
|
||||
var escex bool
|
||||
var metaDel bool
|
||||
var pasteMode PasteMode
|
||||
|
||||
var currentLineBuf []rune
|
||||
|
||||
for {
|
||||
if buf.IsEmpty() {
|
||||
ph := i.Prompt.Placeholder
|
||||
if i.Prompt.UseAlt {
|
||||
ph = i.Prompt.AltPlaceholder
|
||||
}
|
||||
fmt.Printf(ColorGrey + ph + fmt.Sprintf(CursorLeftN, len(ph)) + ColorDefault)
|
||||
}
|
||||
|
||||
r, err := i.Terminal.Read()
|
||||
|
||||
if buf.IsEmpty() {
|
||||
fmt.Print(ClearToEOL)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", io.EOF
|
||||
}
|
||||
|
||||
if escex {
|
||||
escex = false
|
||||
|
||||
switch r {
|
||||
case KeyUp:
|
||||
if i.History.Pos > 0 {
|
||||
if i.History.Pos == i.History.Size() {
|
||||
currentLineBuf = []rune(buf.String())
|
||||
}
|
||||
buf.Replace(i.History.Prev())
|
||||
}
|
||||
case KeyDown:
|
||||
if i.History.Pos < i.History.Size() {
|
||||
buf.Replace(i.History.Next())
|
||||
if i.History.Pos == i.History.Size() {
|
||||
buf.Replace(currentLineBuf)
|
||||
}
|
||||
}
|
||||
case KeyLeft:
|
||||
buf.MoveLeft()
|
||||
case KeyRight:
|
||||
buf.MoveRight()
|
||||
case CharBracketedPaste:
|
||||
var code string
|
||||
for cnt := 0; cnt < 3; cnt++ {
|
||||
r, err = i.Terminal.Read()
|
||||
if err != nil {
|
||||
return "", io.EOF
|
||||
}
|
||||
|
||||
code += string(r)
|
||||
}
|
||||
if code == CharBracketedPasteStart {
|
||||
pasteMode = PasteModeStart
|
||||
} else if code == CharBracketedPasteEnd {
|
||||
pasteMode = PasteModeEnd
|
||||
}
|
||||
case KeyDel:
|
||||
if buf.Size() > 0 {
|
||||
buf.Delete()
|
||||
}
|
||||
metaDel = true
|
||||
case MetaStart:
|
||||
buf.MoveToStart()
|
||||
case MetaEnd:
|
||||
buf.MoveToEnd()
|
||||
default:
|
||||
// skip any keys we don't know about
|
||||
continue
|
||||
}
|
||||
continue
|
||||
} else if esc {
|
||||
esc = false
|
||||
|
||||
switch r {
|
||||
case 'b':
|
||||
buf.MoveLeftWord()
|
||||
case 'f':
|
||||
buf.MoveRightWord()
|
||||
case CharEscapeEx:
|
||||
escex = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch r {
|
||||
case CharNull:
|
||||
continue
|
||||
case CharEsc:
|
||||
esc = true
|
||||
case CharInterrupt:
|
||||
return "", ErrInterrupt
|
||||
case CharLineStart:
|
||||
buf.MoveToStart()
|
||||
case CharLineEnd:
|
||||
buf.MoveToEnd()
|
||||
case CharBackward:
|
||||
buf.MoveLeft()
|
||||
case CharForward:
|
||||
buf.MoveRight()
|
||||
case CharBackspace, CharCtrlH:
|
||||
buf.Remove()
|
||||
case CharTab:
|
||||
// todo: convert back to real tabs
|
||||
for cnt := 0; cnt < 8; cnt++ {
|
||||
buf.Add(' ')
|
||||
}
|
||||
case CharDelete:
|
||||
if buf.Size() > 0 {
|
||||
buf.Delete()
|
||||
} else {
|
||||
return "", io.EOF
|
||||
}
|
||||
case CharKill:
|
||||
buf.DeleteRemaining()
|
||||
case CharCtrlU:
|
||||
buf.DeleteBefore()
|
||||
case CharCtrlL:
|
||||
buf.ClearScreen()
|
||||
case CharCtrlW:
|
||||
buf.DeleteWord()
|
||||
case CharEnter:
|
||||
output := buf.String()
|
||||
if output != "" {
|
||||
i.History.Add([]rune(output))
|
||||
}
|
||||
buf.MoveToEnd()
|
||||
fmt.Println()
|
||||
switch pasteMode {
|
||||
case PasteModeStart:
|
||||
output = `"""` + output
|
||||
case PasteModeEnd:
|
||||
output = output + `"""`
|
||||
}
|
||||
return output, nil
|
||||
default:
|
||||
if metaDel {
|
||||
metaDel = false
|
||||
continue
|
||||
}
|
||||
if r >= CharSpace || r == CharEnter {
|
||||
buf.Add(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Instance) HistoryEnable() {
|
||||
i.History.Enabled = true
|
||||
}
|
||||
|
||||
func (i *Instance) HistoryDisable() {
|
||||
i.History.Enabled = false
|
||||
}
|
||||
|
||||
func NewTerminal() (*Terminal, error) {
|
||||
t := &Terminal{
|
||||
outchan: make(chan rune),
|
||||
}
|
||||
|
||||
go t.ioloop()
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *Terminal) ioloop() {
|
||||
buf := bufio.NewReader(os.Stdin)
|
||||
|
||||
for {
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
close(t.outchan)
|
||||
break
|
||||
}
|
||||
t.outchan <- r
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Terminal) Read() (rune, error) {
|
||||
r, ok := <-t.outchan
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
@@ -1,36 +0,0 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || (linux && !appengine) || netbsd || openbsd || os400 || solaris
|
||||
|
||||
package readline
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type Termios syscall.Termios
|
||||
|
||||
func SetRawMode(fd int) (*Termios, error) {
|
||||
termios, err := getTermios(fd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newTermios := *termios
|
||||
newTermios.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON
|
||||
newTermios.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN
|
||||
newTermios.Cflag &^= syscall.CSIZE | syscall.PARENB
|
||||
newTermios.Cflag |= syscall.CS8
|
||||
newTermios.Cc[syscall.VMIN] = 1
|
||||
newTermios.Cc[syscall.VTIME] = 0
|
||||
|
||||
return termios, setTermios(fd, &newTermios)
|
||||
}
|
||||
|
||||
func UnsetRawMode(fd int, termios *Termios) error {
|
||||
return setTermios(fd, termios)
|
||||
}
|
||||
|
||||
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||
func IsTerminal(fd int) bool {
|
||||
_, err := getTermios(fd)
|
||||
return err == nil
|
||||
}
|
@@ -1,25 +0,0 @@
|
||||
//go:build darwin || freebsd || netbsd || openbsd
|
||||
|
||||
package readline
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func getTermios(fd int) (*Termios, error) {
|
||||
termios := new(Termios)
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||
if err != 0 {
|
||||
return nil, err
|
||||
}
|
||||
return termios, nil
|
||||
}
|
||||
|
||||
func setTermios(fd int, termios *Termios) error {
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||
if err != 0 {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,28 +0,0 @@
|
||||
//go:build linux || solaris
|
||||
|
||||
package readline
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const tcgets = 0x5401
|
||||
const tcsets = 0x5402
|
||||
|
||||
func getTermios(fd int) (*Termios, error) {
|
||||
termios := new(Termios)
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||
if err != 0 {
|
||||
return nil, err
|
||||
}
|
||||
return termios, nil
|
||||
}
|
||||
|
||||
func setTermios(fd int, termios *Termios) error {
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||
if err != 0 {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,62 +0,0 @@
|
||||
package readline
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
enableLineInput = 2
|
||||
enableWindowInput = 8
|
||||
enableMouseInput = 16
|
||||
enableInsertMode = 32
|
||||
enableQuickEditMode = 64
|
||||
enableExtendedFlags = 128
|
||||
enableProcessedOutput = 1
|
||||
enableWrapAtEolOutput = 2
|
||||
enableAutoPosition = 256 // Cursor position is not affected by writing data to the console.
|
||||
enableEchoInput = 4 // Characters are written to the console as they're read.
|
||||
enableProcessedInput = 1 // Enables input processing (like recognizing Ctrl+C).
|
||||
)
|
||||
|
||||
var kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
|
||||
var (
|
||||
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
|
||||
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
|
||||
)
|
||||
|
||||
type State struct {
|
||||
mode uint32
|
||||
}
|
||||
|
||||
// IsTerminal checks if the given file descriptor is associated with a terminal
|
||||
func IsTerminal(fd int) bool {
|
||||
var st uint32
|
||||
r, _, e := syscall.SyscallN(procGetConsoleMode.Addr(), uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
|
||||
// if the call succeeds and doesn't produce an error, it's a terminal
|
||||
return r != 0 && e == 0
|
||||
}
|
||||
|
||||
func SetRawMode(fd int) (*State, error) {
|
||||
var st uint32
|
||||
// retrieve the current mode of the terminal
|
||||
_, _, e := syscall.SyscallN(procGetConsoleMode.Addr(), uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
|
||||
if e != 0 {
|
||||
return nil, error(e)
|
||||
}
|
||||
// modify the mode to set it to raw
|
||||
raw := st &^ (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput)
|
||||
// apply the new mode to the terminal
|
||||
_, _, e = syscall.SyscallN(procSetConsoleMode.Addr(), uintptr(fd), uintptr(raw), 0)
|
||||
if e != 0 {
|
||||
return nil, error(e)
|
||||
}
|
||||
// return the original state so that it can be restored later
|
||||
return &State{st}, nil
|
||||
}
|
||||
|
||||
func UnsetRawMode(fd int, state *State) error {
|
||||
_, _, err := syscall.SyscallN(procSetConsoleMode.Addr(), uintptr(fd), uintptr(state.mode), 0)
|
||||
return err
|
||||
}
|
@@ -1,86 +0,0 @@
|
||||
package readline
|
||||
|
||||
const (
|
||||
CharNull = 0
|
||||
CharLineStart = 1
|
||||
CharBackward = 2
|
||||
CharInterrupt = 3
|
||||
CharDelete = 4
|
||||
CharLineEnd = 5
|
||||
CharForward = 6
|
||||
CharBell = 7
|
||||
CharCtrlH = 8
|
||||
CharTab = 9
|
||||
CharCtrlJ = 10
|
||||
CharKill = 11
|
||||
CharCtrlL = 12
|
||||
CharEnter = 13
|
||||
CharNext = 14
|
||||
CharPrev = 16
|
||||
CharBckSearch = 18
|
||||
CharFwdSearch = 19
|
||||
CharTranspose = 20
|
||||
CharCtrlU = 21
|
||||
CharCtrlW = 23
|
||||
CharCtrlY = 25
|
||||
CharCtrlZ = 26
|
||||
CharEsc = 27
|
||||
CharSpace = 32
|
||||
CharEscapeEx = 91
|
||||
CharBackspace = 127
|
||||
)
|
||||
|
||||
const (
|
||||
KeyDel = 51
|
||||
KeyUp = 65
|
||||
KeyDown = 66
|
||||
KeyRight = 67
|
||||
KeyLeft = 68
|
||||
MetaEnd = 70
|
||||
MetaStart = 72
|
||||
)
|
||||
|
||||
const (
|
||||
CursorUp = "\033[1A"
|
||||
CursorDown = "\033[1B"
|
||||
CursorRight = "\033[1C"
|
||||
CursorLeft = "\033[1D"
|
||||
|
||||
CursorSave = "\033[s"
|
||||
CursorRestore = "\033[u"
|
||||
|
||||
CursorUpN = "\033[%dA"
|
||||
CursorDownN = "\033[%dB"
|
||||
CursorRightN = "\033[%dC"
|
||||
CursorLeftN = "\033[%dD"
|
||||
|
||||
CursorEOL = "\033[E"
|
||||
CursorBOL = "\033[1G"
|
||||
CursorHide = "\033[?25l"
|
||||
CursorShow = "\033[?25h"
|
||||
|
||||
ClearToEOL = "\033[K"
|
||||
ClearLine = "\033[2K"
|
||||
ClearScreen = "\033[2J"
|
||||
CursorReset = "\033[0;0f"
|
||||
|
||||
ColorGrey = "\033[38;5;245m"
|
||||
ColorDefault = "\033[0m"
|
||||
|
||||
StartBracketedPaste = "\033[?2004h"
|
||||
EndBracketedPaste = "\033[?2004l"
|
||||
)
|
||||
|
||||
const (
|
||||
CharBracketedPaste = 50
|
||||
CharBracketedPasteStart = "00~"
|
||||
CharBracketedPasteEnd = "01~"
|
||||
)
|
||||
|
||||
type PasteMode int
|
||||
|
||||
const (
|
||||
PastModeOff = iota
|
||||
PasteModeStart
|
||||
PasteModeEnd
|
||||
)
|
@@ -26,8 +26,7 @@ require() {
|
||||
|
||||
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
||||
|
||||
ARCH=$(uname -m)
|
||||
case "$ARCH" in
|
||||
case "$(uname -m)" in
|
||||
x86_64) ARCH="amd64" ;;
|
||||
aarch64|arm64) ARCH="arm64" ;;
|
||||
*) error "Unsupported architecture: $ARCH" ;;
|
||||
@@ -63,10 +62,7 @@ status "Installing ollama to $BINDIR..."
|
||||
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
||||
$SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama
|
||||
|
||||
install_success() {
|
||||
status 'The Ollama API is now available at 0.0.0.0:11434.'
|
||||
status 'Install complete. Run "ollama" from the command line.'
|
||||
}
|
||||
install_success() { status 'Install complete. Run "ollama" from the command line.'; }
|
||||
trap install_success EXIT
|
||||
|
||||
# Everything from this point onwards is optional.
|
||||
@@ -77,9 +73,6 @@ configure_systemd() {
|
||||
$SUDO useradd -r -s /bin/false -m -d /usr/share/ollama ollama
|
||||
fi
|
||||
|
||||
status "Adding current user to ollama group..."
|
||||
$SUDO usermod -a -G ollama $(whoami)
|
||||
|
||||
status "Creating ollama systemd service..."
|
||||
cat <<EOF | $SUDO tee /etc/systemd/system/ollama.service >/dev/null
|
||||
[Unit]
|
||||
@@ -92,6 +85,7 @@ User=ollama
|
||||
Group=ollama
|
||||
Restart=always
|
||||
RestartSec=3
|
||||
Environment="HOME=/usr/share/ollama"
|
||||
Environment="PATH=$PATH"
|
||||
|
||||
[Install]
|
||||
@@ -133,7 +127,6 @@ if check_gpu nvidia-smi; then
|
||||
fi
|
||||
|
||||
if ! check_gpu lspci && ! check_gpu lshw; then
|
||||
install_success
|
||||
warning "No NVIDIA GPU detected. Ollama will run in CPU-only mode."
|
||||
exit 0
|
||||
fi
|
||||
@@ -180,7 +173,7 @@ install_cuda_driver_apt() {
|
||||
case $1 in
|
||||
debian)
|
||||
status 'Enabling contrib sources...'
|
||||
$SUDO sed 's/main/contrib/' < /etc/apt/sources.list | $SUDO tee /etc/apt/sources.list.d/contrib.list > /dev/null
|
||||
$SUDO sed 's/main/contrib/' < /etc/apt/sources.list | sudo tee /etc/apt/sources.list.d/contrib.list > /dev/null
|
||||
;;
|
||||
esac
|
||||
|
||||
|
@@ -1,15 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
set -eu
|
||||
|
||||
export VERSION=${VERSION:-0.0.0}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
|
||||
|
||||
docker buildx build \
|
||||
--push \
|
||||
--platform=linux/arm64,linux/amd64 \
|
||||
--build-arg=VERSION \
|
||||
--build-arg=GOFLAGS \
|
||||
-f Dockerfile \
|
||||
-t ollama/ollama -t ollama/ollama:$VERSION \
|
||||
.
|
@@ -91,7 +91,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
|
||||
}
|
||||
|
||||
s := SignatureData{
|
||||
Method: http.MethodGet,
|
||||
Method: "GET",
|
||||
Path: redirectURL.String(),
|
||||
Data: nil,
|
||||
}
|
||||
@@ -103,10 +103,9 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Authorization", sig)
|
||||
resp, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
|
||||
resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, nil)
|
||||
if err != nil {
|
||||
log.Printf("couldn't get token: %q", err)
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
|
@@ -15,7 +15,6 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -89,12 +88,17 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
|
||||
}
|
||||
|
||||
if len(b.Parts) == 0 {
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
|
||||
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
||||
|
||||
var size = b.Total / numDownloadParts
|
||||
@@ -129,6 +133,7 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||
|
||||
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||
defer blobDownloadManager.Delete(b.Digest)
|
||||
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
|
||||
file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
|
||||
@@ -149,26 +154,21 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||
|
||||
i := i
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
for try := 0; try < maxRetries; try++ {
|
||||
w := io.NewOffsetWriter(file, part.StartsAt())
|
||||
err = b.downloadChunk(inner, requestURL, w, part, opts)
|
||||
err := b.downloadChunk(inner, requestURL, w, part, opts)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||
// return immediately if the context is canceled or the device is out of space
|
||||
case errors.Is(err, context.Canceled):
|
||||
return err
|
||||
case err != nil:
|
||||
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], i, try, err)
|
||||
continue
|
||||
default:
|
||||
if try > 0 {
|
||||
log.Printf("%s part %d completed after %d retries", b.Digest[7:19], i, try)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
|
||||
return errors.New("max retries exceeded")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -198,14 +198,14 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
|
||||
headers := make(http.Header)
|
||||
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
|
||||
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
n, err := io.Copy(w, io.TeeReader(resp.Body, b))
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
// rollback progress
|
||||
b.Completed.Add(-n)
|
||||
return err
|
||||
@@ -216,7 +216,7 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
|
||||
return err
|
||||
}
|
||||
|
||||
// return nil or context.Canceled or UnexpectedEOF (resumable)
|
||||
// return nil or context.Canceled
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -306,8 +306,6 @@ type downloadOpts struct {
|
||||
|
||||
const maxRetries = 3
|
||||
|
||||
var errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
||||
fp, err := GetBlobsPath(opts.digest)
|
||||
|
382
server/images.go
382
server/images.go
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/jmorganca/ollama/parser"
|
||||
"github.com/jmorganca/ollama/vector"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
||||
@@ -45,10 +47,12 @@ type Model struct {
|
||||
System string
|
||||
License []string
|
||||
Digest string
|
||||
ConfigDigest string
|
||||
Options map[string]interface{}
|
||||
Embeddings []vector.Embedding
|
||||
}
|
||||
|
||||
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
||||
func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
|
||||
t := m.Template
|
||||
if request.Template != "" {
|
||||
t = request.Template
|
||||
@@ -60,12 +64,20 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
||||
}
|
||||
|
||||
var vars struct {
|
||||
First bool
|
||||
System string
|
||||
Prompt string
|
||||
Embed string
|
||||
|
||||
// deprecated: versions <= 0.0.7 used this to omit the system prompt
|
||||
Context []int
|
||||
}
|
||||
|
||||
vars.First = len(request.Context) == 0
|
||||
vars.System = m.System
|
||||
vars.Prompt = request.Prompt
|
||||
vars.Context = request.Context
|
||||
vars.Embed = embedding
|
||||
|
||||
if request.System != "" {
|
||||
vars.System = request.System
|
||||
@@ -125,7 +137,7 @@ func (m *ManifestV2) GetTotalSize() (total int64) {
|
||||
}
|
||||
|
||||
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
|
||||
fp, err := mp.GetManifestPath()
|
||||
fp, err := mp.GetManifestPath(false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -159,11 +171,12 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
|
||||
model := &Model{
|
||||
Name: mp.GetFullTagname(),
|
||||
ShortName: mp.GetShortTagname(),
|
||||
Digest: digest,
|
||||
Template: "{{ .Prompt }}",
|
||||
License: []string{},
|
||||
Name: mp.GetFullTagname(),
|
||||
ShortName: mp.GetShortTagname(),
|
||||
Digest: digest,
|
||||
ConfigDigest: manifest.Config.Digest,
|
||||
Template: "{{ .Prompt }}",
|
||||
License: []string{},
|
||||
}
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
@@ -177,9 +190,15 @@ func GetModel(name string) (*Model, error) {
|
||||
model.ModelPath = filename
|
||||
model.OriginalModel = layer.From
|
||||
case "application/vnd.ollama.image.embed":
|
||||
// Deprecated in versions > 0.1.2
|
||||
// TODO: remove this warning in a future version
|
||||
log.Print("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open file: %s", filename)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.adapter":
|
||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||
case "application/vnd.ollama.image.template":
|
||||
@@ -246,7 +265,7 @@ func filenameWithPath(path, f string) (string, error) {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
|
||||
func CreateModel(ctx context.Context, workDir, name string, path string, fn func(resp api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
var manifest *ManifestV2
|
||||
@@ -291,11 +310,13 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||
var layers []*LayerReader
|
||||
params := make(map[string][]string)
|
||||
var sourceParams map[string]any
|
||||
embed := EmbeddingParams{fn: fn}
|
||||
for _, c := range commands {
|
||||
log.Printf("[%s] - %s\n", c.Name, c.Args)
|
||||
switch c.Name {
|
||||
case "model":
|
||||
fn(api.ProgressResponse{Status: "looking for model"})
|
||||
embed.model = c.Args
|
||||
|
||||
mp := ParseModelPath(c.Args)
|
||||
mf, _, err := GetManifest(mp)
|
||||
@@ -319,6 +340,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
embed.model = modelFile
|
||||
// create a model from this specified file
|
||||
fn(api.ProgressResponse{Status: "creating model layer"})
|
||||
file, err := os.Open(modelFile)
|
||||
@@ -395,10 +417,16 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newLayer.From = mp.GetShortTagname()
|
||||
newLayer.From = mp.GetNamespaceRepository()
|
||||
layers = append(layers, newLayer)
|
||||
}
|
||||
}
|
||||
case "embed":
|
||||
embedFilePath, err := filenameWithPath(path, c.Args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
embed.files = append(embed.files, embedFilePath)
|
||||
case "adapter":
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
||||
|
||||
@@ -489,8 +517,18 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||
}
|
||||
l.MediaType = "application/vnd.ollama.image.params"
|
||||
layers = append(layers, l)
|
||||
|
||||
// apply these parameters to the embedding options, in case embeddings need to be generated using this model
|
||||
embed.opts = formattedParams
|
||||
}
|
||||
|
||||
// generate the embedding layers
|
||||
embeddingLayers, err := embeddingLayers(workDir, embed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
layers = append(layers, embeddingLayers...)
|
||||
|
||||
digests, err := getLayerDigests(layers)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -534,6 +572,146 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||
return nil
|
||||
}
|
||||
|
||||
type EmbeddingParams struct {
|
||||
model string
|
||||
opts map[string]interface{}
|
||||
files []string // paths to files to embed
|
||||
fn func(resp api.ProgressResponse)
|
||||
}
|
||||
|
||||
// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
|
||||
func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error) {
|
||||
layers := []*LayerReader{}
|
||||
if len(e.files) > 0 {
|
||||
// check if the model is a file path or a model name
|
||||
model, err := GetModel(e.model)
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "couldn't open file") {
|
||||
return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err)
|
||||
}
|
||||
// the model may be a file path, create a model from this file
|
||||
model = &Model{ModelPath: e.model}
|
||||
}
|
||||
|
||||
if err := load(context.Background(), workDir, model, e.opts, defaultSessionDuration); err != nil {
|
||||
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
||||
}
|
||||
|
||||
// this will be used to check if we already have embeddings for a file
|
||||
modelInfo, err := os.Stat(model.ModelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get model file info: %v", err)
|
||||
}
|
||||
|
||||
addedFiles := make(map[string]bool) // keep track of files that have already been added
|
||||
for _, filePattern := range e.files {
|
||||
matchingFiles, err := filepath.Glob(filePattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
|
||||
}
|
||||
|
||||
for _, filePath := range matchingFiles {
|
||||
if addedFiles[filePath] {
|
||||
continue
|
||||
}
|
||||
addedFiles[filePath] = true
|
||||
// check if we already have embeddings for this file path
|
||||
layerIdentifier := fmt.Sprintf("%s:%s:%s:%d", filePath, e.model, modelInfo.ModTime().Format("2006-01-02 15:04:05"), modelInfo.Size())
|
||||
digest, _ := GetSHA256Digest(strings.NewReader(layerIdentifier))
|
||||
existing, err := existingFileEmbeddings(digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check existing embeddings for file %s: %v", filePath, err)
|
||||
}
|
||||
|
||||
// TODO: check file type
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open embed file: %w", err)
|
||||
}
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
data := []string{}
|
||||
for scanner.Scan() {
|
||||
data = append(data, scanner.Text())
|
||||
}
|
||||
f.Close()
|
||||
|
||||
// the digest of the file is set here so that the client knows a new operation is in progress
|
||||
fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
|
||||
|
||||
embeddings := []vector.Embedding{}
|
||||
for i, d := range data {
|
||||
if strings.TrimSpace(d) == "" {
|
||||
continue
|
||||
}
|
||||
e.fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("creating embeddings for file %s", filePath),
|
||||
Digest: fileDigest,
|
||||
Total: int64(len(data) - 1),
|
||||
Completed: int64(i),
|
||||
})
|
||||
if len(existing[d]) > 0 {
|
||||
// already have an embedding for this line
|
||||
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
|
||||
continue
|
||||
}
|
||||
embed, err := loaded.llm.Embedding(context.Background(), d)
|
||||
if err != nil {
|
||||
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
|
||||
continue
|
||||
}
|
||||
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
|
||||
}
|
||||
|
||||
b, err := json.Marshal(embeddings)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode embeddings: %w", err)
|
||||
}
|
||||
r := bytes.NewReader(b)
|
||||
|
||||
layer := &LayerReader{
|
||||
Layer: Layer{
|
||||
MediaType: "application/vnd.ollama.image.embed",
|
||||
Digest: digest,
|
||||
Size: r.Size(),
|
||||
},
|
||||
Reader: r,
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
}
|
||||
}
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// existingFileEmbeddings checks if we already have embeddings for a file and loads them into a look-up map
|
||||
func existingFileEmbeddings(digest string) (map[string][]float64, error) {
|
||||
path, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embeddings blobs path: %w", err)
|
||||
}
|
||||
existingFileEmbeddings := make(map[string][]float64)
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
// already have some embeddings for this file, load embeddings previously generated
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open existing embedding file: %s", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
existing := []vector.Embedding{}
|
||||
if err = json.NewDecoder(file).Decode(&existing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, e := range existing {
|
||||
existingFileEmbeddings[e.Data] = e.Vector
|
||||
}
|
||||
}
|
||||
return existingFileEmbeddings, nil
|
||||
}
|
||||
|
||||
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
|
||||
return slices.DeleteFunc(layers, func(layer *LayerReader) bool {
|
||||
return layer.MediaType == mediaType
|
||||
@@ -549,7 +727,8 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
|
||||
}
|
||||
|
||||
_, err = os.Stat(fp)
|
||||
if os.IsNotExist(err) || force {
|
||||
// note: embed layers are always written since their digest doesnt indicate anything about the contents
|
||||
if os.IsNotExist(err) || force || layer.MediaType == "application/vnd.ollama.image.embed" {
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
|
||||
|
||||
out, err := os.Create(fp)
|
||||
@@ -589,13 +768,10 @@ func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
fp, err := mp.GetManifestPath(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(fp, manifestJSON, 0o644)
|
||||
}
|
||||
|
||||
@@ -707,19 +883,16 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
|
||||
|
||||
func CopyModel(src, dest string) error {
|
||||
srcModelPath := ParseModelPath(src)
|
||||
srcPath, err := srcModelPath.GetManifestPath()
|
||||
srcPath, err := srcModelPath.GetManifestPath(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
destModelPath := ParseModelPath(dest)
|
||||
destPath, err := destModelPath.GetManifestPath()
|
||||
destPath, err := destModelPath.GetManifestPath(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// copy the file
|
||||
input, err := os.ReadFile(srcPath)
|
||||
@@ -882,7 +1055,7 @@ func DeleteModel(name string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
fp, err := mp.GetManifestPath(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -896,27 +1069,51 @@ func DeleteModel(name string) error {
|
||||
}
|
||||
|
||||
func ShowModelfile(model *Model) (string, error) {
|
||||
var mt struct {
|
||||
type modelTemplate struct {
|
||||
*Model
|
||||
From string
|
||||
Parameters map[string][]any
|
||||
From string
|
||||
Params string
|
||||
}
|
||||
|
||||
mt.Parameters = make(map[string][]any)
|
||||
var params []string
|
||||
for k, v := range model.Options {
|
||||
if s, ok := v.([]any); ok {
|
||||
mt.Parameters[k] = s
|
||||
continue
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, val))
|
||||
case int:
|
||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.Itoa(val)))
|
||||
case float64:
|
||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatFloat(val, 'f', 0, 64)))
|
||||
case bool:
|
||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatBool(val)))
|
||||
case []interface{}:
|
||||
for _, nv := range val {
|
||||
switch nval := nv.(type) {
|
||||
case string:
|
||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, nval))
|
||||
case int:
|
||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.Itoa(nval)))
|
||||
case float64:
|
||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatFloat(nval, 'f', 0, 64)))
|
||||
case bool:
|
||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatBool(nval)))
|
||||
default:
|
||||
log.Printf("unknown type: %s", reflect.TypeOf(nv).String())
|
||||
}
|
||||
}
|
||||
default:
|
||||
log.Printf("unknown type: %s", reflect.TypeOf(v).String())
|
||||
}
|
||||
|
||||
mt.Parameters[k] = []any{v}
|
||||
}
|
||||
|
||||
mt.Model = model
|
||||
mt.From = model.ModelPath
|
||||
mt := modelTemplate{
|
||||
Model: model,
|
||||
From: model.OriginalModel,
|
||||
Params: strings.Join(params, "\n"),
|
||||
}
|
||||
|
||||
if model.OriginalModel != "" {
|
||||
mt.From = model.OriginalModel
|
||||
if mt.From == "" {
|
||||
mt.From = model.ModelPath
|
||||
}
|
||||
|
||||
modelFile := `# Modelfile generated by "ollama show"
|
||||
@@ -925,20 +1122,12 @@ func ShowModelfile(model *Model) (string, error) {
|
||||
|
||||
FROM {{ .From }}
|
||||
TEMPLATE """{{ .Template }}"""
|
||||
|
||||
{{- if .System }}
|
||||
SYSTEM """{{ .System }}"""
|
||||
{{- end }}
|
||||
|
||||
{{- range $adapter := .AdapterPaths }}
|
||||
ADAPTER {{ $adapter }}
|
||||
{{- end }}
|
||||
|
||||
{{- range $k, $v := .Parameters }}
|
||||
{{- range $parameter := $v }}
|
||||
PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
|
||||
{{- end }}
|
||||
{{- end }}`
|
||||
{{ .Params }}
|
||||
`
|
||||
for _, l := range mt.Model.AdapterPaths {
|
||||
modelFile += fmt.Sprintf("ADAPTER %s\n", l)
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").Parse(modelFile)
|
||||
if err != nil {
|
||||
@@ -975,7 +1164,46 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
layers = append(layers, &manifest.Config)
|
||||
|
||||
for _, layer := range layers {
|
||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||
exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if exists {
|
||||
fn(api.ProgressResponse{
|
||||
Status: "using existing layer",
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: layer.Size,
|
||||
})
|
||||
log.Printf("Layer %s already exists", layer.Digest)
|
||||
continue
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{
|
||||
Status: "starting upload",
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
})
|
||||
|
||||
location, chunkSize, err := startUpload(ctx, mp, layer, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't start upload: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.HasPrefix(filepath.Base(location.Path), "sha256:") {
|
||||
layer.Digest = filepath.Base(location.Path)
|
||||
fn(api.ProgressResponse{
|
||||
Status: "using existing layer",
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: layer.Size,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if err := uploadBlob(ctx, location, layer, chunkSize, regOpts, fn); err != nil {
|
||||
log.Printf("error uploading blob: %v", err)
|
||||
return err
|
||||
}
|
||||
@@ -992,7 +1220,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
|
||||
resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1082,13 +1310,10 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
fp, err := mp.GetManifestPath(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.WriteFile(fp, manifestJSON, 0o644)
|
||||
if err != nil {
|
||||
@@ -1114,12 +1339,22 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptio
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
|
||||
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't get manifest: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, fmt.Errorf("model not found")
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var m *ManifestV2
|
||||
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
||||
return nil, err
|
||||
@@ -1163,7 +1398,24 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||
}
|
||||
|
||||
// Function to check if a blob already exists in the Docker registry
|
||||
func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest)
|
||||
|
||||
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't check for blob: %v", err)
|
||||
return false, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check for success: If the blob exists, the Docker registry will respond with a 200 OK
|
||||
return resp.StatusCode < http.StatusBadRequest, nil
|
||||
}
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
var status string
|
||||
for try := 0; try < maxRetries; try++ {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
if err != nil {
|
||||
@@ -1171,6 +1423,8 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
return nil, err
|
||||
}
|
||||
|
||||
status = resp.Status
|
||||
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
auth := resp.Header.Get("www-authenticate")
|
||||
@@ -1182,25 +1436,21 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
|
||||
regOpts.Token = token
|
||||
if body != nil {
|
||||
body.Seek(0, io.SeekStart)
|
||||
if _, err := body.Seek(0, io.SeekStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
case resp.StatusCode == http.StatusNotFound:
|
||||
return nil, os.ErrNotExist
|
||||
case resp.StatusCode >= http.StatusBadRequest:
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, body)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
|
||||
default:
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errMaxRetriesExceeded
|
||||
return nil, fmt.Errorf("max retry exceeded: %v", status)
|
||||
}
|
||||
|
||||
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
|
@@ -12,7 +12,7 @@ func TestModelPrompt(t *testing.T) {
|
||||
Template: "a{{ .Prompt }}b",
|
||||
Prompt: "<h1>",
|
||||
}
|
||||
s, err := m.Prompt(req)
|
||||
s, err := m.Prompt(req, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@@ -85,27 +85,20 @@ func (mp ModelPath) GetShortTagname() string {
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
// modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set.
|
||||
// The models directory is where Ollama stores its model files and manifests.
|
||||
func modelsDir() (string, error) {
|
||||
if models, exists := os.LookupEnv("OLLAMA_MODELS"); exists {
|
||||
return models, nil
|
||||
}
|
||||
func (mp ModelPath) GetManifestPath(createDir bool) (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "models"), nil
|
||||
}
|
||||
|
||||
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
||||
func (mp ModelPath) GetManifestPath() (string, error) {
|
||||
dir, err := modelsDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
path := filepath.Join(home, ".ollama", "models", "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
if createDir {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func (mp ModelPath) BaseURL() *url.URL {
|
||||
@@ -116,12 +109,12 @@ func (mp ModelPath) BaseURL() *url.URL {
|
||||
}
|
||||
|
||||
func GetManifestPath() (string, error) {
|
||||
dir, err := modelsDir()
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, "manifests")
|
||||
path := filepath.Join(home, ".ollama", "models", "manifests")
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -130,7 +123,7 @@ func GetManifestPath() (string, error) {
|
||||
}
|
||||
|
||||
func GetBlobsPath(digest string) (string, error) {
|
||||
dir, err := modelsDir()
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -139,7 +132,7 @@ func GetBlobsPath(digest string) (string, error) {
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, "blobs", digest)
|
||||
path := filepath.Join(home, ".ollama", "models", "blobs", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
|
292
server/routes.go
292
server/routes.go
@@ -23,10 +23,11 @@ import (
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gonum.org/v1/gonum/mat"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
"github.com/jmorganca/ollama/vector"
|
||||
)
|
||||
|
||||
var mode string = gin.DebugMode
|
||||
@@ -46,13 +47,14 @@ func init() {
|
||||
var loaded struct {
|
||||
mu sync.Mutex
|
||||
|
||||
runner llm.LLM
|
||||
llm llm.LLM
|
||||
Embeddings []vector.Embedding
|
||||
|
||||
expireAt time.Time
|
||||
expireTimer *time.Timer
|
||||
|
||||
*Model
|
||||
*api.Options
|
||||
digest string
|
||||
options api.Options
|
||||
}
|
||||
|
||||
var defaultSessionDuration = 5 * time.Minute
|
||||
@@ -70,51 +72,65 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
|
||||
}
|
||||
|
||||
// check if the loaded model is still running in a subprocess, in case something unexpected happened
|
||||
if loaded.runner != nil {
|
||||
if err := loaded.runner.Ping(ctx); err != nil {
|
||||
if loaded.llm != nil {
|
||||
if err := loaded.llm.Ping(ctx); err != nil {
|
||||
log.Print("loaded llm process not responding, closing now")
|
||||
// the subprocess is no longer running, so close it
|
||||
loaded.runner.Close()
|
||||
loaded.runner = nil
|
||||
loaded.Model = nil
|
||||
loaded.Options = nil
|
||||
loaded.llm.Close()
|
||||
loaded.llm = nil
|
||||
loaded.digest = ""
|
||||
}
|
||||
}
|
||||
|
||||
needLoad := loaded.runner == nil || // is there a model loaded?
|
||||
loaded.ModelPath != model.ModelPath || // has the base model changed?
|
||||
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
|
||||
|
||||
if needLoad {
|
||||
if loaded.runner != nil {
|
||||
if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) {
|
||||
if loaded.llm != nil {
|
||||
log.Println("changing loaded model")
|
||||
loaded.runner.Close()
|
||||
loaded.runner = nil
|
||||
loaded.Model = nil
|
||||
loaded.Options = nil
|
||||
loaded.llm.Close()
|
||||
loaded.llm = nil
|
||||
loaded.digest = ""
|
||||
}
|
||||
|
||||
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
|
||||
if err != nil {
|
||||
// some older models are not compatible with newer versions of llama.cpp
|
||||
// show a generalized compatibility error until there is a better way to
|
||||
// check for model compatibility
|
||||
if strings.Contains(err.Error(), "failed to load model") {
|
||||
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
|
||||
}
|
||||
if model.Embeddings != nil && len(model.Embeddings) > 0 {
|
||||
opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work
|
||||
loaded.Embeddings = model.Embeddings
|
||||
}
|
||||
|
||||
llmModel, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loaded.Model = model
|
||||
loaded.runner = llmRunner
|
||||
loaded.Options = &opts
|
||||
}
|
||||
// set cache values before modifying opts
|
||||
loaded.llm = llmModel
|
||||
loaded.digest = model.Digest
|
||||
loaded.options = opts
|
||||
|
||||
// update options for the loaded llm
|
||||
// TODO(mxyng): this isn't thread safe, but it should be fine for now
|
||||
loaded.runner.SetOptions(opts)
|
||||
if opts.NumKeep < 0 {
|
||||
promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}}, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokensWithSystem, err := llmModel.Encode(ctx, promptWithSystem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem)
|
||||
|
||||
llmModel.SetOptions(opts)
|
||||
}
|
||||
}
|
||||
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
|
||||
@@ -127,13 +143,13 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
|
||||
return
|
||||
}
|
||||
|
||||
if loaded.runner != nil {
|
||||
loaded.runner.Close()
|
||||
if loaded.llm == nil {
|
||||
return
|
||||
}
|
||||
|
||||
loaded.runner = nil
|
||||
loaded.Model = nil
|
||||
loaded.Options = nil
|
||||
loaded.llm.Close()
|
||||
loaded.llm = nil
|
||||
loaded.digest = ""
|
||||
})
|
||||
}
|
||||
|
||||
@@ -148,26 +164,8 @@ func GenerateHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
var req api.GenerateRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// validate the request
|
||||
switch {
|
||||
case req.Model == "":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
case len(req.Format) > 0 && req.Format != "json":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
||||
return
|
||||
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -197,24 +195,30 @@ func GenerateHandler(c *gin.Context) {
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt := req.Prompt
|
||||
if !req.Raw {
|
||||
prompt, err = model.Prompt(req)
|
||||
embedding := ""
|
||||
if model.Embeddings != nil && len(model.Embeddings) > 0 {
|
||||
promptEmbed, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// TODO: set embed_top from specified parameters in modelfile
|
||||
embed_top := 3
|
||||
topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
|
||||
for _, e := range topK {
|
||||
embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data)
|
||||
}
|
||||
}
|
||||
|
||||
prompt, err := model.Prompt(req, embedding)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
// an empty request loads the model
|
||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||
ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
fn := func(r api.GenerateResponse) {
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
@@ -226,16 +230,16 @@ func GenerateHandler(c *gin.Context) {
|
||||
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
if req.Raw {
|
||||
// in raw mode the client must manage history on their own
|
||||
r.Context = nil
|
||||
}
|
||||
|
||||
ch <- r
|
||||
}
|
||||
|
||||
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
// an empty request loads the model
|
||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||
ch <- api.GenerateResponse{Model: req.Model, Done: true}
|
||||
} else {
|
||||
if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -264,18 +268,8 @@ func EmbeddingHandler(c *gin.Context) {
|
||||
defer loaded.mu.Unlock()
|
||||
|
||||
var req api.EmbeddingRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -291,12 +285,12 @@ func EmbeddingHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !loaded.Options.EmbeddingOnly {
|
||||
if !loaded.options.EmbeddingOnly {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
|
||||
return
|
||||
}
|
||||
|
||||
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
|
||||
embedding, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
log.Printf("embedding generation failed: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
@@ -311,18 +305,8 @@ func EmbeddingHandler(c *gin.Context) {
|
||||
|
||||
func PullModelHandler(c *gin.Context) {
|
||||
var req api.PullRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -355,18 +339,8 @@ func PullModelHandler(c *gin.Context) {
|
||||
|
||||
func PushModelHandler(c *gin.Context) {
|
||||
var req api.PushRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -381,9 +355,7 @@ func PushModelHandler(c *gin.Context) {
|
||||
Insecure: req.Insecure,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
ctx := context.Background()
|
||||
if err := PushModel(ctx, req.Name, regOpts, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
@@ -399,20 +371,12 @@ func PushModelHandler(c *gin.Context) {
|
||||
|
||||
func CreateModelHandler(c *gin.Context) {
|
||||
var req api.CreateRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" || req.Path == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name and path are required"})
|
||||
return
|
||||
}
|
||||
workDir := c.GetString("workDir")
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
@@ -424,7 +388,7 @@ func CreateModelHandler(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil {
|
||||
if err := CreateModel(ctx, workDir, req.Name, req.Path, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -439,18 +403,8 @@ func CreateModelHandler(c *gin.Context) {
|
||||
|
||||
func DeleteModelHandler(c *gin.Context) {
|
||||
var req api.DeleteRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -479,18 +433,8 @@ func DeleteModelHandler(c *gin.Context) {
|
||||
|
||||
func ShowModelHandler(c *gin.Context) {
|
||||
var req api.ShowRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -559,7 +503,7 @@ func GetModelInfo(name string) (*api.ShowResponse, error) {
|
||||
}
|
||||
|
||||
func ListModelsHandler(c *gin.Context) {
|
||||
models := make([]api.ModelResponse, 0)
|
||||
var models []api.ModelResponse
|
||||
fp, err := GetManifestPath()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -600,18 +544,8 @@ func ListModelsHandler(c *gin.Context) {
|
||||
|
||||
func CopyModelHandler(c *gin.Context) {
|
||||
var req api.CopyRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Source == "" || req.Destination == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -632,22 +566,6 @@ var defaultAllowOrigins = []string{
|
||||
}
|
||||
|
||||
func Serve(ln net.Listener, allowOrigins []string) error {
|
||||
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
||||
// clean up unused layers and manifests
|
||||
if err := PruneLayers(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manifestsPath, err := GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := PruneDirectory(manifestsPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
config := cors.DefaultConfig()
|
||||
config.AllowWildcard = true
|
||||
|
||||
@@ -693,7 +611,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||
r.Handle(method, "/api/tags", ListModelsHandler)
|
||||
}
|
||||
|
||||
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
|
||||
log.Printf("Listening on %s", ln.Addr())
|
||||
s := &http.Server{
|
||||
Handler: r,
|
||||
}
|
||||
@@ -703,8 +621,8 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-signals
|
||||
if loaded.runner != nil {
|
||||
loaded.runner.Close()
|
||||
if loaded.llm != nil {
|
||||
loaded.llm.Close()
|
||||
}
|
||||
os.RemoveAll(workDir)
|
||||
os.Exit(0)
|
||||
@@ -713,7 +631,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||
if runtime.GOOS == "linux" {
|
||||
// check compatibility to log warnings
|
||||
if _, err := llm.CheckVRAM(); err != nil {
|
||||
log.Printf("Warning: GPU support may not be enabled, check you have installed GPU drivers: %v", err)
|
||||
log.Printf("Warning: GPU support may not enabled, check you have installed install GPU drivers: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
431
server/upload.go
431
server/upload.go
@@ -2,369 +2,218 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/format"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var blobUploadManager sync.Map
|
||||
|
||||
type blobUpload struct {
|
||||
*Layer
|
||||
|
||||
Total int64
|
||||
Completed atomic.Int64
|
||||
|
||||
Parts []blobUploadPart
|
||||
|
||||
nextURL chan *url.URL
|
||||
|
||||
context.CancelFunc
|
||||
|
||||
done bool
|
||||
err error
|
||||
references atomic.Int32
|
||||
}
|
||||
|
||||
const (
|
||||
numUploadParts = 64
|
||||
minUploadPartSize int64 = 95 * 1000 * 1000
|
||||
maxUploadPartSize int64 = 1000 * 1000 * 1000
|
||||
redirectChunkSize int64 = 1024 * 1024 * 1024
|
||||
regularChunkSize int64 = 95 * 1024 * 1024
|
||||
)
|
||||
|
||||
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if b.From != "" {
|
||||
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
||||
if layer.From != "" {
|
||||
values := requestURL.Query()
|
||||
values.Add("mount", b.Digest)
|
||||
values.Add("from", b.From)
|
||||
values.Add("mount", layer.Digest)
|
||||
values.Add("from", layer.From)
|
||||
requestURL.RawQuery = values.Encode()
|
||||
}
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodPost, requestURL, nil, nil, opts)
|
||||
resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
log.Printf("couldn't start upload: %v", err)
|
||||
return nil, 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
location := resp.Header.Get("Docker-Upload-Location")
|
||||
chunkSize := redirectChunkSize
|
||||
if location == "" {
|
||||
location = resp.Header.Get("Location")
|
||||
chunkSize = regularChunkSize
|
||||
}
|
||||
|
||||
fi, err := os.Stat(p)
|
||||
locationURL, err := url.Parse(location)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
b.Total = fi.Size()
|
||||
|
||||
var size = b.Total / numUploadParts
|
||||
switch {
|
||||
case size < minUploadPartSize:
|
||||
size = minUploadPartSize
|
||||
case size > maxUploadPartSize:
|
||||
size = maxUploadPartSize
|
||||
}
|
||||
|
||||
var offset int64
|
||||
for offset < fi.Size() {
|
||||
if offset+size > fi.Size() {
|
||||
size = fi.Size() - offset
|
||||
}
|
||||
|
||||
// set part.N to the current number of parts
|
||||
b.Parts = append(b.Parts, blobUploadPart{blobUpload: b, N: len(b.Parts), Offset: offset, Size: size})
|
||||
offset += size
|
||||
}
|
||||
|
||||
log.Printf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
|
||||
|
||||
requestURL, err = url.Parse(location)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.nextURL = make(chan *url.URL, 1)
|
||||
b.nextURL <- requestURL
|
||||
return nil
|
||||
return locationURL, chunkSize, nil
|
||||
}
|
||||
|
||||
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
|
||||
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
|
||||
func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
|
||||
defer blobUploadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
// TODO allow resumability
|
||||
// TODO allow canceling uploads via DELETE
|
||||
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Open(p)
|
||||
f, err := os.Open(fp)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
g, inner := errgroup.WithContext(ctx)
|
||||
g.SetLimit(numUploadParts)
|
||||
for i := range b.Parts {
|
||||
part := &b.Parts[i]
|
||||
select {
|
||||
case <-inner.Done():
|
||||
case requestURL := <-b.nextURL:
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
for try := 0; try < maxRetries; try++ {
|
||||
part.ReadSeeker = io.NewSectionReader(f, part.Offset, part.Size)
|
||||
err = b.uploadChunk(inner, http.MethodPatch, requestURL, part, opts)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
return err
|
||||
case errors.Is(err, errMaxRetriesExceeded):
|
||||
return err
|
||||
case err != nil:
|
||||
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err)
|
||||
continue
|
||||
}
|
||||
pw := ProgressWriter{
|
||||
status: fmt.Sprintf("uploading %s", layer.Digest),
|
||||
digest: layer.Digest,
|
||||
total: layer.Size,
|
||||
fn: fn,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
for offset := int64(0); offset < layer.Size; {
|
||||
chunk := layer.Size - offset
|
||||
if chunk > chunkSize {
|
||||
chunk = chunkSize
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
|
||||
resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw)
|
||||
if err != nil {
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("error uploading chunk: %v", err),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: offset,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
offset += chunk
|
||||
location := resp.Header.Get("Docker-Upload-Location")
|
||||
if location == "" {
|
||||
location = resp.Header.Get("Location")
|
||||
}
|
||||
|
||||
requestURL, err = url.Parse(location)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
b.err = err
|
||||
return
|
||||
}
|
||||
|
||||
requestURL := <-b.nextURL
|
||||
|
||||
var sb strings.Builder
|
||||
for _, part := range b.Parts {
|
||||
sb.Write(part.Sum(nil))
|
||||
}
|
||||
|
||||
md5sum := md5.Sum([]byte(sb.String()))
|
||||
|
||||
values := requestURL.Query()
|
||||
values.Add("digest", b.Digest)
|
||||
values.Add("etag", fmt.Sprintf("%x-%d", md5sum, len(b.Parts)))
|
||||
values.Add("digest", layer.Digest)
|
||||
requestURL.RawQuery = values.Encode()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/octet-stream")
|
||||
headers.Set("Content-Length", "0")
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b.done = true
|
||||
}
|
||||
|
||||
func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
|
||||
part.Reset()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/octet-stream")
|
||||
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
|
||||
headers.Set("X-Redirect-Uploads", "1")
|
||||
|
||||
if method == http.MethodPatch {
|
||||
headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1))
|
||||
}
|
||||
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(part.ReadSeeker, io.MultiWriter(part, part.Hash)), opts)
|
||||
// finish the upload
|
||||
resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't finish upload: %v", err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
location := resp.Header.Get("Docker-Upload-Location")
|
||||
if location == "" {
|
||||
location = resp.Header.Get("Location")
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
nextURL, err := url.Parse(location)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusTemporaryRedirect:
|
||||
b.nextURL <- nextURL
|
||||
|
||||
redirectURL, err := resp.Location()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for try := 0; try < maxRetries; try++ {
|
||||
err = b.uploadChunk(ctx, http.MethodPut, redirectURL, part, nil)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
return err
|
||||
case errors.Is(err, errMaxRetriesExceeded):
|
||||
return err
|
||||
case err != nil:
|
||||
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err)
|
||||
continue
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
|
||||
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
auth := resp.Header.Get("www-authenticate")
|
||||
authRedir := ParseAuthRedirectString(auth)
|
||||
token, err := getAuthToken(ctx, authRedir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.Token = token
|
||||
fallthrough
|
||||
case resp.StatusCode >= http.StatusBadRequest:
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fmt.Errorf("http status %d %s: %s", resp.StatusCode, resp.Status, body)
|
||||
}
|
||||
|
||||
if method == http.MethodPatch {
|
||||
b.nextURL <- nextURL
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *blobUpload) acquire() {
|
||||
b.references.Add(1)
|
||||
}
|
||||
func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) {
|
||||
sectionReader := io.NewSectionReader(r, offset, limit)
|
||||
|
||||
func (b *blobUpload) release() {
|
||||
if b.references.Add(-1) == 0 {
|
||||
b.CancelFunc()
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/octet-stream")
|
||||
headers.Set("Content-Length", strconv.Itoa(int(limit)))
|
||||
headers.Set("X-Redirect-Uploads", "1")
|
||||
|
||||
if method == http.MethodPatch {
|
||||
headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
|
||||
}
|
||||
}
|
||||
|
||||
func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
|
||||
b.acquire()
|
||||
defer b.release()
|
||||
for try := 0; try < maxRetries; try++ {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sectionReader, pw), opts)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
ticker := time.NewTicker(60 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusTemporaryRedirect:
|
||||
location, err := resp.Location()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pw.completed = offset
|
||||
if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil {
|
||||
// retry
|
||||
log.Printf("retrying redirected upload: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
auth := resp.Header.Get("www-authenticate")
|
||||
authRedir := ParseAuthRedirectString(auth)
|
||||
token, err := getAuthToken(ctx, authRedir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts.Token = token
|
||||
|
||||
pw.completed = offset
|
||||
sectionReader = io.NewSectionReader(r, offset, limit)
|
||||
continue
|
||||
case resp.StatusCode >= http.StatusBadRequest:
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("uploading %s", b.Digest),
|
||||
Digest: b.Digest,
|
||||
Total: b.Total,
|
||||
Completed: b.Completed.Load(),
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("max retries exceeded")
|
||||
}
|
||||
|
||||
type ProgressWriter struct {
|
||||
status string
|
||||
digest string
|
||||
bucket int64
|
||||
completed int64
|
||||
total int64
|
||||
fn func(api.ProgressResponse)
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (pw *ProgressWriter) Write(b []byte) (int, error) {
|
||||
pw.mu.Lock()
|
||||
defer pw.mu.Unlock()
|
||||
|
||||
n := len(b)
|
||||
pw.bucket += int64(n)
|
||||
|
||||
// throttle status updates to not spam the client
|
||||
if pw.bucket >= 1024*1024 || pw.completed+pw.bucket >= pw.total {
|
||||
pw.completed += pw.bucket
|
||||
pw.fn(api.ProgressResponse{
|
||||
Status: pw.status,
|
||||
Digest: pw.digest,
|
||||
Total: pw.total,
|
||||
Completed: pw.completed,
|
||||
})
|
||||
|
||||
if b.done || b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
pw.bucket = 0
|
||||
}
|
||||
}
|
||||
|
||||
type blobUploadPart struct {
|
||||
// N is the part number
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
hash.Hash
|
||||
|
||||
written int64
|
||||
|
||||
io.ReadSeeker
|
||||
*blobUpload
|
||||
}
|
||||
|
||||
func (p *blobUploadPart) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
p.written += int64(n)
|
||||
p.Completed.Add(int64(n))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (p *blobUploadPart) Reset() {
|
||||
p.Seek(0, io.SeekStart)
|
||||
p.Completed.Add(-int64(p.written))
|
||||
p.written = 0
|
||||
p.Hash = md5.New()
|
||||
}
|
||||
|
||||
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
case err != nil:
|
||||
return err
|
||||
default:
|
||||
defer resp.Body.Close()
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("uploading %s", layer.Digest),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: layer.Size,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
|
||||
upload := data.(*blobUpload)
|
||||
if !ok {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
||||
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
|
||||
blobUploadManager.Delete(layer.Digest)
|
||||
return err
|
||||
}
|
||||
|
||||
go upload.Run(context.Background(), opts)
|
||||
}
|
||||
|
||||
return upload.Wait(ctx, fn)
|
||||
}
|
||||
|
69
vector/store.go
Normal file
69
vector/store.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package vector
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"sort"
|
||||
|
||||
"gonum.org/v1/gonum/mat"
|
||||
)
|
||||
|
||||
type Embedding struct {
|
||||
Vector []float64 // the embedding vector
|
||||
Data string // the data represted by the embedding
|
||||
}
|
||||
|
||||
type EmbeddingSimilarity struct {
|
||||
Embedding Embedding // the embedding that was used to calculate the similarity
|
||||
Similarity float64 // the similarity between the embedding and the query
|
||||
}
|
||||
|
||||
type Heap []EmbeddingSimilarity
|
||||
|
||||
func (h Heap) Len() int { return len(h) }
|
||||
func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity }
|
||||
func (h Heap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
func (h *Heap) Push(e any) {
|
||||
*h = append(*h, e.(EmbeddingSimilarity))
|
||||
}
|
||||
|
||||
func (h *Heap) Pop() interface{} {
|
||||
old := *h
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*h = old[0 : n-1]
|
||||
return x
|
||||
}
|
||||
|
||||
// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
|
||||
// This value will range from -1 to 1, where 1 means the vectors are identical.
|
||||
func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
|
||||
dotProduct := mat.Dot(vec1, vec2)
|
||||
norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
|
||||
|
||||
if norms == 0 {
|
||||
return 0
|
||||
}
|
||||
return dotProduct / norms
|
||||
}
|
||||
|
||||
func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity {
|
||||
h := &Heap{}
|
||||
heap.Init(h)
|
||||
for _, emb := range embeddings {
|
||||
similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector))
|
||||
heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity})
|
||||
if h.Len() > k {
|
||||
heap.Pop(h)
|
||||
}
|
||||
}
|
||||
|
||||
topK := make([]EmbeddingSimilarity, 0, h.Len())
|
||||
for h.Len() > 0 {
|
||||
topK = append(topK, heap.Pop(h).(EmbeddingSimilarity))
|
||||
}
|
||||
sort.Slice(topK, func(i, j int) bool {
|
||||
return topK[i].Similarity > topK[j].Similarity
|
||||
})
|
||||
|
||||
return topK
|
||||
}
|
Reference in New Issue
Block a user