Compare commits
170 Commits
mattw/howt
...
v0.1.8
Author | SHA1 | Date | |
---|---|---|---|
![]() |
e21579a0f1 | ||
![]() |
c44b619428 | ||
![]() |
17678b7225 | ||
![]() |
6109bebba6 | ||
![]() |
8ae8c9fa8c | ||
![]() |
f39daff461 | ||
![]() |
c50b01bc21 | ||
![]() |
b9dc875401 | ||
![]() |
06589a3b30 | ||
![]() |
1fd511e661 | ||
![]() |
c01bbe94fd | ||
![]() |
1beb5645a9 | ||
![]() |
6db3691b8f | ||
![]() |
fe5a872444 | ||
![]() |
d39709260f | ||
![]() |
60bb3c03a1 | ||
![]() |
2e53704685 | ||
![]() |
527f9a7975 | ||
![]() |
c4cc738cbf | ||
![]() |
2c6189f4fe | ||
![]() |
c05ab9a86e | ||
![]() |
f42f3d9b27 | ||
![]() |
341fb7e35f | ||
![]() |
ec3614812a | ||
![]() |
f14969314a | ||
![]() |
1fb9288661 | ||
![]() |
01a03caa20 | ||
![]() |
bf6786bb39 | ||
![]() |
642128b75a | ||
![]() |
f21bd6210d | ||
![]() |
ad88799411 | ||
![]() |
0818b5e318 | ||
![]() |
1df6100c77 | ||
![]() |
5c48fe1fb0 | ||
![]() |
874bb31986 | ||
![]() |
f7856a57eb | ||
![]() |
f9a4281124 | ||
![]() |
8d03bd7b54 | ||
![]() |
9ec16f0f03 | ||
![]() |
57a58db1b0 | ||
![]() |
2d75a4537c | ||
![]() |
4748609611 | ||
![]() |
c0dcea1398 | ||
![]() |
115fc56eb7 | ||
![]() |
186f685224 | ||
![]() |
12efcbb057 | ||
![]() |
4e09aab8b9 | ||
![]() |
3a1ed9ff70 | ||
![]() |
6d283882b1 | ||
![]() |
5c3491f425 | ||
![]() |
e5d1ce4dde | ||
![]() |
2665f3c28e | ||
![]() |
a79f030e75 | ||
![]() |
9bc5864a03 | ||
![]() |
b88cc0fac9 | ||
![]() |
5b2cf16397 | ||
![]() |
910816a532 | ||
![]() |
28c3f288e2 | ||
![]() |
deeac961bb | ||
![]() |
49443e7da5 | ||
![]() |
bb8464c0d2 | ||
![]() |
daa5bb4473 | ||
![]() |
92119de9d8 | ||
![]() |
53b0ba8d43 | ||
![]() |
db342691f9 | ||
![]() |
cecf83141e | ||
![]() |
a5a2adf1ec | ||
![]() |
b0c9cd0f3b | ||
![]() |
77f61c6301 | ||
![]() |
f3604534e5 | ||
![]() |
914428351a | ||
![]() |
9afea9e3b9 | ||
![]() |
c039432b5c | ||
![]() |
c345b4ca7c | ||
![]() |
0c7a00a264 | ||
![]() |
36c160f1c3 | ||
![]() |
b66bcaa582 | ||
![]() |
c9167494cb | ||
![]() |
125d0a013a | ||
![]() |
ba2da6ceaa | ||
![]() |
ccff9ca09c | ||
![]() |
436a5be49c | ||
![]() |
cc0bf96398 | ||
![]() |
386169205c | ||
![]() |
0d6342a882 | ||
![]() |
75bee074b6 | ||
![]() |
533d76368c | ||
![]() |
459f4a7889 | ||
![]() |
25c63c91d8 | ||
![]() |
cbfff4f868 | ||
![]() |
7ed5a39bc7 | ||
![]() |
cc1d03f4ec | ||
![]() |
846f593dbf | ||
![]() |
0a53da03fd | ||
![]() |
2ce1793a1d | ||
![]() |
e1c5be24e7 | ||
![]() |
2ad8a074ac | ||
![]() |
7e547c6833 | ||
![]() |
689842b9ff | ||
![]() |
a19d47642e | ||
![]() |
a7dad24d92 | ||
![]() |
6b213216d5 | ||
![]() |
fe6f3b48f7 | ||
![]() |
36c88cb9db | ||
![]() |
235e43d7f6 | ||
![]() |
730996e530 | ||
![]() |
ce6197a8e0 | ||
![]() |
46b9953f32 | ||
![]() |
4dcceeffb7 | ||
![]() |
019e4a4558 | ||
![]() |
627d04d927 | ||
![]() |
940e8ebec3 | ||
![]() |
565648f3f7 | ||
![]() |
90c49bed57 | ||
![]() |
3a2477174f | ||
![]() |
8c6c2cbc8c | ||
![]() |
5dc0cff459 | ||
![]() |
c5c8b4b16a | ||
![]() |
8299bf76ed | ||
![]() |
ee4979e510 | ||
![]() |
08b0e04f40 | ||
![]() |
b36b0b71f8 | ||
![]() |
094df37563 | ||
![]() |
f3648fd206 | ||
![]() |
bd93a94abd | ||
![]() |
f55bdb6f10 | ||
![]() |
2870a9bfc8 | ||
![]() |
c031c211d1 | ||
![]() |
68391b0055 | ||
![]() |
b7e137323a | ||
![]() |
8fa3f366ad | ||
![]() |
fddb303f23 | ||
![]() |
ad5ee20c7b | ||
![]() |
785b4eb5bf | ||
![]() |
16ede1b30b | ||
![]() |
17d6bbbb2a | ||
![]() |
6481b7f34c | ||
![]() |
cb4a80b693 | ||
![]() |
68d7255bd3 | ||
![]() |
9ef2fce33a | ||
![]() |
43eaba3d60 | ||
![]() |
1af493c5a0 | ||
![]() |
a0c3e989de | ||
![]() |
7af0fdce48 | ||
![]() |
ee94693b1a | ||
![]() |
731dbdc1a5 | ||
![]() |
06bcfbd629 | ||
![]() |
7d7c2510f8 | ||
![]() |
f9b2f999ac | ||
![]() |
c416087339 | ||
![]() |
6002cebd2c | ||
![]() |
212bdc541c | ||
![]() |
dca6686273 | ||
![]() |
598621afab | ||
![]() |
6479f49c09 | ||
![]() |
832b4db9d4 | ||
![]() |
c43873f33b | ||
![]() |
11d82d7b9b | ||
![]() |
36fe2deebf | ||
![]() |
4a8931f634 | ||
![]() |
bd6e38fb1a | ||
![]() |
92189a5855 | ||
![]() |
d790bf9916 | ||
![]() |
35afac099a | ||
![]() |
811c3d1900 | ||
![]() |
3553d10769 | ||
![]() |
6fe178134d | ||
![]() |
d890890f66 | ||
![]() |
89ba19feca | ||
![]() |
6f58c77671 |
@@ -5,8 +5,8 @@ ARG GOFLAGS="'-ldflags=-w -s'"
|
|||||||
|
|
||||||
WORKDIR /go/src/github.com/jmorganca/ollama
|
WORKDIR /go/src/github.com/jmorganca/ollama
|
||||||
RUN apt-get update && apt-get install -y git build-essential cmake
|
RUN apt-get update && apt-get install -y git build-essential cmake
|
||||||
ADD https://dl.google.com/go/go1.21.1.linux-$TARGETARCH.tar.gz /tmp/go1.21.1.tar.gz
|
ADD https://dl.google.com/go/go1.21.3.linux-$TARGETARCH.tar.gz /tmp/go1.21.3.tar.gz
|
||||||
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.1.tar.gz
|
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.3.tar.gz
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
ENV GOARCH=$TARGETARCH
|
ENV GOARCH=$TARGETARCH
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
|
|
||||||
# centos7 amd64 dependencies
|
# centos7 amd64 dependencies
|
||||||
FROM --platform=linux/amd64 nvidia/cuda:11.8.0-devel-centos7 AS base-amd64
|
FROM --platform=linux/amd64 nvidia/cuda:11.3.1-devel-centos7 AS base-amd64
|
||||||
RUN yum install -y https://repo.ius.io/ius-release-el7.rpm centos-release-scl && \
|
RUN yum install -y https://repo.ius.io/ius-release-el7.rpm centos-release-scl && \
|
||||||
yum update -y && \
|
yum update -y && \
|
||||||
yum install -y devtoolset-10-gcc devtoolset-10-gcc-c++ git236 wget
|
yum install -y devtoolset-10-gcc devtoolset-10-gcc-c++ git236 wget
|
||||||
@@ -8,7 +7,7 @@ RUN wget "https://github.com/Kitware/CMake/releases/download/v3.27.6/cmake-3.27.
|
|||||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||||
|
|
||||||
# centos8 arm64 dependencies
|
# centos8 arm64 dependencies
|
||||||
FROM --platform=linux/arm64 nvidia/cuda:11.4.3-devel-centos8 AS base-arm64
|
FROM --platform=linux/arm64 nvidia/cuda-arm64:11.3.1-devel-centos8 AS base-arm64
|
||||||
RUN sed -i -e 's/mirrorlist/#mirrorlist/g' -e 's|#baseurl=http://mirror.centos.org|baseurl=http://vault.centos.org|g' /etc/yum.repos.d/CentOS-*
|
RUN sed -i -e 's/mirrorlist/#mirrorlist/g' -e 's|#baseurl=http://mirror.centos.org|baseurl=http://vault.centos.org|g' /etc/yum.repos.d/CentOS-*
|
||||||
RUN yum install -y git cmake
|
RUN yum install -y git cmake
|
||||||
|
|
||||||
@@ -17,8 +16,8 @@ ARG TARGETARCH
|
|||||||
ARG GOFLAGS="'-ldflags -w -s'"
|
ARG GOFLAGS="'-ldflags -w -s'"
|
||||||
|
|
||||||
# install go
|
# install go
|
||||||
ADD https://dl.google.com/go/go1.21.1.linux-$TARGETARCH.tar.gz /tmp/go1.21.1.tar.gz
|
ADD https://dl.google.com/go/go1.21.3.linux-$TARGETARCH.tar.gz /tmp/go1.21.3.tar.gz
|
||||||
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.1.tar.gz
|
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.3.tar.gz
|
||||||
|
|
||||||
# build the final binary
|
# build the final binary
|
||||||
WORKDIR /go/src/github.com/jmorganca/ollama
|
WORKDIR /go/src/github.com/jmorganca/ollama
|
||||||
|
39
README.md
39
README.md
@@ -15,6 +15,10 @@ Get up and running with large language models locally.
|
|||||||
|
|
||||||
[Download](https://ollama.ai/download/Ollama-darwin.zip)
|
[Download](https://ollama.ai/download/Ollama-darwin.zip)
|
||||||
|
|
||||||
|
### Windows
|
||||||
|
|
||||||
|
Coming soon!
|
||||||
|
|
||||||
### Linux & WSL2
|
### Linux & WSL2
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -23,9 +27,10 @@ curl https://ollama.ai/install.sh | sh
|
|||||||
|
|
||||||
[Manual install instructions](https://github.com/jmorganca/ollama/blob/main/docs/linux.md)
|
[Manual install instructions](https://github.com/jmorganca/ollama/blob/main/docs/linux.md)
|
||||||
|
|
||||||
### Windows
|
### Docker
|
||||||
|
|
||||||
coming soon
|
The official [Ollama Docker image `ollama/ollama`](https://hub.docker.com/r/ollama/ollama)
|
||||||
|
is available on Docker Hub.
|
||||||
|
|
||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
@@ -56,11 +61,11 @@ Here are some example open-source models that can be downloaded:
|
|||||||
|
|
||||||
## Customize your own model
|
## Customize your own model
|
||||||
|
|
||||||
### Import from GGUF or GGML
|
### Import from GGUF
|
||||||
|
|
||||||
Ollama supports importing GGUF and GGML file formats in the Modelfile. This means if you have a model that is not in the Ollama library, you can create it, iterate on it, and upload it to the Ollama library to share with others when you are ready.
|
Ollama supports importing GGUF models in the Modelfile:
|
||||||
|
|
||||||
1. Create a file named Modelfile, and add a `FROM` instruction with the local filepath to the model you want to import.
|
1. Create a file named `Modelfile`, with a `FROM` instruction with the local filepath to the model you want to import.
|
||||||
|
|
||||||
```
|
```
|
||||||
FROM ./vicuna-33b.Q4_0.gguf
|
FROM ./vicuna-33b.Q4_0.gguf
|
||||||
@@ -69,18 +74,22 @@ Ollama supports importing GGUF and GGML file formats in the Modelfile. This mean
|
|||||||
2. Create the model in Ollama
|
2. Create the model in Ollama
|
||||||
|
|
||||||
```
|
```
|
||||||
ollama create name -f path_to_modelfile
|
ollama create example -f Modelfile
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Run the model
|
3. Run the model
|
||||||
|
|
||||||
```
|
```
|
||||||
ollama run name
|
ollama run example
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Import from PyTorch or Safetensors
|
||||||
|
|
||||||
|
See the [guide](docs/import.md) on importing models for more information.
|
||||||
|
|
||||||
### Customize a prompt
|
### Customize a prompt
|
||||||
|
|
||||||
Models from the Ollama library can be customized with a prompt. The example
|
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama2` model:
|
||||||
|
|
||||||
```
|
```
|
||||||
ollama pull llama2
|
ollama pull llama2
|
||||||
@@ -170,8 +179,7 @@ ollama list
|
|||||||
Install `cmake` and `go`:
|
Install `cmake` and `go`:
|
||||||
|
|
||||||
```
|
```
|
||||||
brew install cmake
|
brew install cmake go
|
||||||
brew install go
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Then generate dependencies and build:
|
Then generate dependencies and build:
|
||||||
@@ -195,9 +203,8 @@ Finally, in a separate shell, run a model:
|
|||||||
|
|
||||||
## REST API
|
## REST API
|
||||||
|
|
||||||
> See the [API documentation](docs/api.md) for all endpoints.
|
Ollama has a REST API for running and managing models.
|
||||||
|
For example, to generate text from a model:
|
||||||
Ollama has an API for running and managing models. For example to generate text from a model:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
curl -X POST http://localhost:11434/api/generate -d '{
|
curl -X POST http://localhost:11434/api/generate -d '{
|
||||||
@@ -206,6 +213,8 @@ curl -X POST http://localhost:11434/api/generate -d '{
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
See the [API documentation](./docs/api.md) for all endpoints.
|
||||||
|
|
||||||
## Community Integrations
|
## Community Integrations
|
||||||
|
|
||||||
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama) with [example](https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa)
|
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama) with [example](https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa)
|
||||||
@@ -222,3 +231,7 @@ curl -X POST http://localhost:11434/api/generate -d '{
|
|||||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||||
- [Dumbar](https://github.com/JerrySievert/Dumbar)
|
- [Dumbar](https://github.com/JerrySievert/Dumbar)
|
||||||
- [Emacs client](https://github.com/zweifisch/ollama)
|
- [Emacs client](https://github.com/zweifisch/ollama)
|
||||||
|
- [oterm](https://github.com/ggozad/oterm)
|
||||||
|
- [Ellama Emacs client](https://github.com/s-kostyaev/ellama)
|
||||||
|
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
|
||||||
|
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
||||||
|
@@ -14,13 +14,10 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/jmorganca/ollama/format"
|
||||||
"github.com/jmorganca/ollama/version"
|
"github.com/jmorganca/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultHost = "127.0.0.1:11434"
|
|
||||||
|
|
||||||
var envHost = os.Getenv("OLLAMA_HOST")
|
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
base *url.URL
|
base *url.URL
|
||||||
http http.Client
|
http http.Client
|
||||||
@@ -43,16 +40,28 @@ func checkError(resp *http.Response, body []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ClientFromEnvironment() (*Client, error) {
|
func ClientFromEnvironment() (*Client, error) {
|
||||||
|
defaultPort := "11434"
|
||||||
|
|
||||||
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
|
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
|
||||||
if !ok {
|
switch {
|
||||||
|
case !ok:
|
||||||
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
|
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
|
||||||
|
case scheme == "http":
|
||||||
|
defaultPort = "80"
|
||||||
|
case scheme == "https":
|
||||||
|
defaultPort = "443"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// trim trailing slashes
|
||||||
|
hostport = strings.TrimRight(hostport, "/")
|
||||||
|
|
||||||
host, port, err := net.SplitHostPort(hostport)
|
host, port, err := net.SplitHostPort(hostport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
host, port = "127.0.0.1", "11434"
|
host, port = "127.0.0.1", defaultPort
|
||||||
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
|
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
|
||||||
host = ip.String()
|
host = ip.String()
|
||||||
|
} else if hostport != "" {
|
||||||
|
host = hostport
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,7 +72,7 @@ func ClientFromEnvironment() (*Client, error) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockRequest, err := http.NewRequest("HEAD", client.base.String(), nil)
|
mockRequest, err := http.NewRequest(http.MethodHead, client.base.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -127,7 +136,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxBufferSize = 512 * 1000 // 512KB
|
const maxBufferSize = 512 * format.KiloByte
|
||||||
|
|
||||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||||
var buf *bytes.Buffer
|
var buf *bytes.Buffer
|
||||||
|
43
api/client_test.go
Normal file
43
api/client_test.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestClientFromEnvironment(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
value string
|
||||||
|
expect string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]*testCase{
|
||||||
|
"empty": {value: "", expect: "http://127.0.0.1:11434"},
|
||||||
|
"only address": {value: "1.2.3.4", expect: "http://1.2.3.4:11434"},
|
||||||
|
"only port": {value: ":1234", expect: "http://:1234"},
|
||||||
|
"address and port": {value: "1.2.3.4:1234", expect: "http://1.2.3.4:1234"},
|
||||||
|
"scheme http and address": {value: "http://1.2.3.4", expect: "http://1.2.3.4:80"},
|
||||||
|
"scheme https and address": {value: "https://1.2.3.4", expect: "https://1.2.3.4:443"},
|
||||||
|
"scheme, address, and port": {value: "https://1.2.3.4:1234", expect: "https://1.2.3.4:1234"},
|
||||||
|
"hostname": {value: "example.com", expect: "http://example.com:11434"},
|
||||||
|
"hostname and port": {value: "example.com:1234", expect: "http://example.com:1234"},
|
||||||
|
"scheme http and hostname": {value: "http://example.com", expect: "http://example.com:80"},
|
||||||
|
"scheme https and hostname": {value: "https://example.com", expect: "https://example.com:443"},
|
||||||
|
"scheme, hostname, and port": {value: "https://example.com:1234", expect: "https://example.com:1234"},
|
||||||
|
"trailing slash": {value: "example.com/", expect: "http://example.com:11434"},
|
||||||
|
"trailing slash port": {value: "example.com:1234/", expect: "http://example.com:1234"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range testCases {
|
||||||
|
t.Run(k, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", v.value)
|
||||||
|
|
||||||
|
client, err := ClientFromEnvironment()
|
||||||
|
if err != v.err {
|
||||||
|
t.Fatalf("expected %s, got %s", v.err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.base.String() != v.expect {
|
||||||
|
t.Fatalf("expected %s, got %s", v.expect, client.base.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
72
api/types.go
72
api/types.go
@@ -3,7 +3,6 @@ package api
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -162,15 +161,10 @@ func (r *GenerateResponse) Summary() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Options struct {
|
// Runner options which must be set when the model is loaded into memory
|
||||||
Seed int `json:"seed,omitempty"`
|
type Runner struct {
|
||||||
|
UseNUMA bool `json:"numa,omitempty"`
|
||||||
// Backend options
|
|
||||||
UseNUMA bool `json:"numa,omitempty"`
|
|
||||||
|
|
||||||
// Model options
|
|
||||||
NumCtx int `json:"num_ctx,omitempty"`
|
NumCtx int `json:"num_ctx,omitempty"`
|
||||||
NumKeep int `json:"num_keep,omitempty"`
|
|
||||||
NumBatch int `json:"num_batch,omitempty"`
|
NumBatch int `json:"num_batch,omitempty"`
|
||||||
NumGQA int `json:"num_gqa,omitempty"`
|
NumGQA int `json:"num_gqa,omitempty"`
|
||||||
NumGPU int `json:"num_gpu,omitempty"`
|
NumGPU int `json:"num_gpu,omitempty"`
|
||||||
@@ -184,8 +178,15 @@ type Options struct {
|
|||||||
EmbeddingOnly bool `json:"embedding_only,omitempty"`
|
EmbeddingOnly bool `json:"embedding_only,omitempty"`
|
||||||
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
|
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
|
||||||
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
|
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
|
||||||
|
NumThread int `json:"num_thread,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// Predict options
|
type Options struct {
|
||||||
|
Runner
|
||||||
|
|
||||||
|
// Predict options used at runtime
|
||||||
|
NumKeep int `json:"num_keep,omitempty"`
|
||||||
|
Seed int `json:"seed,omitempty"`
|
||||||
NumPredict int `json:"num_predict,omitempty"`
|
NumPredict int `json:"num_predict,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
TopP float32 `json:"top_p,omitempty"`
|
TopP float32 `json:"top_p,omitempty"`
|
||||||
@@ -201,8 +202,6 @@ type Options struct {
|
|||||||
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
||||||
PenalizeNewline bool `json:"penalize_newline,omitempty"`
|
PenalizeNewline bool `json:"penalize_newline,omitempty"`
|
||||||
Stop []string `json:"stop,omitempty"`
|
Stop []string `json:"stop,omitempty"`
|
||||||
|
|
||||||
NumThread int `json:"num_thread,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrInvalidOpts = fmt.Errorf("invalid options")
|
var ErrInvalidOpts = fmt.Errorf("invalid options")
|
||||||
@@ -238,44 +237,39 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
|||||||
// when JSON unmarshals numbers, it uses float64, not int
|
// when JSON unmarshals numbers, it uses float64, not int
|
||||||
field.SetInt(int64(t))
|
field.SetInt(int64(t))
|
||||||
default:
|
default:
|
||||||
log.Printf("could not convert model parameter %v of type %T to int, skipped", key, val)
|
return fmt.Errorf("option %q must be of type integer", key)
|
||||||
}
|
}
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
val, ok := val.(bool)
|
val, ok := val.(bool)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("could not convert model parameter %v of type %T to bool, skipped", key, val)
|
return fmt.Errorf("option %q must be of type boolean", key)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
field.SetBool(val)
|
field.SetBool(val)
|
||||||
case reflect.Float32:
|
case reflect.Float32:
|
||||||
// JSON unmarshals to float64
|
// JSON unmarshals to float64
|
||||||
val, ok := val.(float64)
|
val, ok := val.(float64)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("could not convert model parameter %v of type %T to float32, skipped", key, val)
|
return fmt.Errorf("option %q must be of type float32", key)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
field.SetFloat(val)
|
field.SetFloat(val)
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
val, ok := val.(string)
|
val, ok := val.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("could not convert model parameter %v of type %T to string, skipped", key, val)
|
return fmt.Errorf("option %q must be of type string", key)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
field.SetString(val)
|
field.SetString(val)
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
// JSON unmarshals to []interface{}, not []string
|
// JSON unmarshals to []interface{}, not []string
|
||||||
val, ok := val.([]interface{})
|
val, ok := val.([]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("could not convert model parameter %v of type %T to slice, skipped", key, val)
|
return fmt.Errorf("option %q must be of type array", key)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
// convert []interface{} to []string
|
// convert []interface{} to []string
|
||||||
slice := make([]string, len(val))
|
slice := make([]string, len(val))
|
||||||
for i, item := range val {
|
for i, item := range val {
|
||||||
str, ok := item.(string)
|
str, ok := item.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("could not convert model parameter %v of type %T to slice of strings, skipped", key, item)
|
return fmt.Errorf("option %q must be of an array of strings", key)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
slice[i] = str
|
slice[i] = str
|
||||||
}
|
}
|
||||||
@@ -299,7 +293,7 @@ func DefaultOptions() Options {
|
|||||||
return Options{
|
return Options{
|
||||||
// options set on request to runner
|
// options set on request to runner
|
||||||
NumPredict: -1,
|
NumPredict: -1,
|
||||||
NumKeep: -1,
|
NumKeep: 0,
|
||||||
Temperature: 0.8,
|
Temperature: 0.8,
|
||||||
TopK: 40,
|
TopK: 40,
|
||||||
TopP: 0.9,
|
TopP: 0.9,
|
||||||
@@ -315,20 +309,22 @@ func DefaultOptions() Options {
|
|||||||
PenalizeNewline: true,
|
PenalizeNewline: true,
|
||||||
Seed: -1,
|
Seed: -1,
|
||||||
|
|
||||||
// options set when the model is loaded
|
Runner: Runner{
|
||||||
NumCtx: 2048,
|
// options set when the model is loaded
|
||||||
RopeFrequencyBase: 10000.0,
|
NumCtx: 2048,
|
||||||
RopeFrequencyScale: 1.0,
|
RopeFrequencyBase: 10000.0,
|
||||||
NumBatch: 512,
|
RopeFrequencyScale: 1.0,
|
||||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
NumBatch: 512,
|
||||||
NumGQA: 1,
|
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||||
NumThread: 0, // let the runtime decide
|
NumGQA: 1,
|
||||||
LowVRAM: false,
|
NumThread: 0, // let the runtime decide
|
||||||
F16KV: true,
|
LowVRAM: false,
|
||||||
UseMLock: false,
|
F16KV: true,
|
||||||
UseMMap: true,
|
UseMLock: false,
|
||||||
UseNUMA: false,
|
UseMMap: true,
|
||||||
EmbeddingOnly: true,
|
UseNUMA: false,
|
||||||
|
EmbeddingOnly: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -47,16 +47,6 @@ const config: ForgeConfig = {
|
|||||||
},
|
},
|
||||||
rebuildConfig: {},
|
rebuildConfig: {},
|
||||||
makers: [new MakerSquirrel({}), new MakerZIP({}, ['darwin'])],
|
makers: [new MakerSquirrel({}), new MakerZIP({}, ['darwin'])],
|
||||||
publishers: [
|
|
||||||
new PublisherGithub({
|
|
||||||
repository: {
|
|
||||||
name: 'ollama',
|
|
||||||
owner: 'jmorganca',
|
|
||||||
},
|
|
||||||
draft: false,
|
|
||||||
prerelease: true,
|
|
||||||
}),
|
|
||||||
],
|
|
||||||
hooks: {
|
hooks: {
|
||||||
readPackageJson: async (_, packageJson) => {
|
readPackageJson: async (_, packageJson) => {
|
||||||
return { ...packageJson, version: process.env.VERSION || packageJson.version }
|
return { ...packageJson, version: process.env.VERSION || packageJson.version }
|
||||||
|
992
app/package-lock.json
generated
992
app/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -46,7 +46,7 @@
|
|||||||
"chmodr": "^1.2.0",
|
"chmodr": "^1.2.0",
|
||||||
"copy-webpack-plugin": "^11.0.0",
|
"copy-webpack-plugin": "^11.0.0",
|
||||||
"css-loader": "^6.8.1",
|
"css-loader": "^6.8.1",
|
||||||
"electron": "25.2.0",
|
"electron": "25.9.2",
|
||||||
"eslint": "^8.43.0",
|
"eslint": "^8.43.0",
|
||||||
"eslint-plugin-import": "^2.27.5",
|
"eslint-plugin-import": "^2.27.5",
|
||||||
"fork-ts-checker-webpack-plugin": "^7.3.0",
|
"fork-ts-checker-webpack-plugin": "^7.3.0",
|
||||||
|
@@ -162,13 +162,56 @@ app.on('before-quit', () => {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const updateURL = `https://ollama.ai/api/update?os=${process.platform}&arch=${
|
||||||
|
process.arch
|
||||||
|
}&version=${app.getVersion()}&id=${id()}`
|
||||||
|
|
||||||
|
let latest = ''
|
||||||
|
async function isNewReleaseAvailable() {
|
||||||
|
try {
|
||||||
|
const response = await fetch(updateURL)
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if (response.status === 204) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json()
|
||||||
|
|
||||||
|
const url = data?.url
|
||||||
|
if (!url) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if (latest === url) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
latest = url
|
||||||
|
|
||||||
|
return true
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`update check failed - ${error}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function checkUpdate() {
|
||||||
|
const available = await isNewReleaseAvailable()
|
||||||
|
if (available) {
|
||||||
|
logger.info('checking for update')
|
||||||
|
autoUpdater.checkForUpdates()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function init() {
|
function init() {
|
||||||
if (app.isPackaged) {
|
if (app.isPackaged) {
|
||||||
autoUpdater.checkForUpdates()
|
checkUpdate()
|
||||||
setInterval(() => {
|
setInterval(() => {
|
||||||
if (!updateAvailable) {
|
checkUpdate()
|
||||||
autoUpdater.checkForUpdates()
|
|
||||||
}
|
|
||||||
}, 60 * 60 * 1000)
|
}, 60 * 60 * 1000)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -246,11 +289,7 @@ function id(): string {
|
|||||||
return uuid
|
return uuid
|
||||||
}
|
}
|
||||||
|
|
||||||
autoUpdater.setFeedURL({
|
autoUpdater.setFeedURL({ url: updateURL })
|
||||||
url: `https://ollama.ai/api/update?os=${process.platform}&arch=${
|
|
||||||
process.arch
|
|
||||||
}&version=${app.getVersion()}&id=${id()}`,
|
|
||||||
})
|
|
||||||
|
|
||||||
autoUpdater.on('error', e => {
|
autoUpdater.on('error', e => {
|
||||||
logger.error(`update check failed - ${e.message}`)
|
logger.error(`update check failed - ${e.message}`)
|
||||||
|
171
cmd/cmd.go
171
cmd/cmd.go
@@ -11,6 +11,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -22,7 +23,6 @@ import (
|
|||||||
|
|
||||||
"github.com/dustin/go-humanize"
|
"github.com/dustin/go-humanize"
|
||||||
"github.com/olekukonko/tablewriter"
|
"github.com/olekukonko/tablewriter"
|
||||||
"github.com/pdevine/readline"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
@@ -30,30 +30,11 @@ import (
|
|||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/format"
|
"github.com/jmorganca/ollama/format"
|
||||||
"github.com/jmorganca/ollama/progressbar"
|
"github.com/jmorganca/ollama/progressbar"
|
||||||
|
"github.com/jmorganca/ollama/readline"
|
||||||
"github.com/jmorganca/ollama/server"
|
"github.com/jmorganca/ollama/server"
|
||||||
"github.com/jmorganca/ollama/version"
|
"github.com/jmorganca/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Painter struct {
|
|
||||||
IsMultiLine bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p Painter) Paint(line []rune, _ int) []rune {
|
|
||||||
termType := os.Getenv("TERM")
|
|
||||||
if termType == "xterm-256color" && len(line) == 0 {
|
|
||||||
var prompt string
|
|
||||||
if p.IsMultiLine {
|
|
||||||
prompt = "Use \"\"\" to end multi-line input"
|
|
||||||
} else {
|
|
||||||
prompt = "Send a message (/? for help)"
|
|
||||||
}
|
|
||||||
return []rune(fmt.Sprintf("\033[38;5;245m%s\033[%dD\033[0m", prompt, len(prompt)))
|
|
||||||
}
|
|
||||||
// add a space and a backspace to prevent the cursor from walking up the screen
|
|
||||||
line = append(line, []rune(" \b")...)
|
|
||||||
return line
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
filename, _ := cmd.Flags().GetString("file")
|
filename, _ := cmd.Flags().GetString("file")
|
||||||
filename, err := filepath.Abs(filename)
|
filename, err := filepath.Abs(filename)
|
||||||
@@ -78,18 +59,12 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
}
|
}
|
||||||
currentDigest = resp.Digest
|
currentDigest = resp.Digest
|
||||||
switch {
|
// pulling
|
||||||
case strings.Contains(resp.Status, "embeddings"):
|
bar = progressbar.DefaultBytes(
|
||||||
bar = progressbar.Default(resp.Total, resp.Status)
|
resp.Total,
|
||||||
bar.Set64(resp.Completed)
|
resp.Status,
|
||||||
default:
|
)
|
||||||
// pulling
|
bar.Set64(resp.Completed)
|
||||||
bar = progressbar.DefaultBytes(
|
|
||||||
resp.Total,
|
|
||||||
resp.Status,
|
|
||||||
)
|
|
||||||
bar.Set64(resp.Completed)
|
|
||||||
}
|
|
||||||
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
||||||
bar.Set64(resp.Completed)
|
bar.Set64(resp.Completed)
|
||||||
} else {
|
} else {
|
||||||
@@ -124,19 +99,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
models, err := client.List(context.Background())
|
name := args[0]
|
||||||
if err != nil {
|
// check if the model exists on the server
|
||||||
return err
|
_, err = client.Show(context.Background(), &api.ShowRequest{Name: name})
|
||||||
}
|
var statusError api.StatusError
|
||||||
|
switch {
|
||||||
canonicalModelPath := server.ParseModelPath(args[0])
|
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
|
||||||
for _, model := range models.Models {
|
if err := PullHandler(cmd, args); err != nil {
|
||||||
if model.Name == canonicalModelPath.GetShortTagname() {
|
return err
|
||||||
return RunGenerate(cmd, args)
|
|
||||||
}
|
}
|
||||||
}
|
case err != nil:
|
||||||
|
|
||||||
if err := PullHandler(cmd, args); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -514,38 +486,11 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func generateInteractive(cmd *cobra.Command, model string) error {
|
func generateInteractive(cmd *cobra.Command, model string) error {
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// load the model
|
// load the model
|
||||||
if err := generate(cmd, model, "", false); err != nil {
|
if err := generate(cmd, model, "", false); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
completer := readline.NewPrefixCompleter(
|
|
||||||
readline.PcItem("/help"),
|
|
||||||
readline.PcItem("/list"),
|
|
||||||
readline.PcItem("/set",
|
|
||||||
readline.PcItem("history"),
|
|
||||||
readline.PcItem("nohistory"),
|
|
||||||
readline.PcItem("wordwrap"),
|
|
||||||
readline.PcItem("nowordwrap"),
|
|
||||||
readline.PcItem("verbose"),
|
|
||||||
readline.PcItem("quiet"),
|
|
||||||
),
|
|
||||||
readline.PcItem("/show",
|
|
||||||
readline.PcItem("license"),
|
|
||||||
readline.PcItem("modelfile"),
|
|
||||||
readline.PcItem("parameters"),
|
|
||||||
readline.PcItem("system"),
|
|
||||||
readline.PcItem("template"),
|
|
||||||
),
|
|
||||||
readline.PcItem("/exit"),
|
|
||||||
readline.PcItem("/bye"),
|
|
||||||
)
|
|
||||||
|
|
||||||
usage := func() {
|
usage := func() {
|
||||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||||
@@ -578,20 +523,17 @@ func generateInteractive(cmd *cobra.Command, model string) error {
|
|||||||
fmt.Fprintln(os.Stderr, "")
|
fmt.Fprintln(os.Stderr, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
var painter Painter
|
prompt := readline.Prompt{
|
||||||
|
Prompt: ">>> ",
|
||||||
config := readline.Config{
|
AltPrompt: "... ",
|
||||||
Painter: &painter,
|
Placeholder: "Send a message (/? for help)",
|
||||||
Prompt: ">>> ",
|
AltPlaceholder: `Use """ to end multi-line input`,
|
||||||
HistoryFile: filepath.Join(home, ".ollama", "history"),
|
|
||||||
AutoComplete: completer,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
scanner, err := readline.NewEx(&config)
|
scanner, err := readline.New(prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer scanner.Close()
|
|
||||||
|
|
||||||
var wordWrap bool
|
var wordWrap bool
|
||||||
termType := os.Getenv("TERM")
|
termType := os.Getenv("TERM")
|
||||||
@@ -608,17 +550,20 @@ func generateInteractive(cmd *cobra.Command, model string) error {
|
|||||||
wordWrap = false
|
wordWrap = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Print(readline.StartBracketedPaste)
|
||||||
|
defer fmt.Printf(readline.EndBracketedPaste)
|
||||||
|
|
||||||
var multiLineBuffer string
|
var multiLineBuffer string
|
||||||
var isMultiLine bool
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, err := scanner.Readline()
|
line, err := scanner.Readline()
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, io.EOF):
|
case errors.Is(err, io.EOF):
|
||||||
|
fmt.Println()
|
||||||
return nil
|
return nil
|
||||||
case errors.Is(err, readline.ErrInterrupt):
|
case errors.Is(err, readline.ErrInterrupt):
|
||||||
if line == "" {
|
if line == "" {
|
||||||
fmt.Println("Use Ctrl-D or /bye to exit.")
|
fmt.Println("\nUse Ctrl-D or /bye to exit.")
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
@@ -629,23 +574,19 @@ func generateInteractive(cmd *cobra.Command, model string) error {
|
|||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case isMultiLine:
|
case scanner.Prompt.UseAlt:
|
||||||
if strings.HasSuffix(line, `"""`) {
|
if strings.HasSuffix(line, `"""`) {
|
||||||
isMultiLine = false
|
scanner.Prompt.UseAlt = false
|
||||||
painter.IsMultiLine = isMultiLine
|
|
||||||
multiLineBuffer += strings.TrimSuffix(line, `"""`)
|
multiLineBuffer += strings.TrimSuffix(line, `"""`)
|
||||||
line = multiLineBuffer
|
line = multiLineBuffer
|
||||||
multiLineBuffer = ""
|
multiLineBuffer = ""
|
||||||
scanner.SetPrompt(">>> ")
|
|
||||||
} else {
|
} else {
|
||||||
multiLineBuffer += line + " "
|
multiLineBuffer += line + " "
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
case strings.HasPrefix(line, `"""`):
|
case strings.HasPrefix(line, `"""`):
|
||||||
isMultiLine = true
|
scanner.Prompt.UseAlt = true
|
||||||
painter.IsMultiLine = isMultiLine
|
|
||||||
multiLineBuffer = strings.TrimPrefix(line, `"""`) + " "
|
multiLineBuffer = strings.TrimPrefix(line, `"""`) + " "
|
||||||
scanner.SetPrompt("... ")
|
|
||||||
continue
|
continue
|
||||||
case strings.HasPrefix(line, "/list"):
|
case strings.HasPrefix(line, "/list"):
|
||||||
args := strings.Fields(line)
|
args := strings.Fields(line)
|
||||||
@@ -672,19 +613,6 @@ func generateInteractive(cmd *cobra.Command, model string) error {
|
|||||||
case "quiet":
|
case "quiet":
|
||||||
cmd.Flags().Set("verbose", "false")
|
cmd.Flags().Set("verbose", "false")
|
||||||
fmt.Println("Set 'quiet' mode.")
|
fmt.Println("Set 'quiet' mode.")
|
||||||
case "mode":
|
|
||||||
if len(args) > 2 {
|
|
||||||
switch args[2] {
|
|
||||||
case "vim":
|
|
||||||
scanner.SetVimMode(true)
|
|
||||||
case "emacs", "default":
|
|
||||||
scanner.SetVimMode(false)
|
|
||||||
default:
|
|
||||||
usage()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
usage()
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||||
}
|
}
|
||||||
@@ -694,7 +622,12 @@ func generateInteractive(cmd *cobra.Command, model string) error {
|
|||||||
case strings.HasPrefix(line, "/show"):
|
case strings.HasPrefix(line, "/show"):
|
||||||
args := strings.Fields(line)
|
args := strings.Fields(line)
|
||||||
if len(args) > 1 {
|
if len(args) > 1 {
|
||||||
resp, err := server.GetModelInfo(model)
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("error: couldn't connect to ollama server")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("error: couldn't get model")
|
fmt.Println("error: couldn't get model")
|
||||||
return err
|
return err
|
||||||
@@ -796,21 +729,6 @@ func RunServer(cmd *cobra.Command, _ []string) error {
|
|||||||
origins = strings.Split(o, ",")
|
origins = strings.Split(o, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
|
||||||
if err := server.PruneLayers(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
manifestsPath, err := server.GetManifestPath()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := server.PruneDirectory(manifestsPath); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return server.Serve(ln, origins)
|
return server.Serve(ln, origins)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -933,7 +851,7 @@ func NewCLI() *cobra.Command {
|
|||||||
createCmd := &cobra.Command{
|
createCmd := &cobra.Command{
|
||||||
Use: "create MODEL",
|
Use: "create MODEL",
|
||||||
Short: "Create a model from a Modelfile",
|
Short: "Create a model from a Modelfile",
|
||||||
Args: cobra.MinimumNArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
PreRunE: checkServerHeartbeat,
|
PreRunE: checkServerHeartbeat,
|
||||||
RunE: CreateHandler,
|
RunE: CreateHandler,
|
||||||
}
|
}
|
||||||
@@ -943,7 +861,7 @@ func NewCLI() *cobra.Command {
|
|||||||
showCmd := &cobra.Command{
|
showCmd := &cobra.Command{
|
||||||
Use: "show MODEL",
|
Use: "show MODEL",
|
||||||
Short: "Show information for a model",
|
Short: "Show information for a model",
|
||||||
Args: cobra.MinimumNArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
PreRunE: checkServerHeartbeat,
|
PreRunE: checkServerHeartbeat,
|
||||||
RunE: ShowHandler,
|
RunE: ShowHandler,
|
||||||
}
|
}
|
||||||
@@ -970,13 +888,14 @@ func NewCLI() *cobra.Command {
|
|||||||
Use: "serve",
|
Use: "serve",
|
||||||
Aliases: []string{"start"},
|
Aliases: []string{"start"},
|
||||||
Short: "Start ollama",
|
Short: "Start ollama",
|
||||||
|
Args: cobra.ExactArgs(0),
|
||||||
RunE: RunServer,
|
RunE: RunServer,
|
||||||
}
|
}
|
||||||
|
|
||||||
pullCmd := &cobra.Command{
|
pullCmd := &cobra.Command{
|
||||||
Use: "pull MODEL",
|
Use: "pull MODEL",
|
||||||
Short: "Pull a model from a registry",
|
Short: "Pull a model from a registry",
|
||||||
Args: cobra.MinimumNArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
PreRunE: checkServerHeartbeat,
|
PreRunE: checkServerHeartbeat,
|
||||||
RunE: PullHandler,
|
RunE: PullHandler,
|
||||||
}
|
}
|
||||||
@@ -986,7 +905,7 @@ func NewCLI() *cobra.Command {
|
|||||||
pushCmd := &cobra.Command{
|
pushCmd := &cobra.Command{
|
||||||
Use: "push MODEL",
|
Use: "push MODEL",
|
||||||
Short: "Push a model to a registry",
|
Short: "Push a model to a registry",
|
||||||
Args: cobra.MinimumNArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
PreRunE: checkServerHeartbeat,
|
PreRunE: checkServerHeartbeat,
|
||||||
RunE: PushHandler,
|
RunE: PushHandler,
|
||||||
}
|
}
|
||||||
@@ -1002,15 +921,15 @@ func NewCLI() *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
copyCmd := &cobra.Command{
|
copyCmd := &cobra.Command{
|
||||||
Use: "cp",
|
Use: "cp SOURCE TARGET",
|
||||||
Short: "Copy a model",
|
Short: "Copy a model",
|
||||||
Args: cobra.MinimumNArgs(2),
|
Args: cobra.ExactArgs(2),
|
||||||
PreRunE: checkServerHeartbeat,
|
PreRunE: checkServerHeartbeat,
|
||||||
RunE: CopyHandler,
|
RunE: CopyHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
deleteCmd := &cobra.Command{
|
deleteCmd := &cobra.Command{
|
||||||
Use: "rm",
|
Use: "rm MODEL [MODEL...]",
|
||||||
Short: "Remove a model",
|
Short: "Remove a model",
|
||||||
Args: cobra.MinimumNArgs(1),
|
Args: cobra.MinimumNArgs(1),
|
||||||
PreRunE: checkServerHeartbeat,
|
PreRunE: checkServerHeartbeat,
|
||||||
|
154
docs/api.md
154
docs/api.md
@@ -45,9 +45,11 @@ Advanced parameters (optional):
|
|||||||
- `system`: system prompt to (overrides what is defined in the `Modelfile`)
|
- `system`: system prompt to (overrides what is defined in the `Modelfile`)
|
||||||
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
|
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
|
||||||
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||||
- `stream`: if `false` the response will be be returned as a single response object, rather than a stream of objects
|
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
|
|
||||||
### Request
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -X POST http://localhost:11434/api/generate -d '{
|
curl -X POST http://localhost:11434/api/generate -d '{
|
||||||
@@ -56,9 +58,9 @@ curl -X POST http://localhost:11434/api/generate -d '{
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Response
|
#### Response
|
||||||
|
|
||||||
A stream of JSON objects:
|
A stream of JSON objects is returned:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -102,6 +104,38 @@ To calculate how fast the response is generated in tokens per second (token/s),
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -X POST http://localhost:11434/api/generate -d '{
|
||||||
|
"model": "llama2:7b",
|
||||||
|
"prompt": "Why is the sky blue?",
|
||||||
|
"stream": false
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Response
|
||||||
|
|
||||||
|
If `stream` is set to `false`, the response will be a single JSON object:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama2:7b",
|
||||||
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
|
"response": "The sky is blue because it is the color of the sky.",
|
||||||
|
"context": [1, 2, 3],
|
||||||
|
"done": true,
|
||||||
|
"total_duration": 5589157167,
|
||||||
|
"load_duration": 3013701500,
|
||||||
|
"sample_count": 114,
|
||||||
|
"sample_duration": 81442000,
|
||||||
|
"prompt_eval_count": 46,
|
||||||
|
"prompt_eval_duration": 1160282000,
|
||||||
|
"eval_count": 13,
|
||||||
|
"eval_duration": 1325948000
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Create a Model
|
## Create a Model
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -114,9 +148,11 @@ Create a model from a [`Modelfile`](./modelfile.md)
|
|||||||
|
|
||||||
- `name`: name of the model to create
|
- `name`: name of the model to create
|
||||||
- `path`: path to the Modelfile
|
- `path`: path to the Modelfile
|
||||||
- `stream`: (optional) if `false` the response will be be returned as a single response object, rather than a stream of objects
|
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
|
|
||||||
### Request
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -X POST http://localhost:11434/api/create -d '{
|
curl -X POST http://localhost:11434/api/create -d '{
|
||||||
@@ -125,7 +161,7 @@ curl -X POST http://localhost:11434/api/create -d '{
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Response
|
#### Response
|
||||||
|
|
||||||
A stream of JSON objects. When finished, `status` is `success`.
|
A stream of JSON objects. When finished, `status` is `success`.
|
||||||
|
|
||||||
@@ -143,13 +179,17 @@ GET /api/tags
|
|||||||
|
|
||||||
List models that are available locally.
|
List models that are available locally.
|
||||||
|
|
||||||
### Request
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/tags
|
curl http://localhost:11434/api/tags
|
||||||
```
|
```
|
||||||
|
|
||||||
### Response
|
#### Response
|
||||||
|
|
||||||
|
A single JSON object will be returned.
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -180,7 +220,9 @@ Show details about a model including modelfile, template, parameters, license, a
|
|||||||
|
|
||||||
- `name`: name of the model to show
|
- `name`: name of the model to show
|
||||||
|
|
||||||
### Request
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/show -d '{
|
curl http://localhost:11434/api/show -d '{
|
||||||
@@ -188,7 +230,7 @@ curl http://localhost:11434/api/show -d '{
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Response
|
#### Response
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -207,7 +249,9 @@ POST /api/copy
|
|||||||
|
|
||||||
Copy a model. Creates a model with another name from an existing model.
|
Copy a model. Creates a model with another name from an existing model.
|
||||||
|
|
||||||
### Request
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/copy -d '{
|
curl http://localhost:11434/api/copy -d '{
|
||||||
@@ -216,6 +260,10 @@ curl http://localhost:11434/api/copy -d '{
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Response
|
||||||
|
|
||||||
|
The only response is a 200 OK if successful.
|
||||||
|
|
||||||
## Delete a Model
|
## Delete a Model
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -226,9 +274,11 @@ Delete a model and its data.
|
|||||||
|
|
||||||
### Parameters
|
### Parameters
|
||||||
|
|
||||||
- `model`: model name to delete
|
- `name`: model name to delete
|
||||||
|
|
||||||
### Request
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -X DELETE http://localhost:11434/api/delete -d '{
|
curl -X DELETE http://localhost:11434/api/delete -d '{
|
||||||
@@ -236,6 +286,10 @@ curl -X DELETE http://localhost:11434/api/delete -d '{
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Response
|
||||||
|
|
||||||
|
If successful, the only response is a 200 OK.
|
||||||
|
|
||||||
## Pull a Model
|
## Pull a Model
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -248,9 +302,11 @@ Download a model from the ollama library. Cancelled pulls are resumed from where
|
|||||||
|
|
||||||
- `name`: name of the model to pull
|
- `name`: name of the model to pull
|
||||||
- `insecure`: (optional) allow insecure connections to the library. Only use this if you are pulling from your own library during development.
|
- `insecure`: (optional) allow insecure connections to the library. Only use this if you are pulling from your own library during development.
|
||||||
- `stream`: (optional) if `false` the response will be be returned as a single response object, rather than a stream of objects
|
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
|
|
||||||
### Request
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -X POST http://localhost:11434/api/pull -d '{
|
curl -X POST http://localhost:11434/api/pull -d '{
|
||||||
@@ -258,13 +314,51 @@ curl -X POST http://localhost:11434/api/pull -d '{
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Response
|
#### Response
|
||||||
|
|
||||||
|
If `stream` is not specified, or set to `true`, a stream of JSON objects is returned:
|
||||||
|
|
||||||
|
The first object is the manifest:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "pulling manifest"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Then there is a series of downloading responses. Until any of the download is completed, the `completed` key may not be included. The number of files to be downloaded depends on the number of layers specified in the manifest.
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"status": "downloading digestname",
|
"status": "downloading digestname",
|
||||||
"digest": "digestname",
|
"digest": "digestname",
|
||||||
"total": 2142590208
|
"total": 2142590208,
|
||||||
|
"completed": 241970
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
After all the files are downloaded, the final responses are:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "verifying sha256 digest"
|
||||||
|
}
|
||||||
|
{
|
||||||
|
"status": "writing manifest"
|
||||||
|
}
|
||||||
|
{
|
||||||
|
"status": "removing any unused layers"
|
||||||
|
}
|
||||||
|
{
|
||||||
|
"status": "success"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
if `stream` is set to false, then the response is a single JSON object:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "success"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -280,9 +374,11 @@ Upload a model to a model library. Requires registering for ollama.ai and adding
|
|||||||
|
|
||||||
- `name`: name of the model to push in the form of `<namespace>/<model>:<tag>`
|
- `name`: name of the model to push in the form of `<namespace>/<model>:<tag>`
|
||||||
- `insecure`: (optional) allow insecure connections to the library. Only use this if you are pushing to your library during development.
|
- `insecure`: (optional) allow insecure connections to the library. Only use this if you are pushing to your library during development.
|
||||||
- `stream`: (optional) if `false` the response will be be returned as a single response object, rather than a stream of objects
|
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
|
|
||||||
### Request
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -X POST http://localhost:11434/api/push -d '{
|
curl -X POST http://localhost:11434/api/push -d '{
|
||||||
@@ -290,9 +386,9 @@ curl -X POST http://localhost:11434/api/push -d '{
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Response
|
#### Response
|
||||||
|
|
||||||
Streaming response that starts with:
|
If `stream` is not specified, or set to `true`, a stream of JSON objects is returned:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{ "status": "retrieving manifest" }
|
{ "status": "retrieving manifest" }
|
||||||
@@ -325,6 +421,12 @@ Finally, when the upload is complete:
|
|||||||
{"status":"success"}
|
{"status":"success"}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If `stream` is set to `false`, then the response is a single JSON object:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "status": "success" }
|
||||||
|
```
|
||||||
|
|
||||||
## Generate Embeddings
|
## Generate Embeddings
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -342,7 +444,9 @@ Advanced parameters:
|
|||||||
|
|
||||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||||
|
|
||||||
### Request
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -X POST http://localhost:11434/api/embeddings -d '{
|
curl -X POST http://localhost:11434/api/embeddings -d '{
|
||||||
@@ -351,11 +455,11 @@ curl -X POST http://localhost:11434/api/embeddings -d '{
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Response
|
#### Response
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"embeddings": [
|
"embedding": [
|
||||||
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
|
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
|
||||||
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
|
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
|
||||||
]
|
]
|
||||||
|
69
docs/faq.md
69
docs/faq.md
@@ -1,18 +1,79 @@
|
|||||||
# FAQ
|
# FAQ
|
||||||
|
|
||||||
## How can I expose the Ollama server?
|
## How can I view the logs?
|
||||||
|
|
||||||
|
On macOS:
|
||||||
|
|
||||||
|
```
|
||||||
|
cat ~/.ollama/logs/server.log
|
||||||
|
```
|
||||||
|
|
||||||
|
On Linux:
|
||||||
|
|
||||||
|
```
|
||||||
|
journalctl -u ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
If you're running `ollama serve` directly, the logs will be printed to the console.
|
||||||
|
|
||||||
|
## How can I expose Ollama on my network?
|
||||||
|
|
||||||
|
Ollama binds to 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable.
|
||||||
|
|
||||||
|
On macOS:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
OLLAMA_HOST=0.0.0.0:11435 ollama serve
|
OLLAMA_HOST=0.0.0.0:11435 ollama serve
|
||||||
```
|
```
|
||||||
|
|
||||||
By default, Ollama allows cross origin requests from `127.0.0.1` and `0.0.0.0`. To support more origins, you can use the `OLLAMA_ORIGINS` environment variable:
|
On Linux:
|
||||||
|
|
||||||
|
Create a `systemd` drop-in directory and set `Environment=OLLAMA_HOST`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir -p /etc/systemd/system/ollama.service.d
|
||||||
|
echo "[Service]" >>/etc/systemd/system/ollama.service.d/environment.conf
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
echo "Environment=OLLAMA_HOST=0.0.0.0:11434" >>/etc/systemd/system/ollama.service.d/environment.conf
|
||||||
|
```
|
||||||
|
|
||||||
|
Reload `systemd` and restart Ollama:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
systemctl daemon-reload
|
||||||
|
systemctl restart ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
## How can I allow additional web origins to access Ollama?
|
||||||
|
|
||||||
|
Ollama allows cross origin requests from `127.0.0.1` and `0.0.0.0` by default. Add additional origins with the `OLLAMA_ORIGINS` environment variable:
|
||||||
|
|
||||||
|
On macOS:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
OLLAMA_ORIGINS=http://192.168.1.1:*,https://example.com ollama serve
|
OLLAMA_ORIGINS=http://192.168.1.1:*,https://example.com ollama serve
|
||||||
```
|
```
|
||||||
|
|
||||||
|
On Linux:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
echo "Environment=OLLAMA_ORIGINS=http://129.168.1.1:*,https://example.com" >>/etc/systemd/system/ollama.service.d/environment.conf
|
||||||
|
```
|
||||||
|
|
||||||
|
Reload `systemd` and restart Ollama:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
systemctl daemon-reload
|
||||||
|
systemctl restart ollama
|
||||||
|
```
|
||||||
|
|
||||||
## Where are models stored?
|
## Where are models stored?
|
||||||
|
|
||||||
* macOS: Raw model data is stored under `~/.ollama/models`.
|
- macOS: Raw model data is stored under `~/.ollama/models`.
|
||||||
* Linux: Raw model data is stored under `/usr/share/ollama/.ollama/models`
|
- Linux: Raw model data is stored under `/usr/share/ollama/.ollama/models`
|
||||||
|
|
||||||
|
### How can I change where Ollama stores models?
|
||||||
|
|
||||||
|
To modify where models are stored, you can use the `OLLAMA_MODELS` environment variable. Note that on Linux this means defining `OLLAMA_MODELS` in a drop-in `/etc/systemd/system/ollama.service.d` service file, reloading systemd, and restarting the ollama service.
|
||||||
|
198
docs/import.md
Normal file
198
docs/import.md
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# Import a model
|
||||||
|
|
||||||
|
This guide walks through importing a GGUF, PyTorch or Safetensors model.
|
||||||
|
|
||||||
|
## Importing (GGUF)
|
||||||
|
|
||||||
|
### Step 1: Write a `Modelfile`
|
||||||
|
|
||||||
|
Start by creating a `Modelfile`. This file is the blueprint for your model, specifying weights, parameters, prompt templates and more.
|
||||||
|
|
||||||
|
```
|
||||||
|
FROM ./mistral-7b-v0.1.Q4_0.gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`:
|
||||||
|
|
||||||
|
```
|
||||||
|
FROM ./q4_0.bin
|
||||||
|
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Create the Ollama model
|
||||||
|
|
||||||
|
Finally, create a model from your `Modelfile`:
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama create example -f Modelfile
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Run your model
|
||||||
|
|
||||||
|
Next, test the model with `ollama run`:
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama run example "What is your favourite condiment?"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Importing (PyTorch & Safetensors)
|
||||||
|
|
||||||
|
### Supported models
|
||||||
|
|
||||||
|
Ollama supports a set of model architectures, with support for more coming soon:
|
||||||
|
|
||||||
|
- Llama & Mistral
|
||||||
|
- Falcon & RW
|
||||||
|
- GPT-NeoX
|
||||||
|
- BigCode
|
||||||
|
|
||||||
|
To view a model's architecture, check the `config.json` file in its HuggingFace repo. You should see an entry under `architectures` (e.g. `LlamaForCausalLM`).
|
||||||
|
|
||||||
|
### Step 1: Clone the HuggingFace repository (optional)
|
||||||
|
|
||||||
|
If the model is currently hosted in a HuggingFace repository, first clone that repository to download the raw model.
|
||||||
|
|
||||||
|
```
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
||||||
|
cd Mistral-7B-Instruct-v0.1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Convert and quantize to a `.bin` file (optional, for PyTorch and Safetensors)
|
||||||
|
|
||||||
|
If the model is in PyTorch or Safetensors format, a [Docker image](https://hub.docker.com/r/ollama/quantize) with the tooling required to convert and quantize models is available.
|
||||||
|
|
||||||
|
First, Install [Docker](https://www.docker.com/get-started/).
|
||||||
|
|
||||||
|
Next, to convert and quantize your model, run:
|
||||||
|
|
||||||
|
```
|
||||||
|
docker run --rm -v .:/model ollama/quantize -q q4_0 /model
|
||||||
|
```
|
||||||
|
|
||||||
|
This will output two files into the directory:
|
||||||
|
|
||||||
|
- `f16.bin`: the model converted to GGUF
|
||||||
|
- `q4_0.bin` the model quantized to a 4-bit quantization (we will use this file to create the Ollama model)
|
||||||
|
|
||||||
|
### Step 3: Write a `Modelfile`
|
||||||
|
|
||||||
|
Next, create a `Modelfile` for your model:
|
||||||
|
|
||||||
|
```
|
||||||
|
FROM ./q4_0.bin
|
||||||
|
```
|
||||||
|
|
||||||
|
(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`:
|
||||||
|
|
||||||
|
```
|
||||||
|
FROM ./q4_0.bin
|
||||||
|
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 4: Create the Ollama model
|
||||||
|
|
||||||
|
Finally, create a model from your `Modelfile`:
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama create example -f Modelfile
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 5: Run your model
|
||||||
|
|
||||||
|
Next, test the model with `ollama run`:
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama run example "What is your favourite condiment?"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Publishing your model (optional – early alpha)
|
||||||
|
|
||||||
|
Publishing models is in early alpha. If you'd like to publish your model to share with others, follow these steps:
|
||||||
|
|
||||||
|
1. Create [an account](https://ollama.ai/signup)
|
||||||
|
2. Run `cat ~/.ollama/id_ed25519.pub` to view your Ollama public key. Copy this to the clipboard.
|
||||||
|
3. Add your public key to your [Ollama account](https://ollama.ai/settings/keys)
|
||||||
|
|
||||||
|
Next, copy your model to your username's namespace:
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama cp example <your username>/example
|
||||||
|
```
|
||||||
|
|
||||||
|
Then push the model:
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama push <your username>/example
|
||||||
|
```
|
||||||
|
|
||||||
|
After publishing, your model will be available at `https://ollama.ai/<your username>/example`.
|
||||||
|
|
||||||
|
## Quantization reference
|
||||||
|
|
||||||
|
The quantization options are as follow (from highest highest to lowest levels of quantization). Note: some architectures such as Falcon do not support K quants.
|
||||||
|
|
||||||
|
- `q2_K`
|
||||||
|
- `q3_K`
|
||||||
|
- `q3_K_S`
|
||||||
|
- `q3_K_M`
|
||||||
|
- `q3_K_L`
|
||||||
|
- `q4_0` (recommended)
|
||||||
|
- `q4_1`
|
||||||
|
- `q4_K`
|
||||||
|
- `q4_K_S`
|
||||||
|
- `q4_K_M`
|
||||||
|
- `q5_0`
|
||||||
|
- `q5_1`
|
||||||
|
- `q5_K`
|
||||||
|
- `q5_K_S`
|
||||||
|
- `q5_K_M`
|
||||||
|
- `q6_K`
|
||||||
|
- `q8_0`
|
||||||
|
|
||||||
|
## Manually converting & quantizing models
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
Start by cloning the `llama.cpp` repo to your machine in another directory:
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://github.com/ggerganov/llama.cpp.git
|
||||||
|
cd llama.cpp
|
||||||
|
```
|
||||||
|
|
||||||
|
Next, install the Python dependencies:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Finally, build the `quantize` tool:
|
||||||
|
|
||||||
|
```
|
||||||
|
make quantize
|
||||||
|
```
|
||||||
|
|
||||||
|
### Convert the model
|
||||||
|
|
||||||
|
Run the correct conversion script for your model architecture:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# LlamaForCausalLM or MistralForCausalLM
|
||||||
|
python convert.py <path to model directory>
|
||||||
|
|
||||||
|
# FalconForCausalLM
|
||||||
|
python convert-falcon-hf-to-gguf.py <path to model directory>
|
||||||
|
|
||||||
|
# GPTNeoXForCausalLM
|
||||||
|
python convert-gptneox-hf-to-gguf.py <path to model directory>
|
||||||
|
|
||||||
|
# GPTBigCodeForCausalLM
|
||||||
|
python convert-starcoder-hf-to-gguf.py <path to model directory>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quantize the model
|
||||||
|
|
||||||
|
```
|
||||||
|
quantize <path to model dir>/ggml-model-f32.bin <path to model dir>/q4_0.bin q4_0
|
||||||
|
```
|
@@ -1,12 +1,16 @@
|
|||||||
# Installing Ollama on Linux
|
# Ollama on Linux
|
||||||
|
|
||||||
> Note: A one line installer for Ollama is available by running:
|
## Install
|
||||||
|
|
||||||
|
Install Ollama running this one-liner:
|
||||||
>
|
>
|
||||||
> ```bash
|
```bash
|
||||||
> curl https://ollama.ai/install.sh | sh
|
curl https://ollama.ai/install.sh | sh
|
||||||
> ```
|
```
|
||||||
|
|
||||||
## Download the `ollama` binary
|
## Manual install
|
||||||
|
|
||||||
|
### Download the `ollama` binary
|
||||||
|
|
||||||
Ollama is distributed as a self-contained binary. Download it to a directory in your PATH:
|
Ollama is distributed as a self-contained binary. Download it to a directory in your PATH:
|
||||||
|
|
||||||
@@ -15,31 +19,7 @@ sudo curl -L https://ollama.ai/download/ollama-linux-amd64 -o /usr/bin/ollama
|
|||||||
sudo chmod +x /usr/bin/ollama
|
sudo chmod +x /usr/bin/ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
## Start Ollama
|
### Adding Ollama as a startup service (recommended)
|
||||||
|
|
||||||
Start Ollama by running `ollama serve`:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ollama serve
|
|
||||||
```
|
|
||||||
|
|
||||||
Once Ollama is running, run a model in another terminal session:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ollama run llama2
|
|
||||||
```
|
|
||||||
|
|
||||||
## Install CUDA drivers (optional – for Nvidia GPUs)
|
|
||||||
|
|
||||||
[Download and install](https://developer.nvidia.com/cuda-downloads) CUDA.
|
|
||||||
|
|
||||||
Verify that the drivers are installed by running the following command, which should print details about your GPU:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
nvidia-smi
|
|
||||||
```
|
|
||||||
|
|
||||||
## Adding Ollama as a startup service (optional)
|
|
||||||
|
|
||||||
Create a user for Ollama:
|
Create a user for Ollama:
|
||||||
|
|
||||||
@@ -60,7 +40,6 @@ User=ollama
|
|||||||
Group=ollama
|
Group=ollama
|
||||||
Restart=always
|
Restart=always
|
||||||
RestartSec=3
|
RestartSec=3
|
||||||
Environment="HOME=/usr/share/ollama"
|
|
||||||
|
|
||||||
[Install]
|
[Install]
|
||||||
WantedBy=default.target
|
WantedBy=default.target
|
||||||
@@ -73,7 +52,40 @@ sudo systemctl daemon-reload
|
|||||||
sudo systemctl enable ollama
|
sudo systemctl enable ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
### Viewing logs
|
### Install CUDA drivers (optional – for Nvidia GPUs)
|
||||||
|
|
||||||
|
[Download and install](https://developer.nvidia.com/cuda-downloads) CUDA.
|
||||||
|
|
||||||
|
Verify that the drivers are installed by running the following command, which should print details about your GPU:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nvidia-smi
|
||||||
|
```
|
||||||
|
|
||||||
|
### Start Ollama
|
||||||
|
|
||||||
|
Start Ollama using `systemd`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo systemctl start ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
## Update
|
||||||
|
|
||||||
|
Update ollama by running the install script again:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl https://ollama.ai/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Or by downloading the ollama binary:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo curl -L https://ollama.ai/download/ollama-linux-amd64 -o /usr/bin/ollama
|
||||||
|
sudo chmod +x /usr/bin/ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
## Viewing logs
|
||||||
|
|
||||||
To view logs of Ollama running as a startup service, run:
|
To view logs of Ollama running as a startup service, run:
|
||||||
|
|
||||||
@@ -81,3 +93,24 @@ To view logs of Ollama running as a startup service, run:
|
|||||||
journalctl -u ollama
|
journalctl -u ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Uninstall
|
||||||
|
|
||||||
|
Remove the ollama service:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo systemctl stop ollama
|
||||||
|
sudo systemctl disable ollama
|
||||||
|
sudo rm /etc/systemd/system/ollama.service
|
||||||
|
```
|
||||||
|
|
||||||
|
Remove the ollama binary from your bin directory (either `/usr/local/bin`, `/usr/bin`, or `/bin`):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo rm $(which ollama)
|
||||||
|
```
|
||||||
|
|
||||||
|
Remove the downloaded models and Ollama service user:
|
||||||
|
```bash
|
||||||
|
sudo rm -r /usr/share/ollama
|
||||||
|
sudo userdel ollama
|
||||||
|
```
|
||||||
|
@@ -12,7 +12,6 @@ A model file is the blueprint to create and share models with Ollama.
|
|||||||
- [FROM (Required)](#from-required)
|
- [FROM (Required)](#from-required)
|
||||||
- [Build from llama2](#build-from-llama2)
|
- [Build from llama2](#build-from-llama2)
|
||||||
- [Build from a bin file](#build-from-a-bin-file)
|
- [Build from a bin file](#build-from-a-bin-file)
|
||||||
- [EMBED](#embed)
|
|
||||||
- [PARAMETER](#parameter)
|
- [PARAMETER](#parameter)
|
||||||
- [Valid Parameters and Values](#valid-parameters-and-values)
|
- [Valid Parameters and Values](#valid-parameters-and-values)
|
||||||
- [TEMPLATE](#template)
|
- [TEMPLATE](#template)
|
||||||
@@ -91,17 +90,6 @@ FROM ./ollama-model.bin
|
|||||||
|
|
||||||
This bin file location should be specified as an absolute path or relative to the `Modelfile` location.
|
This bin file location should be specified as an absolute path or relative to the `Modelfile` location.
|
||||||
|
|
||||||
### EMBED
|
|
||||||
|
|
||||||
The `EMBED` instruction is used to add embeddings of files to a model. This is useful for adding custom data that the model can reference when generating an answer. Note that currently only text files are supported, formatted with each line as one embedding.
|
|
||||||
|
|
||||||
```modelfile
|
|
||||||
FROM <model name>:<tag>
|
|
||||||
EMBED <file path>.txt
|
|
||||||
EMBED <different file path>.txt
|
|
||||||
EMBED <path to directory>/*.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
### PARAMETER
|
### PARAMETER
|
||||||
|
|
||||||
The `PARAMETER` instruction defines a parameter that can be set when the model is run.
|
The `PARAMETER` instruction defines a parameter that can be set when the model is run.
|
||||||
|
@@ -1,96 +0,0 @@
|
|||||||
# How to Quantize a Model
|
|
||||||
|
|
||||||
Sometimes the model you want to work with is not available at [https://ollama.ai/library](https://ollama.ai/library). If you want to try out that model before we have a chance to quantize it, you can use this process.
|
|
||||||
|
|
||||||
## Figure out if we can run the model?
|
|
||||||
|
|
||||||
Not all models will work with Ollama. There are a number of factors that go into whether we are able to work with the next cool model. First it has to work with llama.cpp. Then we have to have implemented the features of llama.cpp that it requires. And then, sometimes, even with both of those, the model might not work...
|
|
||||||
|
|
||||||
1. What is the model you want to convert and upload?
|
|
||||||
2. Visit the model's page on HuggingFace.
|
|
||||||
3. Switch to the **Files and versions** tab.
|
|
||||||
4. Click on the **config.json** file. If there is no config.json file, it may not work.
|
|
||||||
5. Take note of the **architecture** list in the json file.
|
|
||||||
6. Does any entry in the list match one of the following architectures?
|
|
||||||
1. LlamaForCausalLM
|
|
||||||
2. MistralForCausalLM
|
|
||||||
3. RWForCausalLM
|
|
||||||
4. FalconForCausalLM
|
|
||||||
5. GPTNeoXForCausalLM
|
|
||||||
6. GPTBigCodeForCausalLM
|
|
||||||
7. If the answer is yes, then there is a good chance the model will run after being converted and quantized.
|
|
||||||
8. An alternative to this process is to visit [https://caniquant.tvl.st](https://caniquant.tvl.st) and enter the org/modelname in the box and submit.
|
|
||||||
|
|
||||||
At this point there are two processes you can use. You can either use a Docker container to convert and quantize, OR you can manually run the scripts. The Docker container is the easiest way to do it, but it requires you to have Docker installed on your machine. If you don't have Docker installed, you can follow the manual process.
|
|
||||||
|
|
||||||
## Convert and Quantize with Docker
|
|
||||||
|
|
||||||
Run `docker run --rm -v /path/to/model/repo:/repo ollama/quantize -q quantlevel /repo`. For instance, if you have downloaded the latest Mistral 7B model, then clone it to your machine. Then change into that directory and you can run:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
docker run --rm -v .:/repo ollama/quantize -q q4_0 /repo
|
|
||||||
```
|
|
||||||
|
|
||||||
You can find the different quantization levels below under **Quantize the Model**.
|
|
||||||
|
|
||||||
This will output two files into the directory. First is a f16.bin file that is the model converted to GGUF. The second file is a q4_0.bin file which is the model quantized to a 4 bit quantization. You should rename it to something more descriptive.
|
|
||||||
|
|
||||||
You can find the repository for the Docker container here: [https://github.com/mxyng/quantize](https://github.com/mxyng/quantize)
|
|
||||||
|
|
||||||
## Convert and Quantize Manually
|
|
||||||
|
|
||||||
### Clone llama.cpp to your machine
|
|
||||||
|
|
||||||
If we know the model has a chance of working, then we need to convert and quantize. This is a matter of running two separate scripts in the llama.cpp project.
|
|
||||||
|
|
||||||
1. Decide where you want the llama.cpp repository on your machine.
|
|
||||||
2. Navigate to that location and then run:
|
|
||||||
[`git clone https://github.com/ggerganov/llama.cpp.git`](https://github.com/ggerganov/llama.cpp.git)
|
|
||||||
1. If you don't have git installed, download this zip file and unzip it to that location: https://github.com/ggerganov/llama.cpp/archive/refs/heads/master.zip
|
|
||||||
3. Install the Python dependencies: `pip install torch transformers sentencepiece`
|
|
||||||
|
|
||||||
### Convert the model to GGUF
|
|
||||||
|
|
||||||
1. Decide on the right convert script to run. What was the model architecture you found in the first section.
|
|
||||||
1. LlamaForCausalLM or MistralForCausalLM:
|
|
||||||
run `python3 convert.py <modelfilename>`
|
|
||||||
No need to specify fp16 or fp32.
|
|
||||||
2. FalconForCausalLM or RWForCausalLM:
|
|
||||||
run `python3 convert-falcon-hf-to-gguf.py <modelfilename> <fpsize>`
|
|
||||||
fpsize depends on the weight size. 1 for fp16, 0 for fp32
|
|
||||||
3. GPTNeoXForCausalLM:
|
|
||||||
run `python3 convert-gptneox-hf-to-gguf.py <modelfilename> <fpsize>`
|
|
||||||
fpsize depends on the weight size. 1 for fp16, 0 for fp32
|
|
||||||
4. GPTBigCodeForCausalLM:
|
|
||||||
run `python3 convert-starcoder-hf-to-gguf.py <modelfilename> <fpsize>`
|
|
||||||
fpsize depends on the weight size. 1 for fp16, 0 for fp32
|
|
||||||
|
|
||||||
### Quantize the model
|
|
||||||
|
|
||||||
If the model converted successfully, there is a good chance it will also quantize successfully. Now you need to decide on the quantization to use. We will always try to create all the quantizations and upload them to the library. You should decide which level is more important to you and quantize accordingly.
|
|
||||||
|
|
||||||
The quantization options are as follows. Note that some architectures such as Falcon do not support K quants.
|
|
||||||
|
|
||||||
- Q4_0
|
|
||||||
- Q4_1
|
|
||||||
- Q5_0
|
|
||||||
- Q5_1
|
|
||||||
- Q2_K
|
|
||||||
- Q3_K
|
|
||||||
- Q3_K_S
|
|
||||||
- Q3_K_M
|
|
||||||
- Q3_K_L
|
|
||||||
- Q4_K
|
|
||||||
- Q4_K_S
|
|
||||||
- Q4_K_M
|
|
||||||
- Q5_K
|
|
||||||
- Q5_K_S
|
|
||||||
- Q5_K_M
|
|
||||||
- Q6_K
|
|
||||||
- Q8_0
|
|
||||||
|
|
||||||
Run the following command `quantize <converted model from above> <output file> <quantization type>`
|
|
||||||
|
|
||||||
## Now Create the Model
|
|
||||||
|
|
||||||
Now you can create the Ollama model. Refer to the [modelfile](./modelfile.md) doc for more information on doing that.
|
|
@@ -3,10 +3,10 @@ package main
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -16,7 +16,7 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Print(err.Error())
|
fmt.Print(err.Error())
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
responseData, err := io.ReadAll(resp.Body)
|
responseData, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@@ -6,7 +6,6 @@ PERSIST_DIRECTORY = os.environ.get('PERSIST_DIRECTORY', 'db')
|
|||||||
|
|
||||||
# Define the Chroma settings
|
# Define the Chroma settings
|
||||||
CHROMA_SETTINGS = Settings(
|
CHROMA_SETTINGS = Settings(
|
||||||
chroma_db_impl='duckdb+parquet',
|
|
||||||
persist_directory=PERSIST_DIRECTORY,
|
persist_directory=PERSIST_DIRECTORY,
|
||||||
anonymized_telemetry=False
|
anonymized_telemetry=False
|
||||||
)
|
)
|
||||||
|
@@ -150,7 +150,7 @@ def main():
|
|||||||
print("Creating new vectorstore")
|
print("Creating new vectorstore")
|
||||||
texts = process_documents()
|
texts = process_documents()
|
||||||
print(f"Creating embeddings. May take some minutes...")
|
print(f"Creating embeddings. May take some minutes...")
|
||||||
db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS)
|
db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory)
|
||||||
db.persist()
|
db.persist()
|
||||||
db = None
|
db = None
|
||||||
|
|
||||||
|
@@ -4,6 +4,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
|||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
from langchain.vectorstores import Chroma
|
from langchain.vectorstores import Chroma
|
||||||
from langchain.llms import Ollama
|
from langchain.llms import Ollama
|
||||||
|
import chromadb
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
@@ -22,7 +23,9 @@ def main():
|
|||||||
# Parse the command line arguments
|
# Parse the command line arguments
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
|
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
|
||||||
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
|
|
||||||
|
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
|
||||||
|
|
||||||
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
|
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
|
||||||
# activate/deactivate the streaming StdOut callback for LLMs
|
# activate/deactivate the streaming StdOut callback for LLMs
|
||||||
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
|
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
|
||||||
|
File diff suppressed because it is too large
Load Diff
22
examples/python-rag-newssummary/README.md
Normal file
22
examples/python-rag-newssummary/README.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# News Summarizer
|
||||||
|
|
||||||
|
This example goes through a series of steps:
|
||||||
|
|
||||||
|
1. You choose a topic area (e.g., "news", "NVidia", "music", etc.).
|
||||||
|
2. Gets the most recent articles on that topic from various sources.
|
||||||
|
3. Uses Ollama to summarize each article.
|
||||||
|
4. Creates chunks of sentences from each article.
|
||||||
|
5. Uses Sentence Transformers to generate embeddings for each of those chunks.
|
||||||
|
6. You enter a question regarding the summaries shown.
|
||||||
|
7. Uses Sentence Transformers to generate an embedding for that question.
|
||||||
|
8. Uses the embedded question to find the most similar chunks.
|
||||||
|
9. Feeds all that to Ollama to generate a good answer to your question based on these news articles.
|
||||||
|
|
||||||
|
This example lets you pick from a few different topic areas, then summarize the most recent x articles for that topic. It then creates chunks of sentences from each article and then generates embeddings for each of those chunks.
|
||||||
|
|
||||||
|
You can run the example like this:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
python summ.py
|
||||||
|
```
|
9
examples/python-rag-newssummary/requirements.txt
Normal file
9
examples/python-rag-newssummary/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
beautifulsoup4==4.12.2
|
||||||
|
feedparser==6.0.10
|
||||||
|
mattsollamatools==0.0.8
|
||||||
|
newspaper3k==0.2.8
|
||||||
|
nltk==3.8.1
|
||||||
|
numpy==1.24.3
|
||||||
|
Requests==2.31.0
|
||||||
|
scikit_learn==1.3.0
|
||||||
|
sentence_transformers==2.2.2
|
86
examples/python-rag-newssummary/summ.py
Normal file
86
examples/python-rag-newssummary/summ.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import curses
|
||||||
|
import json
|
||||||
|
from utils import get_url_for_topic, topic_urls, menu, getUrls, get_summary, getArticleText, knn_search
|
||||||
|
import requests
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from mattsollamatools import chunker
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
chosen_topic = curses.wrapper(menu)
|
||||||
|
print("Here is your news summary:\n")
|
||||||
|
urls = getUrls(chosen_topic, n=5)
|
||||||
|
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||||
|
allEmbeddings = []
|
||||||
|
|
||||||
|
for url in urls:
|
||||||
|
article={}
|
||||||
|
article['embeddings'] = []
|
||||||
|
article['url'] = url
|
||||||
|
text = getArticleText(url)
|
||||||
|
summary = get_summary(text)
|
||||||
|
chunks = chunker(text) # Use the chunk_text function from web_utils
|
||||||
|
embeddings = model.encode(chunks)
|
||||||
|
for (chunk, embedding) in zip(chunks, embeddings):
|
||||||
|
item = {}
|
||||||
|
item['source'] = chunk
|
||||||
|
item['embedding'] = embedding.tolist() # Convert NumPy array to list
|
||||||
|
item['sourcelength'] = len(chunk)
|
||||||
|
article['embeddings'].append(item)
|
||||||
|
|
||||||
|
allEmbeddings.append(article)
|
||||||
|
|
||||||
|
print(f"{summary}\n")
|
||||||
|
|
||||||
|
|
||||||
|
while True:
|
||||||
|
context = []
|
||||||
|
# Input a question from the user
|
||||||
|
question = input("Enter your question about the news, or type quit: ")
|
||||||
|
|
||||||
|
if question.lower() == 'quit':
|
||||||
|
break
|
||||||
|
|
||||||
|
# Embed the user's question
|
||||||
|
question_embedding = model.encode([question])
|
||||||
|
|
||||||
|
# Perform KNN search to find the best matches (indices and source text)
|
||||||
|
best_matches = knn_search(question_embedding, allEmbeddings, k=10)
|
||||||
|
|
||||||
|
|
||||||
|
sourcetext=""
|
||||||
|
for i, (index, source_text) in enumerate(best_matches, start=1):
|
||||||
|
sourcetext += f"{i}. Index: {index}, Source Text: {source_text}"
|
||||||
|
|
||||||
|
systemPrompt = f"Only use the following information to answer the question. Do not use anything else: {sourcetext}"
|
||||||
|
|
||||||
|
url = "http://localhost:11434/api/generate"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": "mistral-openorca",
|
||||||
|
"prompt": question,
|
||||||
|
"system": systemPrompt,
|
||||||
|
"stream": False,
|
||||||
|
"context": context
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert the payload to a JSON string
|
||||||
|
payload_json = json.dumps(payload)
|
||||||
|
|
||||||
|
# Set the headers to specify JSON content
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send the POST request
|
||||||
|
response = requests.post(url, data=payload_json, headers=headers)
|
||||||
|
|
||||||
|
# Check the response
|
||||||
|
if response.status_code == 200:
|
||||||
|
output = json.loads(response.text)
|
||||||
|
context = output['context']
|
||||||
|
print(output['response']+ "\n")
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"Request failed with status code {response.status_code}")
|
||||||
|
|
108
examples/python-rag-newssummary/utils.py
Normal file
108
examples/python-rag-newssummary/utils.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
import curses
|
||||||
|
import feedparser
|
||||||
|
import requests
|
||||||
|
import unicodedata
|
||||||
|
import json
|
||||||
|
from newspaper import Article
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from nltk.tokenize import sent_tokenize, word_tokenize
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.neighbors import NearestNeighbors
|
||||||
|
from mattsollamatools import chunker
|
||||||
|
|
||||||
|
# Create a dictionary to store topics and their URLs
|
||||||
|
topic_urls = {
|
||||||
|
"Mac": "https://9to5mac.com/guides/mac/feed",
|
||||||
|
"News": "http://www.npr.org/rss/rss.php?id=1001",
|
||||||
|
"Nvidia": "https://nvidianews.nvidia.com/releases.xml",
|
||||||
|
"Raspberry Pi": "https://www.raspberrypi.com/news/feed/",
|
||||||
|
"Music": "https://www.billboard.com/c/music/music-news/feed/"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use curses to create a menu of topics
|
||||||
|
def menu(stdscr):
|
||||||
|
chosen_topic = get_url_for_topic(stdscr)
|
||||||
|
url = topic_urls[chosen_topic] if chosen_topic in topic_urls else "Topic not found"
|
||||||
|
|
||||||
|
stdscr.addstr(len(topic_urls) + 3, 0, f"Selected URL for {chosen_topic}: {url}")
|
||||||
|
stdscr.refresh()
|
||||||
|
|
||||||
|
return chosen_topic
|
||||||
|
|
||||||
|
# You have chosen a topic. Now return the url for that topic
|
||||||
|
def get_url_for_topic(stdscr):
|
||||||
|
curses.curs_set(0) # Hide the cursor
|
||||||
|
stdscr.clear()
|
||||||
|
|
||||||
|
stdscr.addstr(0, 0, "Choose a topic using the arrow keys (Press Enter to select):")
|
||||||
|
|
||||||
|
# Create a list of topics
|
||||||
|
topics = list(topic_urls.keys())
|
||||||
|
current_topic = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
for i, topic in enumerate(topics):
|
||||||
|
if i == current_topic:
|
||||||
|
stdscr.addstr(i + 2, 2, f"> {topic}")
|
||||||
|
else:
|
||||||
|
stdscr.addstr(i + 2, 2, f" {topic}")
|
||||||
|
|
||||||
|
stdscr.refresh()
|
||||||
|
|
||||||
|
key = stdscr.getch()
|
||||||
|
|
||||||
|
if key == curses.KEY_DOWN and current_topic < len(topics) - 1:
|
||||||
|
current_topic += 1
|
||||||
|
elif key == curses.KEY_UP and current_topic > 0:
|
||||||
|
current_topic -= 1
|
||||||
|
elif key == 10: # Enter key
|
||||||
|
return topic_urls[topics[current_topic]]
|
||||||
|
|
||||||
|
# Get the last N URLs from an RSS feed
|
||||||
|
def getUrls(feed_url, n=20):
|
||||||
|
feed = feedparser.parse(feed_url)
|
||||||
|
entries = feed.entries[-n:]
|
||||||
|
urls = [entry.link for entry in entries]
|
||||||
|
return urls
|
||||||
|
|
||||||
|
# Often there are a bunch of ads and menus on pages for a news article. This uses newspaper3k to get just the text of just the article.
|
||||||
|
def getArticleText(url):
|
||||||
|
article = Article(url)
|
||||||
|
article.download()
|
||||||
|
article.parse()
|
||||||
|
return article.text
|
||||||
|
|
||||||
|
def get_summary(text):
|
||||||
|
systemPrompt = "Write a concise summary of the text, return your responses with 5 lines that cover the key points of the text given."
|
||||||
|
prompt = text
|
||||||
|
|
||||||
|
url = "http://localhost:11434/api/generate"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": "mistral-openorca",
|
||||||
|
"prompt": prompt,
|
||||||
|
"system": systemPrompt,
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
payload_json = json.dumps(payload)
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
response = requests.post(url, data=payload_json, headers=headers)
|
||||||
|
|
||||||
|
return json.loads(response.text)["response"]
|
||||||
|
|
||||||
|
# Perform K-nearest neighbors (KNN) search
|
||||||
|
def knn_search(question_embedding, embeddings, k=5):
|
||||||
|
X = np.array([item['embedding'] for article in embeddings for item in article['embeddings']])
|
||||||
|
source_texts = [item['source'] for article in embeddings for item in article['embeddings']]
|
||||||
|
|
||||||
|
# Fit a KNN model on the embeddings
|
||||||
|
knn = NearestNeighbors(n_neighbors=k, metric='cosine')
|
||||||
|
knn.fit(X)
|
||||||
|
|
||||||
|
# Find the indices and distances of the k-nearest neighbors
|
||||||
|
distances, indices = knn.kneighbors(question_embedding, n_neighbors=k)
|
||||||
|
|
||||||
|
# Get the indices and source texts of the best matches
|
||||||
|
best_matches = [(indices[0][i], source_texts[indices[0][i]]) for i in range(k)]
|
||||||
|
|
||||||
|
return best_matches
|
@@ -2,14 +2,21 @@ package format
|
|||||||
|
|
||||||
import "fmt"
|
import "fmt"
|
||||||
|
|
||||||
|
const (
|
||||||
|
Byte = 1
|
||||||
|
KiloByte = Byte * 1000
|
||||||
|
MegaByte = KiloByte * 1000
|
||||||
|
GigaByte = MegaByte * 1000
|
||||||
|
)
|
||||||
|
|
||||||
func HumanBytes(b int64) string {
|
func HumanBytes(b int64) string {
|
||||||
switch {
|
switch {
|
||||||
case b > 1000*1000*1000:
|
case b > GigaByte:
|
||||||
return fmt.Sprintf("%d GB", b/1000/1000/1000)
|
return fmt.Sprintf("%d GB", b/GigaByte)
|
||||||
case b > 1000*1000:
|
case b > MegaByte:
|
||||||
return fmt.Sprintf("%d MB", b/1000/1000)
|
return fmt.Sprintf("%d MB", b/MegaByte)
|
||||||
case b > 1000:
|
case b > KiloByte:
|
||||||
return fmt.Sprintf("%d KB", b/1000)
|
return fmt.Sprintf("%d KB", b/KiloByte)
|
||||||
default:
|
default:
|
||||||
return fmt.Sprintf("%d B", b)
|
return fmt.Sprintf("%d B", b)
|
||||||
}
|
}
|
||||||
|
@@ -29,7 +29,7 @@ func TestHumanTime(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("soon", func(t *testing.T) {
|
t.Run("soon", func(t *testing.T) {
|
||||||
v := now.Add(800*time.Millisecond)
|
v := now.Add(800 * time.Millisecond)
|
||||||
assertEqual(t, HumanTime(v, ""), "Less than a second from now")
|
assertEqual(t, HumanTime(v, ""), "Less than a second from now")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
13
go.mod
13
go.mod
@@ -4,11 +4,11 @@ go 1.20
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/dustin/go-humanize v1.0.1
|
github.com/dustin/go-humanize v1.0.1
|
||||||
|
github.com/emirpasic/gods v1.18.1
|
||||||
github.com/gin-gonic/gin v1.9.1
|
github.com/gin-gonic/gin v1.9.1
|
||||||
github.com/mattn/go-runewidth v0.0.14
|
github.com/mattn/go-runewidth v0.0.14
|
||||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db
|
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db
|
||||||
github.com/olekukonko/tablewriter v0.0.5
|
github.com/olekukonko/tablewriter v0.0.5
|
||||||
github.com/pdevine/readline v1.5.2
|
|
||||||
github.com/spf13/cobra v1.7.0
|
github.com/spf13/cobra v1.7.0
|
||||||
golang.org/x/sync v0.3.0
|
golang.org/x/sync v0.3.0
|
||||||
)
|
)
|
||||||
@@ -39,13 +39,12 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||||
golang.org/x/arch v0.3.0 // indirect
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
golang.org/x/crypto v0.10.0
|
golang.org/x/crypto v0.14.0
|
||||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
|
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
|
||||||
golang.org/x/net v0.10.0 // indirect
|
golang.org/x/net v0.17.0 // indirect
|
||||||
golang.org/x/sys v0.11.0 // indirect
|
golang.org/x/sys v0.13.0 // indirect
|
||||||
golang.org/x/term v0.10.0
|
golang.org/x/term v0.13.0
|
||||||
golang.org/x/text v0.10.0 // indirect
|
golang.org/x/text v0.13.0 // indirect
|
||||||
gonum.org/v1/gonum v0.13.0
|
|
||||||
google.golang.org/protobuf v1.30.0 // indirect
|
google.golang.org/protobuf v1.30.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
31
go.sum
31
go.sum
@@ -4,10 +4,6 @@ github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX
|
|||||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||||
github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
|
|
||||||
github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ=
|
|
||||||
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
|
|
||||||
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
|
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
@@ -15,6 +11,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
|
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||||
|
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||||
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
|
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
|
||||||
@@ -78,8 +76,6 @@ github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N
|
|||||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||||
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0=
|
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0=
|
||||||
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y=
|
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y=
|
||||||
github.com/pdevine/readline v1.5.2 h1:oz6Y5GdTmhPG+08hhxcAvtHitSANWuA2100Sppb38xI=
|
|
||||||
github.com/pdevine/readline v1.5.2/go.mod h1:na/LbuE5PYwxI7GyopWdIs3U8HVe89lYlNTFTXH3wOw=
|
|
||||||
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
|
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
|
||||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||||
@@ -118,35 +114,32 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
|
|||||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
|
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||||
golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
|
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
|
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
|
||||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
|
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||||
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
||||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
|
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
|
golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
|
||||||
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
|
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
|
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||||
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
gonum.org/v1/gonum v0.13.0 h1:a0T3bh+7fhRyqeNbiC3qVHYmkiQgit3wnNan/2c0HMM=
|
|
||||||
gonum.org/v1/gonum v0.13.0/go.mod h1:/WPYRckkfWrhWefxyYTfrTtQR0KH4iyHNuzxqXAKyAU=
|
|
||||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||||
|
@@ -175,7 +175,8 @@ const (
|
|||||||
// Magic constant for `ggla` files (LoRA adapter).
|
// Magic constant for `ggla` files (LoRA adapter).
|
||||||
FILE_MAGIC_GGLA = 0x67676C61
|
FILE_MAGIC_GGLA = 0x67676C61
|
||||||
// Magic constant for `gguf` files (versioned, gguf)
|
// Magic constant for `gguf` files (versioned, gguf)
|
||||||
FILE_MAGIC_GGUF = 0x46554747
|
FILE_MAGIC_GGUF_LE = 0x46554747
|
||||||
|
FILE_MAGIC_GGUF_BE = 0x47475546
|
||||||
)
|
)
|
||||||
|
|
||||||
func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
|
func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
|
||||||
@@ -191,8 +192,10 @@ func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
|
|||||||
ggml.container = &containerGGJT{}
|
ggml.container = &containerGGJT{}
|
||||||
case FILE_MAGIC_GGLA:
|
case FILE_MAGIC_GGLA:
|
||||||
ggml.container = &containerLORA{}
|
ggml.container = &containerLORA{}
|
||||||
case FILE_MAGIC_GGUF:
|
case FILE_MAGIC_GGUF_LE:
|
||||||
ggml.container = &containerGGUF{}
|
ggml.container = &containerGGUF{bo: binary.LittleEndian}
|
||||||
|
case FILE_MAGIC_GGUF_BE:
|
||||||
|
ggml.container = &containerGGUF{bo: binary.BigEndian}
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("invalid file magic")
|
return nil, errors.New("invalid file magic")
|
||||||
}
|
}
|
||||||
|
61
llm/gguf.go
61
llm/gguf.go
@@ -3,12 +3,13 @@ package llm
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
type containerGGUF struct {
|
type containerGGUF struct {
|
||||||
|
bo binary.ByteOrder
|
||||||
|
|
||||||
Version uint32
|
Version uint32
|
||||||
|
|
||||||
V1 struct {
|
V1 struct {
|
||||||
@@ -27,15 +28,13 @@ func (c *containerGGUF) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *containerGGUF) Decode(r io.Reader) (model, error) {
|
func (c *containerGGUF) Decode(r io.Reader) (model, error) {
|
||||||
binary.Read(r, binary.LittleEndian, &c.Version)
|
binary.Read(r, c.bo, &c.Version)
|
||||||
|
|
||||||
switch c.Version {
|
switch c.Version {
|
||||||
case 1:
|
case 1:
|
||||||
binary.Read(r, binary.LittleEndian, &c.V1)
|
binary.Read(r, c.bo, &c.V1)
|
||||||
case 2:
|
|
||||||
binary.Read(r, binary.LittleEndian, &c.V2)
|
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("invalid version")
|
binary.Read(r, c.bo, &c.V2)
|
||||||
}
|
}
|
||||||
|
|
||||||
model := newGGUFModel(c)
|
model := newGGUFModel(c)
|
||||||
@@ -209,75 +208,75 @@ func (llm *ggufModel) NumLayers() int64 {
|
|||||||
return int64(v)
|
return int64(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ggufModel) readU8(r io.Reader) uint8 {
|
func (llm ggufModel) readU8(r io.Reader) uint8 {
|
||||||
var u8 uint8
|
var u8 uint8
|
||||||
binary.Read(r, binary.LittleEndian, &u8)
|
binary.Read(r, llm.bo, &u8)
|
||||||
return u8
|
return u8
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ggufModel) readI8(r io.Reader) int8 {
|
func (llm ggufModel) readI8(r io.Reader) int8 {
|
||||||
var i8 int8
|
var i8 int8
|
||||||
binary.Read(r, binary.LittleEndian, &i8)
|
binary.Read(r, llm.bo, &i8)
|
||||||
return i8
|
return i8
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ggufModel) readU16(r io.Reader) uint16 {
|
func (llm ggufModel) readU16(r io.Reader) uint16 {
|
||||||
var u16 uint16
|
var u16 uint16
|
||||||
binary.Read(r, binary.LittleEndian, &u16)
|
binary.Read(r, llm.bo, &u16)
|
||||||
return u16
|
return u16
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ggufModel) readI16(r io.Reader) int16 {
|
func (llm ggufModel) readI16(r io.Reader) int16 {
|
||||||
var i16 int16
|
var i16 int16
|
||||||
binary.Read(r, binary.LittleEndian, &i16)
|
binary.Read(r, llm.bo, &i16)
|
||||||
return i16
|
return i16
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ggufModel) readU32(r io.Reader) uint32 {
|
func (llm ggufModel) readU32(r io.Reader) uint32 {
|
||||||
var u32 uint32
|
var u32 uint32
|
||||||
binary.Read(r, binary.LittleEndian, &u32)
|
binary.Read(r, llm.bo, &u32)
|
||||||
return u32
|
return u32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ggufModel) readI32(r io.Reader) int32 {
|
func (llm ggufModel) readI32(r io.Reader) int32 {
|
||||||
var i32 int32
|
var i32 int32
|
||||||
binary.Read(r, binary.LittleEndian, &i32)
|
binary.Read(r, llm.bo, &i32)
|
||||||
return i32
|
return i32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ggufModel) readU64(r io.Reader) uint64 {
|
func (llm ggufModel) readU64(r io.Reader) uint64 {
|
||||||
var u64 uint64
|
var u64 uint64
|
||||||
binary.Read(r, binary.LittleEndian, &u64)
|
binary.Read(r, llm.bo, &u64)
|
||||||
return u64
|
return u64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ggufModel) readI64(r io.Reader) int64 {
|
func (llm ggufModel) readI64(r io.Reader) int64 {
|
||||||
var i64 int64
|
var i64 int64
|
||||||
binary.Read(r, binary.LittleEndian, &i64)
|
binary.Read(r, llm.bo, &i64)
|
||||||
return i64
|
return i64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ggufModel) readF32(r io.Reader) float32 {
|
func (llm ggufModel) readF32(r io.Reader) float32 {
|
||||||
var f32 float32
|
var f32 float32
|
||||||
binary.Read(r, binary.LittleEndian, &f32)
|
binary.Read(r, llm.bo, &f32)
|
||||||
return f32
|
return f32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ggufModel) readF64(r io.Reader) float64 {
|
func (llm ggufModel) readF64(r io.Reader) float64 {
|
||||||
var f64 float64
|
var f64 float64
|
||||||
binary.Read(r, binary.LittleEndian, &f64)
|
binary.Read(r, llm.bo, &f64)
|
||||||
return f64
|
return f64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ggufModel) readBool(r io.Reader) bool {
|
func (llm ggufModel) readBool(r io.Reader) bool {
|
||||||
var b bool
|
var b bool
|
||||||
binary.Read(r, binary.LittleEndian, &b)
|
binary.Read(r, llm.bo, &b)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ggufModel) readStringV1(r io.Reader) (string, error) {
|
func (llm ggufModel) readStringV1(r io.Reader) (string, error) {
|
||||||
var nameLength uint32
|
var nameLength uint32
|
||||||
binary.Read(r, binary.LittleEndian, &nameLength)
|
binary.Read(r, llm.bo, &nameLength)
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
|
if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
|
||||||
@@ -292,7 +291,7 @@ func (ggufModel) readStringV1(r io.Reader) (string, error) {
|
|||||||
|
|
||||||
func (llm ggufModel) readString(r io.Reader) (string, error) {
|
func (llm ggufModel) readString(r io.Reader) (string, error) {
|
||||||
var nameLength uint64
|
var nameLength uint64
|
||||||
binary.Read(r, binary.LittleEndian, &nameLength)
|
binary.Read(r, llm.bo, &nameLength)
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
|
if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
|
||||||
|
@@ -12,7 +12,8 @@ package llm
|
|||||||
//go:generate mv ggml/build/cpu/bin/server ggml/build/cpu/bin/ollama-runner
|
//go:generate mv ggml/build/cpu/bin/server ggml/build/cpu/bin/ollama-runner
|
||||||
|
|
||||||
//go:generate git submodule update --force gguf
|
//go:generate git submodule update --force gguf
|
||||||
//go:generate git -C gguf apply ../patches/0001-remove-warm-up-logging.patch
|
//go:generate git -C gguf apply ../patches/0001-update-default-log-target.patch
|
||||||
|
//go:generate git -C gguf apply ../patches/0001-metal-handle-ggml_scale-for-n-4-0-close-3754.patch
|
||||||
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
|
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
|
||||||
//go:generate cmake --build gguf/build/cpu --target server --config Release
|
//go:generate cmake --build gguf/build/cpu --target server --config Release
|
||||||
//go:generate mv gguf/build/cpu/bin/server gguf/build/cpu/bin/ollama-runner
|
//go:generate mv gguf/build/cpu/bin/server gguf/build/cpu/bin/ollama-runner
|
||||||
|
@@ -12,7 +12,8 @@ package llm
|
|||||||
//go:generate mv ggml/build/metal/bin/server ggml/build/metal/bin/ollama-runner
|
//go:generate mv ggml/build/metal/bin/server ggml/build/metal/bin/ollama-runner
|
||||||
|
|
||||||
//go:generate git submodule update --force gguf
|
//go:generate git submodule update --force gguf
|
||||||
//go:generate git -C gguf apply ../patches/0001-remove-warm-up-logging.patch
|
//go:generate git -C gguf apply ../patches/0001-update-default-log-target.patch
|
||||||
|
//go:generate git -C gguf apply ../patches/0001-metal-handle-ggml_scale-for-n-4-0-close-3754.patch
|
||||||
//go:generate cmake -S gguf -B gguf/build/metal -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
|
//go:generate cmake -S gguf -B gguf/build/metal -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
|
||||||
//go:generate cmake --build gguf/build/metal --target server --config Release
|
//go:generate cmake --build gguf/build/metal --target server --config Release
|
||||||
//go:generate mv gguf/build/metal/bin/server gguf/build/metal/bin/ollama-runner
|
//go:generate mv gguf/build/metal/bin/server gguf/build/metal/bin/ollama-runner
|
||||||
|
@@ -13,14 +13,14 @@ package llm
|
|||||||
|
|
||||||
//go:generate git submodule update --force gguf
|
//go:generate git submodule update --force gguf
|
||||||
//go:generate git -C gguf apply ../patches/0001-copy-cuda-runtime-libraries.patch
|
//go:generate git -C gguf apply ../patches/0001-copy-cuda-runtime-libraries.patch
|
||||||
//go:generate git -C gguf apply ../patches/0001-remove-warm-up-logging.patch
|
//go:generate git -C gguf apply ../patches/0001-update-default-log-target.patch
|
||||||
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on
|
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off
|
||||||
//go:generate cmake --build gguf/build/cpu --target server --config Release
|
//go:generate cmake --build gguf/build/cpu --target server --config Release
|
||||||
//go:generate mv gguf/build/cpu/bin/server gguf/build/cpu/bin/ollama-runner
|
//go:generate mv gguf/build/cpu/bin/server gguf/build/cpu/bin/ollama-runner
|
||||||
|
|
||||||
//go:generate cmake -S ggml -B ggml/build/cuda -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on
|
//go:generate cmake -S ggml -B ggml/build/cuda -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on
|
||||||
//go:generate cmake --build ggml/build/cuda --target server --config Release
|
//go:generate cmake --build ggml/build/cuda --target server --config Release
|
||||||
//go:generate mv ggml/build/cuda/bin/server ggml/build/cuda/bin/ollama-runner
|
//go:generate mv ggml/build/cuda/bin/server ggml/build/cuda/bin/ollama-runner
|
||||||
//go:generate cmake -S gguf -B gguf/build/cuda -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on
|
//go:generate cmake -S gguf -B gguf/build/cuda -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off
|
||||||
//go:generate cmake --build gguf/build/cuda --target server --config Release
|
//go:generate cmake --build gguf/build/cuda --target server --config Release
|
||||||
//go:generate mv gguf/build/cuda/bin/server gguf/build/cuda/bin/ollama-runner
|
//go:generate mv gguf/build/cuda/bin/server gguf/build/cuda/bin/ollama-runner
|
||||||
|
@@ -10,7 +10,7 @@ package llm
|
|||||||
//go:generate cmd /c move ggml\build\cpu\bin\Release\server.exe ggml\build\cpu\bin\Release\ollama-runner.exe
|
//go:generate cmd /c move ggml\build\cpu\bin\Release\server.exe ggml\build\cpu\bin\Release\ollama-runner.exe
|
||||||
|
|
||||||
//go:generate git submodule update --force gguf
|
//go:generate git submodule update --force gguf
|
||||||
//go:generate git -C gguf apply ../patches/0001-remove-warm-up-logging.patch
|
//go:generate git -C gguf apply ../patches/0001-update-default-log-target.patch
|
||||||
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on
|
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off
|
||||||
//go:generate cmake --build gguf/build/cpu --target server --config Release
|
//go:generate cmake --build gguf/build/cpu --target server --config Release
|
||||||
//go:generate cmd /c move gguf\build\cpu\bin\Release\server.exe gguf\build\cpu\bin\Release\ollama-runner.exe
|
//go:generate cmd /c move gguf\build\cpu\bin\Release\server.exe gguf\build\cpu\bin\Release\ollama-runner.exe
|
||||||
|
Submodule llm/llama.cpp/gguf updated: bc9d3e3971...9e70cc0322
@@ -0,0 +1,91 @@
|
|||||||
|
From 469c9addef75893e6be12edda852d12e840bf064 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Georgi Gerganov <ggerganov@gmail.com>
|
||||||
|
Date: Tue, 24 Oct 2023 09:46:50 +0300
|
||||||
|
Subject: [PATCH 1/2] metal : handle ggml_scale for n%4 != 0 (close #3754)
|
||||||
|
|
||||||
|
ggml-ci
|
||||||
|
---
|
||||||
|
ggml-metal.m | 18 +++++++++++++-----
|
||||||
|
ggml-metal.metal | 10 +++++++++-
|
||||||
|
2 files changed, 22 insertions(+), 6 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml-metal.m b/ggml-metal.m
|
||||||
|
index c908106..c1901dc 100644
|
||||||
|
--- a/ggml-metal.m
|
||||||
|
+++ b/ggml-metal.m
|
||||||
|
@@ -62,6 +62,7 @@
|
||||||
|
GGML_METAL_DECL_KERNEL(mul);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
||||||
|
GGML_METAL_DECL_KERNEL(scale);
|
||||||
|
+ GGML_METAL_DECL_KERNEL(scale_4);
|
||||||
|
GGML_METAL_DECL_KERNEL(silu);
|
||||||
|
GGML_METAL_DECL_KERNEL(relu);
|
||||||
|
GGML_METAL_DECL_KERNEL(gelu);
|
||||||
|
@@ -249,6 +250,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
|
||||||
|
GGML_METAL_ADD_KERNEL(mul);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_row);
|
||||||
|
GGML_METAL_ADD_KERNEL(scale);
|
||||||
|
+ GGML_METAL_ADD_KERNEL(scale_4);
|
||||||
|
GGML_METAL_ADD_KERNEL(silu);
|
||||||
|
GGML_METAL_ADD_KERNEL(relu);
|
||||||
|
GGML_METAL_ADD_KERNEL(gelu);
|
||||||
|
@@ -347,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||||
|
GGML_METAL_DEL_KERNEL(mul);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_row);
|
||||||
|
GGML_METAL_DEL_KERNEL(scale);
|
||||||
|
+ GGML_METAL_DEL_KERNEL(scale_4);
|
||||||
|
GGML_METAL_DEL_KERNEL(silu);
|
||||||
|
GGML_METAL_DEL_KERNEL(relu);
|
||||||
|
GGML_METAL_DEL_KERNEL(gelu);
|
||||||
|
@@ -923,15 +926,20 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
|
const float scale = *(const float *) src1->data;
|
||||||
|
|
||||||
|
- [encoder setComputePipelineState:ctx->pipeline_scale];
|
||||||
|
+ int64_t n = ggml_nelements(dst);
|
||||||
|
+
|
||||||
|
+ if (n % 4 == 0) {
|
||||||
|
+ n /= 4;
|
||||||
|
+ [encoder setComputePipelineState:ctx->pipeline_scale_4];
|
||||||
|
+ } else {
|
||||||
|
+ [encoder setComputePipelineState:ctx->pipeline_scale];
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||||
|
|
||||||
|
- const int64_t n = ggml_nelements(dst);
|
||||||
|
- GGML_ASSERT(n % 4 == 0);
|
||||||
|
-
|
||||||
|
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
|
case GGML_OP_UNARY:
|
||||||
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||||
|
diff --git a/ggml-metal.metal b/ggml-metal.metal
|
||||||
|
index 69fc713..f4b4605 100644
|
||||||
|
--- a/ggml-metal.metal
|
||||||
|
+++ b/ggml-metal.metal
|
||||||
|
@@ -125,9 +125,17 @@ kernel void kernel_mul_row(
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_scale(
|
||||||
|
+ device const float * src0,
|
||||||
|
+ device float * dst,
|
||||||
|
+ constant float & scale,
|
||||||
|
+ uint tpig[[thread_position_in_grid]]) {
|
||||||
|
+ dst[tpig] = src0[tpig] * scale;
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
+kernel void kernel_scale_4(
|
||||||
|
device const float4 * src0,
|
||||||
|
device float4 * dst,
|
||||||
|
- constant float & scale,
|
||||||
|
+ constant float & scale,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
dst[tpig] = src0[tpig] * scale;
|
||||||
|
}
|
||||||
|
--
|
||||||
|
2.39.3 (Apple Git-145)
|
||||||
|
|
@@ -1,25 +0,0 @@
|
|||||||
From 07993bdc35345b67b27aa649a7c099ad42d80c4c Mon Sep 17 00:00:00 2001
|
|
||||||
From: Michael Yang <mxyng@pm.me>
|
|
||||||
Date: Thu, 21 Sep 2023 14:43:21 -0700
|
|
||||||
Subject: [PATCH] remove warm up logging
|
|
||||||
|
|
||||||
---
|
|
||||||
common/common.cpp | 2 --
|
|
||||||
1 file changed, 2 deletions(-)
|
|
||||||
|
|
||||||
diff --git a/common/common.cpp b/common/common.cpp
|
|
||||||
index 2597ba0..b56549b 100644
|
|
||||||
--- a/common/common.cpp
|
|
||||||
+++ b/common/common.cpp
|
|
||||||
@@ -780,8 +780,6 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
- LOG("warming up the model with an empty run\n");
|
|
||||||
-
|
|
||||||
const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
|
|
||||||
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
|
|
||||||
llama_reset_timings(lctx);
|
|
||||||
--
|
|
||||||
2.42.0
|
|
||||||
|
|
25
llm/llama.cpp/patches/0001-update-default-log-target.patch
Normal file
25
llm/llama.cpp/patches/0001-update-default-log-target.patch
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
From 6465fec6290f0a7f5d4d0fbe6bcf634e4810dde6 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Michael Yang <mxyng@pm.me>
|
||||||
|
Date: Mon, 23 Oct 2023 10:39:34 -0700
|
||||||
|
Subject: [PATCH] default log stderr
|
||||||
|
|
||||||
|
---
|
||||||
|
common/log.h | 2 +-
|
||||||
|
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||||
|
|
||||||
|
diff --git a/common/log.h b/common/log.h
|
||||||
|
index b8953fd..25522cd 100644
|
||||||
|
--- a/common/log.h
|
||||||
|
+++ b/common/log.h
|
||||||
|
@@ -90,7 +90,7 @@
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
#ifndef LOG_TARGET
|
||||||
|
- #define LOG_TARGET log_handler()
|
||||||
|
+ #define LOG_TARGET nullptr
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef LOG_TEE_TARGET
|
||||||
|
--
|
||||||
|
2.42.0
|
||||||
|
|
298
llm/llama.go
298
llm/llama.go
@@ -24,51 +24,53 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
|
"github.com/jmorganca/ollama/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed llama.cpp/*/build/*/bin/*
|
//go:embed llama.cpp/*/build/*/bin/*
|
||||||
var llamaCppEmbed embed.FS
|
var llamaCppEmbed embed.FS
|
||||||
|
|
||||||
type ModelRunner struct {
|
type ModelRunner struct {
|
||||||
Path string // path to the model runner executable
|
Path string // path to the model runner executable
|
||||||
|
Accelerated bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func chooseRunners(workDir, runnerType string) []ModelRunner {
|
func chooseRunners(workDir, runnerType string) []ModelRunner {
|
||||||
buildPath := path.Join("llama.cpp", runnerType, "build")
|
buildPath := path.Join("llama.cpp", runnerType, "build")
|
||||||
var runners []string
|
var runners []ModelRunner
|
||||||
|
|
||||||
// set the runners based on the OS
|
// set the runners based on the OS
|
||||||
// IMPORTANT: the order of the runners in the array is the priority order
|
// IMPORTANT: the order of the runners in the array is the priority order
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "darwin":
|
case "darwin":
|
||||||
runners = []string{
|
runners = []ModelRunner{
|
||||||
path.Join(buildPath, "metal", "bin", "ollama-runner"),
|
{Path: path.Join(buildPath, "metal", "bin", "ollama-runner")},
|
||||||
path.Join(buildPath, "cpu", "bin", "ollama-runner"),
|
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
|
||||||
}
|
}
|
||||||
case "linux":
|
case "linux":
|
||||||
runners = []string{
|
runners = []ModelRunner{
|
||||||
path.Join(buildPath, "cuda", "bin", "ollama-runner"),
|
{Path: path.Join(buildPath, "cuda", "bin", "ollama-runner"), Accelerated: true},
|
||||||
path.Join(buildPath, "cpu", "bin", "ollama-runner"),
|
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
|
||||||
}
|
}
|
||||||
case "windows":
|
case "windows":
|
||||||
// TODO: select windows GPU runner here when available
|
// TODO: select windows GPU runner here when available
|
||||||
runners = []string{
|
runners = []ModelRunner{
|
||||||
path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe"),
|
{Path: path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe")},
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
|
log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
|
||||||
runners = []string{
|
runners = []ModelRunner{
|
||||||
path.Join(buildPath, "cpu", "bin", "ollama-runner"),
|
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
runnerAvailable := false // if no runner files are found in the embed, this flag will cause a fast fail
|
runnerAvailable := false // if no runner files are found in the embed, this flag will cause a fast fail
|
||||||
for _, r := range runners {
|
for _, r := range runners {
|
||||||
// find all the files in the runner's bin directory
|
// find all the files in the runner's bin directory
|
||||||
files, err := fs.Glob(llamaCppEmbed, path.Join(path.Dir(r), "*"))
|
files, err := fs.Glob(llamaCppEmbed, path.Join(path.Dir(r.Path), "*"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// this is expected, ollama may be compiled without all runners packed in
|
// this is expected, ollama may be compiled without all runners packed in
|
||||||
log.Printf("%s runner not found: %v", r, err)
|
log.Printf("%s runner not found: %v", r.Path, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,7 +117,10 @@ func chooseRunners(workDir, runnerType string) []ModelRunner {
|
|||||||
localRunnersByPriority := []ModelRunner{}
|
localRunnersByPriority := []ModelRunner{}
|
||||||
for _, r := range runners {
|
for _, r := range runners {
|
||||||
// clean the ModelRunner paths so that they match the OS we are running on
|
// clean the ModelRunner paths so that they match the OS we are running on
|
||||||
localRunnersByPriority = append(localRunnersByPriority, ModelRunner{Path: filepath.Clean(path.Join(workDir, r))})
|
localRunnersByPriority = append(localRunnersByPriority, ModelRunner{
|
||||||
|
Path: filepath.Clean(path.Join(workDir, r.Path)),
|
||||||
|
Accelerated: r.Accelerated,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return localRunnersByPriority
|
return localRunnersByPriority
|
||||||
@@ -178,12 +183,12 @@ type llamaHyperparameters struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Running struct {
|
type Running struct {
|
||||||
Port int
|
Port int
|
||||||
Cmd *exec.Cmd
|
Cmd *exec.Cmd
|
||||||
Cancel context.CancelFunc
|
Cancel context.CancelFunc
|
||||||
exitOnce sync.Once
|
exitOnce sync.Once
|
||||||
exitCh chan error // channel to receive the exit status of the subprocess
|
exitCh chan error // channel to receive the exit status of the subprocess
|
||||||
exitErr error // error returned by the subprocess
|
*StatusWriter // captures error messages from the llama runner process
|
||||||
}
|
}
|
||||||
|
|
||||||
type llama struct {
|
type llama struct {
|
||||||
@@ -193,7 +198,7 @@ type llama struct {
|
|||||||
|
|
||||||
var errNoGPU = errors.New("nvidia-smi command failed")
|
var errNoGPU = errors.New("nvidia-smi command failed")
|
||||||
|
|
||||||
// CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs
|
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
|
||||||
func CheckVRAM() (int64, error) {
|
func CheckVRAM() (int64, error) {
|
||||||
cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits")
|
cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits")
|
||||||
var stdout bytes.Buffer
|
var stdout bytes.Buffer
|
||||||
@@ -203,19 +208,29 @@ func CheckVRAM() (int64, error) {
|
|||||||
return 0, errNoGPU
|
return 0, errNoGPU
|
||||||
}
|
}
|
||||||
|
|
||||||
var free int64
|
var freeMiB int64
|
||||||
scanner := bufio.NewScanner(&stdout)
|
scanner := bufio.NewScanner(&stdout)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
|
if strings.Contains(line, "[Insufficient Permissions]") {
|
||||||
|
return 0, fmt.Errorf("GPU support may not enabled, check you have installed GPU drivers and have the necessary permissions to run nvidia-smi")
|
||||||
|
}
|
||||||
|
|
||||||
vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64)
|
vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
|
return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
free += vram
|
freeMiB += vram
|
||||||
}
|
}
|
||||||
|
|
||||||
return free, nil
|
freeBytes := freeMiB * 1024 * 1024
|
||||||
|
if freeBytes < 2*format.GigaByte {
|
||||||
|
log.Printf("less than 2 GB VRAM available, falling back to CPU only")
|
||||||
|
freeMiB = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return freeBytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
||||||
@@ -223,7 +238,7 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
|||||||
return opts.NumGPU
|
return opts.NumGPU
|
||||||
}
|
}
|
||||||
if runtime.GOOS == "linux" {
|
if runtime.GOOS == "linux" {
|
||||||
vramMib, err := CheckVRAM()
|
freeBytes, err := CheckVRAM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() != "nvidia-smi command failed" {
|
if err.Error() != "nvidia-smi command failed" {
|
||||||
log.Print(err.Error())
|
log.Print(err.Error())
|
||||||
@@ -232,15 +247,16 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
freeVramBytes := int64(vramMib) * 1024 * 1024 // 1 MiB = 1024^2 bytes
|
/*
|
||||||
|
Calculate bytes per layer, this will roughly be the size of the model file divided by the number of layers.
|
||||||
// Calculate bytes per layer
|
We can store the model weights and the kv cache in vram,
|
||||||
// TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size
|
to enable kv chache vram storage add two additional layers to the number of layers retrieved from the model file.
|
||||||
|
*/
|
||||||
bytesPerLayer := fileSizeBytes / numLayer
|
bytesPerLayer := fileSizeBytes / numLayer
|
||||||
|
|
||||||
// max number of layers we can fit in VRAM, subtract 5% to prevent consuming all available VRAM and running out of memory
|
// 75% of the absolute max number of layers we can fit in available VRAM, off-loading too many layers to the GPU can cause OOM errors
|
||||||
layers := int(freeVramBytes/bytesPerLayer) * 95 / 100
|
layers := int(freeBytes/bytesPerLayer) * 3 / 4
|
||||||
log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, layers)
|
log.Printf("%d MB VRAM available, loading up to %d GPU layers", freeBytes/(1024*1024), layers)
|
||||||
|
|
||||||
return layers
|
return layers
|
||||||
}
|
}
|
||||||
@@ -250,7 +266,8 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
|||||||
|
|
||||||
// StatusWriter is a writer that captures error messages from the llama runner process
|
// StatusWriter is a writer that captures error messages from the llama runner process
|
||||||
type StatusWriter struct {
|
type StatusWriter struct {
|
||||||
ErrCh chan error
|
ErrCh chan error
|
||||||
|
LastErrMsg string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStatusWriter() *StatusWriter {
|
func NewStatusWriter() *StatusWriter {
|
||||||
@@ -260,10 +277,18 @@ func NewStatusWriter() *StatusWriter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *StatusWriter) Write(b []byte) (int, error) {
|
func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||||
|
var errMsg string
|
||||||
if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
|
if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
|
||||||
err := fmt.Errorf("llama runner: %s", after)
|
errMsg = string(bytes.TrimSpace(after))
|
||||||
w.ErrCh <- err
|
} else if _, after, ok := bytes.Cut(b, []byte("CUDA error")); ok {
|
||||||
|
errMsg = string(bytes.TrimSpace(after))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if errMsg != "" {
|
||||||
|
w.LastErrMsg = errMsg
|
||||||
|
w.ErrCh <- fmt.Errorf("llama runner: %s", errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
return os.Stderr.Write(b)
|
return os.Stderr.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,16 +302,23 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
|
|||||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
numGPU := NumGPU(numLayers, fileInfo.Size(), opts)
|
||||||
params := []string{
|
params := []string{
|
||||||
"--model", model,
|
"--model", model,
|
||||||
"--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
|
"--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
|
||||||
"--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
|
|
||||||
"--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
|
|
||||||
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
|
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
|
||||||
"--n-gpu-layers", fmt.Sprintf("%d", NumGPU(numLayers, fileInfo.Size(), opts)),
|
"--n-gpu-layers", fmt.Sprintf("%d", numGPU),
|
||||||
"--embedding",
|
"--embedding",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.RopeFrequencyBase > 0 {
|
||||||
|
params = append(params, "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase))
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.RopeFrequencyScale > 0 {
|
||||||
|
params = append(params, "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale))
|
||||||
|
}
|
||||||
|
|
||||||
if opts.NumGQA > 0 {
|
if opts.NumGQA > 0 {
|
||||||
params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
|
params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
|
||||||
}
|
}
|
||||||
@@ -317,6 +349,11 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
|
|||||||
|
|
||||||
// start the llama.cpp server with a retry in case the port is already in use
|
// start the llama.cpp server with a retry in case the port is already in use
|
||||||
for _, runner := range runners {
|
for _, runner := range runners {
|
||||||
|
if runner.Accelerated && numGPU == 0 {
|
||||||
|
log.Printf("skipping accelerated runner because num_gpu=0")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := os.Stat(runner.Path); err != nil {
|
if _, err := os.Stat(runner.Path); err != nil {
|
||||||
log.Printf("llama runner not found: %v", err)
|
log.Printf("llama runner not found: %v", err)
|
||||||
continue
|
continue
|
||||||
@@ -329,7 +366,15 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
|
|||||||
runner.Path,
|
runner.Path,
|
||||||
append(params, "--port", strconv.Itoa(port))...,
|
append(params, "--port", strconv.Itoa(port))...,
|
||||||
)
|
)
|
||||||
cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", filepath.Dir(runner.Path)))
|
|
||||||
|
var libraryPaths []string
|
||||||
|
if libraryPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||||
|
libraryPaths = append(libraryPaths, libraryPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
libraryPaths = append(libraryPaths, filepath.Dir(runner.Path))
|
||||||
|
|
||||||
|
cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", strings.Join(libraryPaths, ":")))
|
||||||
cmd.Stdout = os.Stderr
|
cmd.Stdout = os.Stderr
|
||||||
statusWriter := NewStatusWriter()
|
statusWriter := NewStatusWriter()
|
||||||
cmd.Stderr = statusWriter
|
cmd.Stderr = statusWriter
|
||||||
@@ -345,7 +390,13 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
|
|||||||
// monitor the llama runner process and signal when it exits
|
// monitor the llama runner process and signal when it exits
|
||||||
go func() {
|
go func() {
|
||||||
err := llm.Cmd.Wait()
|
err := llm.Cmd.Wait()
|
||||||
llm.exitErr = err
|
// default to printing the exit message of the command process, it will probably just say 'exit staus 1'
|
||||||
|
errMsg := err.Error()
|
||||||
|
// try to set a better error message if llama runner logs captured an error
|
||||||
|
if statusWriter.LastErrMsg != "" {
|
||||||
|
errMsg = statusWriter.LastErrMsg
|
||||||
|
}
|
||||||
|
log.Println(errMsg)
|
||||||
// llm.Cmd.Wait() can only be called once, use this exit channel to signal that the process has exited
|
// llm.Cmd.Wait() can only be called once, use this exit channel to signal that the process has exited
|
||||||
llm.exitOnce.Do(func() {
|
llm.exitOnce.Do(func() {
|
||||||
close(llm.exitCh)
|
close(llm.exitCh)
|
||||||
@@ -415,10 +466,9 @@ func (llm *llama) Close() {
|
|||||||
|
|
||||||
// wait for the command to exit to prevent race conditions with the next run
|
// wait for the command to exit to prevent race conditions with the next run
|
||||||
<-llm.exitCh
|
<-llm.exitCh
|
||||||
err := llm.exitErr
|
|
||||||
|
|
||||||
if err != nil {
|
if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
|
||||||
log.Printf("llama runner stopped with error: %v", err)
|
log.Printf("llama runner stopped with error: %v", llm.StatusWriter.LastErrMsg)
|
||||||
} else {
|
} else {
|
||||||
log.Print("llama runner stopped successfully")
|
log.Print("llama runner stopped successfully")
|
||||||
}
|
}
|
||||||
@@ -428,71 +478,21 @@ func (llm *llama) SetOptions(opts api.Options) {
|
|||||||
llm.Options = opts
|
llm.Options = opts
|
||||||
}
|
}
|
||||||
|
|
||||||
type GenerationSettings struct {
|
type prediction struct {
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
|
||||||
IgnoreEOS bool `json:"ignore_eos"`
|
|
||||||
LogitBias []interface{} `json:"logit_bias"`
|
|
||||||
Mirostat int `json:"mirostat"`
|
|
||||||
MirostatEta float64 `json:"mirostat_eta"`
|
|
||||||
MirostatTau float64 `json:"mirostat_tau"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
NCtx int `json:"n_ctx"`
|
|
||||||
NKeep int `json:"n_keep"`
|
|
||||||
NPredict int `json:"n_predict"`
|
|
||||||
NProbs int `json:"n_probs"`
|
|
||||||
PenalizeNl bool `json:"penalize_nl"`
|
|
||||||
PresencePenalty float64 `json:"presence_penalty"`
|
|
||||||
RepeatLastN int `json:"repeat_last_n"`
|
|
||||||
RepeatPenalty float64 `json:"repeat_penalty"`
|
|
||||||
Seed uint32 `json:"seed"`
|
|
||||||
Stop []string `json:"stop"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
Temp float64 `json:"temp"`
|
|
||||||
TfsZ float64 `json:"tfs_z"`
|
|
||||||
TopK int `json:"top_k"`
|
|
||||||
TopP float64 `json:"top_p"`
|
|
||||||
TypicalP float64 `json:"typical_p"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Timings struct {
|
|
||||||
PredictedN int `json:"predicted_n"`
|
|
||||||
PredictedMS float64 `json:"predicted_ms"`
|
|
||||||
PromptN int `json:"prompt_n"`
|
|
||||||
PromptMS float64 `json:"prompt_ms"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Prediction struct {
|
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Stop bool `json:"stop"`
|
Stop bool `json:"stop"`
|
||||||
|
|
||||||
Timings `json:"timings"`
|
Timings struct {
|
||||||
|
PredictedN int `json:"predicted_n"`
|
||||||
|
PredictedMS float64 `json:"predicted_ms"`
|
||||||
|
PromptN int `json:"prompt_n"`
|
||||||
|
PromptMS float64 `json:"prompt_ms"`
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type PredictRequest struct {
|
const maxBufferSize = 512 * format.KiloByte
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
NPredict int `json:"n_predict"`
|
|
||||||
NKeep int `json:"n_keep"`
|
|
||||||
Temperature float32 `json:"temperature"`
|
|
||||||
TopK int `json:"top_k"`
|
|
||||||
TopP float32 `json:"top_p"`
|
|
||||||
TfsZ float32 `json:"tfs_z"`
|
|
||||||
TypicalP float32 `json:"typical_p"`
|
|
||||||
RepeatLastN int `json:"repeat_last_n"`
|
|
||||||
RepeatPenalty float32 `json:"repeat_penalty"`
|
|
||||||
PresencePenalty float32 `json:"presence_penalty"`
|
|
||||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
||||||
Mirostat int `json:"mirostat"`
|
|
||||||
MirostatTau float32 `json:"mirostat_tau"`
|
|
||||||
MirostatEta float32 `json:"mirostat_eta"`
|
|
||||||
PenalizeNl bool `json:"penalize_nl"`
|
|
||||||
Seed int `json:"seed"`
|
|
||||||
Stop []string `json:"stop,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const maxBufferSize = 512 * 1000 // 512KB
|
|
||||||
|
|
||||||
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
|
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||||
prevConvo, err := llm.Decode(ctx, prevContext)
|
prevConvo, err := llm.Decode(ctx, prevContext)
|
||||||
@@ -500,39 +500,46 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove leading spaces from prevConvo if present
|
||||||
|
prevConvo = strings.TrimPrefix(prevConvo, " ")
|
||||||
|
|
||||||
var nextContext strings.Builder
|
var nextContext strings.Builder
|
||||||
nextContext.WriteString(prevConvo)
|
nextContext.WriteString(prevConvo)
|
||||||
nextContext.WriteString(prompt)
|
nextContext.WriteString(prompt)
|
||||||
|
|
||||||
|
request := map[string]any{
|
||||||
|
"prompt": nextContext.String(),
|
||||||
|
"stream": true,
|
||||||
|
"n_predict": llm.NumPredict,
|
||||||
|
"n_keep": llm.NumKeep,
|
||||||
|
"temperature": llm.Temperature,
|
||||||
|
"top_k": llm.TopK,
|
||||||
|
"top_p": llm.TopP,
|
||||||
|
"tfs_z": llm.TFSZ,
|
||||||
|
"typical_p": llm.TypicalP,
|
||||||
|
"repeat_last_n": llm.RepeatLastN,
|
||||||
|
"repeat_penalty": llm.RepeatPenalty,
|
||||||
|
"presence_penalty": llm.PresencePenalty,
|
||||||
|
"frequency_penalty": llm.FrequencyPenalty,
|
||||||
|
"mirostat": llm.Mirostat,
|
||||||
|
"mirostat_tau": llm.MirostatTau,
|
||||||
|
"mirostat_eta": llm.MirostatEta,
|
||||||
|
"penalize_nl": llm.PenalizeNewline,
|
||||||
|
"seed": llm.Seed,
|
||||||
|
"stop": llm.Stop,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handling JSON marshaling with special characters unescaped.
|
||||||
|
buffer := &bytes.Buffer{}
|
||||||
|
enc := json.NewEncoder(buffer)
|
||||||
|
enc.SetEscapeHTML(false)
|
||||||
|
|
||||||
|
if err := enc.Encode(request); err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
||||||
predReq := PredictRequest{
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
||||||
Prompt: nextContext.String(),
|
|
||||||
Stream: true,
|
|
||||||
NPredict: llm.NumPredict,
|
|
||||||
NKeep: llm.NumKeep,
|
|
||||||
Temperature: llm.Temperature,
|
|
||||||
TopK: llm.TopK,
|
|
||||||
TopP: llm.TopP,
|
|
||||||
TfsZ: llm.TFSZ,
|
|
||||||
TypicalP: llm.TypicalP,
|
|
||||||
RepeatLastN: llm.RepeatLastN,
|
|
||||||
RepeatPenalty: llm.RepeatPenalty,
|
|
||||||
PresencePenalty: llm.PresencePenalty,
|
|
||||||
FrequencyPenalty: llm.FrequencyPenalty,
|
|
||||||
Mirostat: llm.Mirostat,
|
|
||||||
MirostatTau: llm.MirostatTau,
|
|
||||||
MirostatEta: llm.MirostatEta,
|
|
||||||
PenalizeNl: llm.PenalizeNewline,
|
|
||||||
Seed: llm.Seed,
|
|
||||||
Stop: llm.Stop,
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := json.Marshal(predReq)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error marshaling data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating POST request: %v", err)
|
return fmt.Errorf("error creating POST request: %v", err)
|
||||||
}
|
}
|
||||||
@@ -563,16 +570,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
|||||||
// This handles the request cancellation
|
// This handles the request cancellation
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
default:
|
default:
|
||||||
line := scanner.Text()
|
line := scanner.Bytes()
|
||||||
if line == "" {
|
if len(line) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read data from the server-side event stream
|
if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
|
||||||
if strings.HasPrefix(line, "data: ") {
|
var p prediction
|
||||||
evt := line[6:]
|
if err := json.Unmarshal(evt, &p); err != nil {
|
||||||
var p Prediction
|
|
||||||
if err := json.Unmarshal([]byte(evt), &p); err != nil {
|
|
||||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -590,10 +595,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
|||||||
fn(api.GenerateResponse{
|
fn(api.GenerateResponse{
|
||||||
Done: true,
|
Done: true,
|
||||||
Context: embd,
|
Context: embd,
|
||||||
PromptEvalCount: p.PromptN,
|
PromptEvalCount: p.Timings.PromptN,
|
||||||
PromptEvalDuration: parseDurationMs(p.PromptMS),
|
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
||||||
EvalCount: p.PredictedN,
|
EvalCount: p.Timings.PredictedN,
|
||||||
EvalDuration: parseDurationMs(p.PredictedMS),
|
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -603,6 +608,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
|
if strings.Contains(err.Error(), "unexpected EOF") {
|
||||||
|
// this means the llama runner subprocess crashed
|
||||||
|
llm.Close()
|
||||||
|
if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
|
||||||
|
return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
|
||||||
|
}
|
||||||
return fmt.Errorf("error reading llm response: %v", err)
|
return fmt.Errorf("error reading llm response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -699,9 +712,6 @@ func (llm *llama) Decode(ctx context.Context, tokens []int) (string, error) {
|
|||||||
return "", fmt.Errorf("unmarshal encode response: %w", err)
|
return "", fmt.Errorf("unmarshal encode response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoded content contains a leading whitespace
|
|
||||||
decoded.Content, _ = strings.CutPrefix(decoded.Content, "")
|
|
||||||
|
|
||||||
return decoded.Content, nil
|
return decoded.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
59
llm/llm.go
59
llm/llm.go
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/pbnjay/memory"
|
"github.com/pbnjay/memory"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
|
"github.com/jmorganca/ollama/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LLM interface {
|
type LLM interface {
|
||||||
@@ -55,45 +56,39 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
|
|||||||
opts.NumGPU = 0
|
opts.NumGPU = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
totalResidentMemory := memory.TotalMemory()
|
var requiredMemory int64
|
||||||
switch ggml.ModelType() {
|
var f16Multiplier int64 = 2
|
||||||
case "3B", "7B":
|
|
||||||
if ggml.FileType() == "F16" && totalResidentMemory < 16*1000*1000 {
|
switch ggml.ModelType() {
|
||||||
return nil, fmt.Errorf("F16 model requires at least 16 GB of memory")
|
case "3B", "7B":
|
||||||
} else if totalResidentMemory < 8*1000*1000 {
|
requiredMemory = 8 * format.GigaByte
|
||||||
return nil, fmt.Errorf("model requires at least 8 GB of memory")
|
case "13B":
|
||||||
|
requiredMemory = 16 * format.GigaByte
|
||||||
|
case "30B", "34B", "40B":
|
||||||
|
requiredMemory = 32 * format.GigaByte
|
||||||
|
case "65B", "70B":
|
||||||
|
requiredMemory = 64 * format.GigaByte
|
||||||
|
case "180B":
|
||||||
|
requiredMemory = 128 * format.GigaByte
|
||||||
|
f16Multiplier = 4
|
||||||
}
|
}
|
||||||
case "13B":
|
|
||||||
if ggml.FileType() == "F16" && totalResidentMemory < 32*1000*1000 {
|
systemMemory := int64(memory.TotalMemory())
|
||||||
return nil, fmt.Errorf("F16 model requires at least 32 GB of memory")
|
|
||||||
} else if totalResidentMemory < 16*1000*1000 {
|
if ggml.FileType() == "F16" && requiredMemory*f16Multiplier > systemMemory {
|
||||||
return nil, fmt.Errorf("model requires at least 16 GB of memory")
|
return nil, fmt.Errorf("F16 model requires at least %s of total memory", format.HumanBytes(requiredMemory))
|
||||||
}
|
} else if requiredMemory > systemMemory {
|
||||||
case "30B", "34B", "40B":
|
return nil, fmt.Errorf("model requires at least %s of total memory", format.HumanBytes(requiredMemory))
|
||||||
if ggml.FileType() == "F16" && totalResidentMemory < 64*1000*1000 {
|
|
||||||
return nil, fmt.Errorf("F16 model requires at least 64 GB of memory")
|
|
||||||
} else if totalResidentMemory < 32*1000*1000 {
|
|
||||||
return nil, fmt.Errorf("model requires at least 32 GB of memory")
|
|
||||||
}
|
|
||||||
case "65B", "70B":
|
|
||||||
if ggml.FileType() == "F16" && totalResidentMemory < 128*1000*1000 {
|
|
||||||
return nil, fmt.Errorf("F16 model requires at least 128 GB of memory")
|
|
||||||
} else if totalResidentMemory < 64*1000*1000 {
|
|
||||||
return nil, fmt.Errorf("model requires at least 64 GB of memory")
|
|
||||||
}
|
|
||||||
case "180B":
|
|
||||||
if ggml.FileType() == "F16" && totalResidentMemory < 512*1000*1000 {
|
|
||||||
return nil, fmt.Errorf("F16 model requires at least 512GB of memory")
|
|
||||||
} else if totalResidentMemory < 128*1000*1000 {
|
|
||||||
return nil, fmt.Errorf("model requires at least 128GB of memory")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch ggml.Name() {
|
switch ggml.Name() {
|
||||||
case "gguf":
|
case "gguf":
|
||||||
opts.NumGQA = 0 // TODO: remove this when llama.cpp runners differ enough to need separate newLlama functions
|
// TODO: gguf will load these options automatically from the model binary
|
||||||
|
opts.NumGQA = 0
|
||||||
|
opts.RopeFrequencyBase = 0.0
|
||||||
|
opts.RopeFrequencyScale = 0.0
|
||||||
return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
|
return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
|
||||||
case "ggml", "ggmf", "ggjt", "ggla":
|
case "ggml", "ggmf", "ggjt", "ggla":
|
||||||
return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
|
return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
|
||||||
|
@@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) {
|
|||||||
command.Args = string(fields[1])
|
command.Args = string(fields[1])
|
||||||
// copy command for validation
|
// copy command for validation
|
||||||
modelCommand = command
|
modelCommand = command
|
||||||
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED", "ADAPTER":
|
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "ADAPTER":
|
||||||
command.Name = string(bytes.ToLower(fields[0]))
|
command.Name = string(bytes.ToLower(fields[0]))
|
||||||
command.Args = string(fields[1])
|
command.Args = string(fields[1])
|
||||||
case "PARAMETER":
|
case "PARAMETER":
|
||||||
@@ -51,6 +51,8 @@ func Parse(reader io.Reader) ([]Command, error) {
|
|||||||
|
|
||||||
command.Name = string(fields[0])
|
command.Name = string(fields[0])
|
||||||
command.Args = string(fields[1])
|
command.Args = string(fields[1])
|
||||||
|
case "EMBED":
|
||||||
|
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
|
||||||
default:
|
default:
|
||||||
if !bytes.HasPrefix(fields[0], []byte("#")) {
|
if !bytes.HasPrefix(fields[0], []byte("#")) {
|
||||||
// log a warning for unknown commands
|
// log a warning for unknown commands
|
||||||
|
372
readline/buffer.go
Normal file
372
readline/buffer.go
Normal file
@@ -0,0 +1,372 @@
|
|||||||
|
package readline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/emirpasic/gods/lists/arraylist"
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Buffer struct {
|
||||||
|
Pos int
|
||||||
|
Buf *arraylist.List
|
||||||
|
Prompt *Prompt
|
||||||
|
LineWidth int
|
||||||
|
Width int
|
||||||
|
Height int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBuffer(prompt *Prompt) (*Buffer, error) {
|
||||||
|
fd := int(os.Stdout.Fd())
|
||||||
|
width, height, err := term.GetSize(fd)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error getting size:", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
lwidth := width - len(prompt.Prompt)
|
||||||
|
if prompt.UseAlt {
|
||||||
|
lwidth = width - len(prompt.AltPrompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
b := &Buffer{
|
||||||
|
Pos: 0,
|
||||||
|
Buf: arraylist.New(),
|
||||||
|
Prompt: prompt,
|
||||||
|
Width: width,
|
||||||
|
Height: height,
|
||||||
|
LineWidth: lwidth,
|
||||||
|
}
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) MoveLeft() {
|
||||||
|
if b.Pos > 0 {
|
||||||
|
if b.Pos%b.LineWidth == 0 {
|
||||||
|
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
||||||
|
} else {
|
||||||
|
fmt.Print(CursorLeft)
|
||||||
|
}
|
||||||
|
b.Pos -= 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) MoveLeftWord() {
|
||||||
|
if b.Pos > 0 {
|
||||||
|
var foundNonspace bool
|
||||||
|
for {
|
||||||
|
v, _ := b.Buf.Get(b.Pos - 1)
|
||||||
|
if v == ' ' {
|
||||||
|
if foundNonspace {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
foundNonspace = true
|
||||||
|
}
|
||||||
|
b.MoveLeft()
|
||||||
|
|
||||||
|
if b.Pos == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) MoveRight() {
|
||||||
|
if b.Pos < b.Size() {
|
||||||
|
b.Pos += 1
|
||||||
|
if b.Pos%b.LineWidth == 0 {
|
||||||
|
fmt.Printf(CursorDown + CursorBOL + cursorRightN(b.PromptSize()))
|
||||||
|
} else {
|
||||||
|
fmt.Print(CursorRight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) MoveRightWord() {
|
||||||
|
if b.Pos < b.Size() {
|
||||||
|
for {
|
||||||
|
b.MoveRight()
|
||||||
|
v, _ := b.Buf.Get(b.Pos)
|
||||||
|
if v == ' ' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.Pos == b.Size() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) MoveToStart() {
|
||||||
|
if b.Pos > 0 {
|
||||||
|
currLine := b.Pos / b.LineWidth
|
||||||
|
if currLine > 0 {
|
||||||
|
for cnt := 0; cnt < currLine; cnt++ {
|
||||||
|
fmt.Print(CursorUp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Printf(CursorBOL + cursorRightN(b.PromptSize()))
|
||||||
|
b.Pos = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) MoveToEnd() {
|
||||||
|
if b.Pos < b.Size() {
|
||||||
|
currLine := b.Pos / b.LineWidth
|
||||||
|
totalLines := b.Size() / b.LineWidth
|
||||||
|
if currLine < totalLines {
|
||||||
|
for cnt := 0; cnt < totalLines-currLine; cnt++ {
|
||||||
|
fmt.Print(CursorDown)
|
||||||
|
}
|
||||||
|
remainder := b.Size() % b.LineWidth
|
||||||
|
fmt.Printf(CursorBOL + cursorRightN(b.PromptSize()+remainder))
|
||||||
|
} else {
|
||||||
|
fmt.Print(cursorRightN(b.Size() - b.Pos))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Pos = b.Size()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) Size() int {
|
||||||
|
return b.Buf.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(n, m int) int {
|
||||||
|
if n > m {
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) PromptSize() int {
|
||||||
|
if b.Prompt.UseAlt {
|
||||||
|
return len(b.Prompt.AltPrompt)
|
||||||
|
}
|
||||||
|
return len(b.Prompt.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) Add(r rune) {
|
||||||
|
if b.Pos == b.Buf.Size() {
|
||||||
|
fmt.Printf("%c", r)
|
||||||
|
b.Buf.Add(r)
|
||||||
|
b.Pos += 1
|
||||||
|
if b.Pos > 0 && b.Pos%b.LineWidth == 0 {
|
||||||
|
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%c", r)
|
||||||
|
b.Buf.Insert(b.Pos, r)
|
||||||
|
b.Pos += 1
|
||||||
|
if b.Pos > 0 && b.Pos%b.LineWidth == 0 {
|
||||||
|
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||||
|
}
|
||||||
|
b.drawRemaining()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) drawRemaining() {
|
||||||
|
var place int
|
||||||
|
remainingText := b.StringN(b.Pos)
|
||||||
|
if b.Pos > 0 {
|
||||||
|
place = b.Pos % b.LineWidth
|
||||||
|
}
|
||||||
|
fmt.Print(CursorHide)
|
||||||
|
|
||||||
|
// render the rest of the current line
|
||||||
|
currLine := remainingText[:min(b.LineWidth-place, len(remainingText))]
|
||||||
|
if len(currLine) > 0 {
|
||||||
|
fmt.Printf(ClearToEOL + currLine)
|
||||||
|
fmt.Print(cursorLeftN(len(currLine)))
|
||||||
|
} else {
|
||||||
|
fmt.Print(ClearToEOL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// render the other lines
|
||||||
|
if len(remainingText) > len(currLine) {
|
||||||
|
remaining := []rune(remainingText[len(currLine):])
|
||||||
|
var totalLines int
|
||||||
|
for i, c := range remaining {
|
||||||
|
if i%b.LineWidth == 0 {
|
||||||
|
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||||
|
totalLines += 1
|
||||||
|
}
|
||||||
|
fmt.Printf("%c", c)
|
||||||
|
}
|
||||||
|
fmt.Print(ClearToEOL)
|
||||||
|
fmt.Print(cursorUpN(totalLines))
|
||||||
|
fmt.Printf(CursorBOL + cursorRightN(b.Width-len(currLine)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Print(CursorShow)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) Remove() {
|
||||||
|
if b.Buf.Size() > 0 && b.Pos > 0 {
|
||||||
|
if b.Pos%b.LineWidth == 0 {
|
||||||
|
// if the user backspaces over the word boundary, do this magic to clear the line
|
||||||
|
// and move to the end of the previous line
|
||||||
|
fmt.Printf(CursorBOL + ClearToEOL)
|
||||||
|
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width) + " " + CursorLeft)
|
||||||
|
} else {
|
||||||
|
fmt.Printf(CursorLeft + " " + CursorLeft)
|
||||||
|
}
|
||||||
|
|
||||||
|
var eraseExtraLine bool
|
||||||
|
if (b.Size()-1)%b.LineWidth == 0 {
|
||||||
|
eraseExtraLine = true
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Pos -= 1
|
||||||
|
b.Buf.Remove(b.Pos)
|
||||||
|
|
||||||
|
if b.Pos < b.Size() {
|
||||||
|
b.drawRemaining()
|
||||||
|
// this erases a line which is left over when backspacing in the middle of a line and there
|
||||||
|
// are trailing characters which go over the line width boundary
|
||||||
|
if eraseExtraLine {
|
||||||
|
remainingLines := (b.Size() - b.Pos) / b.LineWidth
|
||||||
|
fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
|
||||||
|
place := b.Pos % b.LineWidth
|
||||||
|
fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.Prompt)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) Delete() {
|
||||||
|
if b.Size() > 0 && b.Pos < b.Size() {
|
||||||
|
b.Buf.Remove(b.Pos)
|
||||||
|
b.drawRemaining()
|
||||||
|
if b.Size()%b.LineWidth == 0 {
|
||||||
|
if b.Pos != b.Size() {
|
||||||
|
remainingLines := (b.Size() - b.Pos) / b.LineWidth
|
||||||
|
fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL)
|
||||||
|
place := b.Pos % b.LineWidth
|
||||||
|
fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.Prompt)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) DeleteBefore() {
|
||||||
|
if b.Pos > 0 {
|
||||||
|
for cnt := b.Pos - 1; cnt >= 0; cnt-- {
|
||||||
|
b.Remove()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) DeleteRemaining() {
|
||||||
|
if b.Size() > 0 && b.Pos < b.Size() {
|
||||||
|
charsToDel := b.Size() - b.Pos
|
||||||
|
for cnt := 0; cnt < charsToDel; cnt++ {
|
||||||
|
b.Delete()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) DeleteWord() {
|
||||||
|
if b.Buf.Size() > 0 && b.Pos > 0 {
|
||||||
|
var foundNonspace bool
|
||||||
|
for {
|
||||||
|
v, _ := b.Buf.Get(b.Pos - 1)
|
||||||
|
if v == ' ' {
|
||||||
|
if !foundNonspace {
|
||||||
|
b.Remove()
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
foundNonspace = true
|
||||||
|
b.Remove()
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.Pos == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) ClearScreen() {
|
||||||
|
fmt.Printf(ClearScreen + CursorReset + b.Prompt.Prompt)
|
||||||
|
if b.IsEmpty() {
|
||||||
|
ph := b.Prompt.Placeholder
|
||||||
|
fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault)
|
||||||
|
} else {
|
||||||
|
currPos := b.Pos
|
||||||
|
b.Pos = 0
|
||||||
|
b.drawRemaining()
|
||||||
|
fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.Prompt)))
|
||||||
|
if currPos > 0 {
|
||||||
|
targetLine := currPos / b.LineWidth
|
||||||
|
if targetLine > 0 {
|
||||||
|
for cnt := 0; cnt < targetLine; cnt++ {
|
||||||
|
fmt.Print(CursorDown)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
remainder := currPos % b.LineWidth
|
||||||
|
if remainder > 0 {
|
||||||
|
fmt.Print(cursorRightN(remainder))
|
||||||
|
}
|
||||||
|
if currPos%b.LineWidth == 0 {
|
||||||
|
fmt.Printf(CursorBOL + b.Prompt.AltPrompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.Pos = currPos
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) IsEmpty() bool {
|
||||||
|
return b.Buf.Empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) Replace(r []rune) {
|
||||||
|
b.Pos = 0
|
||||||
|
b.Buf.Clear()
|
||||||
|
fmt.Printf(ClearLine + CursorBOL + b.Prompt.Prompt)
|
||||||
|
for _, c := range r {
|
||||||
|
b.Add(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) String() string {
|
||||||
|
return b.StringN(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) StringN(n int) string {
|
||||||
|
return b.StringNM(n, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) StringNM(n, m int) string {
|
||||||
|
var s string
|
||||||
|
if m == 0 {
|
||||||
|
m = b.Size()
|
||||||
|
}
|
||||||
|
for cnt := n; cnt < m; cnt++ {
|
||||||
|
c, _ := b.Buf.Get(cnt)
|
||||||
|
s += string(c.(rune))
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func cursorLeftN(n int) string {
|
||||||
|
return fmt.Sprintf(CursorLeftN, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cursorRightN(n int) string {
|
||||||
|
return fmt.Sprintf(CursorRightN, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cursorUpN(n int) string {
|
||||||
|
return fmt.Sprintf(CursorUpN, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cursorDownN(n int) string {
|
||||||
|
return fmt.Sprintf(CursorDownN, n)
|
||||||
|
}
|
17
readline/errors.go
Normal file
17
readline/errors.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package readline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInterrupt = errors.New("Interrupt")
|
||||||
|
)
|
||||||
|
|
||||||
|
type InterruptError struct {
|
||||||
|
Line []rune
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*InterruptError) Error() string {
|
||||||
|
return "Interrupted"
|
||||||
|
}
|
152
readline/history.go
Normal file
152
readline/history.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
package readline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/emirpasic/gods/lists/arraylist"
|
||||||
|
)
|
||||||
|
|
||||||
|
type History struct {
|
||||||
|
Buf *arraylist.List
|
||||||
|
Autosave bool
|
||||||
|
Pos int
|
||||||
|
Limit int
|
||||||
|
Filename string
|
||||||
|
Enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHistory() (*History, error) {
|
||||||
|
h := &History{
|
||||||
|
Buf: arraylist.New(),
|
||||||
|
Limit: 100, //resizeme
|
||||||
|
Autosave: true,
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.Init()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *History) Init() error {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
path := filepath.Join(home, ".ollama", "history")
|
||||||
|
h.Filename = path
|
||||||
|
|
||||||
|
//todo check if the file exists
|
||||||
|
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0600)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
r := bufio.NewReader(f)
|
||||||
|
for {
|
||||||
|
line, err := r.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
h.Add([]rune(line))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *History) Add(l []rune) {
|
||||||
|
h.Buf.Add(l)
|
||||||
|
h.Compact()
|
||||||
|
h.Pos = h.Size()
|
||||||
|
if h.Autosave {
|
||||||
|
h.Save()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *History) Compact() {
|
||||||
|
s := h.Buf.Size()
|
||||||
|
if s > h.Limit {
|
||||||
|
for cnt := 0; cnt < s-h.Limit; cnt++ {
|
||||||
|
h.Buf.Remove(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *History) Clear() {
|
||||||
|
h.Buf.Clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *History) Prev() []rune {
|
||||||
|
var line []rune
|
||||||
|
if h.Pos > 0 {
|
||||||
|
h.Pos -= 1
|
||||||
|
}
|
||||||
|
v, _ := h.Buf.Get(h.Pos)
|
||||||
|
line, _ = v.([]rune)
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *History) Next() []rune {
|
||||||
|
var line []rune
|
||||||
|
if h.Pos < h.Buf.Size() {
|
||||||
|
h.Pos += 1
|
||||||
|
v, _ := h.Buf.Get(h.Pos)
|
||||||
|
line, _ = v.([]rune)
|
||||||
|
}
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *History) Size() int {
|
||||||
|
return h.Buf.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *History) Save() error {
|
||||||
|
if !h.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpFile := h.Filename + ".tmp"
|
||||||
|
|
||||||
|
f, err := os.OpenFile(tmpFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC|os.O_APPEND, 0666)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
buf := bufio.NewWriter(f)
|
||||||
|
for cnt := 0; cnt < h.Size(); cnt++ {
|
||||||
|
v, _ := h.Buf.Get(cnt)
|
||||||
|
line, _ := v.([]rune)
|
||||||
|
buf.WriteString(string(line) + "\n")
|
||||||
|
}
|
||||||
|
buf.Flush()
|
||||||
|
f.Close()
|
||||||
|
|
||||||
|
if err = os.Rename(tmpFile, h.Filename); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
254
readline/readline.go
Normal file
254
readline/readline.go
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
package readline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Prompt struct {
|
||||||
|
Prompt string
|
||||||
|
AltPrompt string
|
||||||
|
Placeholder string
|
||||||
|
AltPlaceholder string
|
||||||
|
UseAlt bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Terminal struct {
|
||||||
|
outchan chan rune
|
||||||
|
}
|
||||||
|
|
||||||
|
type Instance struct {
|
||||||
|
Prompt *Prompt
|
||||||
|
Terminal *Terminal
|
||||||
|
History *History
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(prompt Prompt) (*Instance, error) {
|
||||||
|
term, err := NewTerminal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
history, err := NewHistory()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Instance{
|
||||||
|
Prompt: &prompt,
|
||||||
|
Terminal: term,
|
||||||
|
History: history,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Instance) Readline() (string, error) {
|
||||||
|
prompt := i.Prompt.Prompt
|
||||||
|
if i.Prompt.UseAlt {
|
||||||
|
prompt = i.Prompt.AltPrompt
|
||||||
|
}
|
||||||
|
fmt.Print(prompt)
|
||||||
|
|
||||||
|
fd := int(syscall.Stdin)
|
||||||
|
termios, err := SetRawMode(fd)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer UnsetRawMode(fd, termios)
|
||||||
|
|
||||||
|
buf, _ := NewBuffer(i.Prompt)
|
||||||
|
|
||||||
|
var esc bool
|
||||||
|
var escex bool
|
||||||
|
var metaDel bool
|
||||||
|
var pasteMode PasteMode
|
||||||
|
|
||||||
|
var currentLineBuf []rune
|
||||||
|
|
||||||
|
for {
|
||||||
|
if buf.IsEmpty() {
|
||||||
|
ph := i.Prompt.Placeholder
|
||||||
|
if i.Prompt.UseAlt {
|
||||||
|
ph = i.Prompt.AltPlaceholder
|
||||||
|
}
|
||||||
|
fmt.Printf(ColorGrey + ph + fmt.Sprintf(CursorLeftN, len(ph)) + ColorDefault)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := i.Terminal.Read()
|
||||||
|
|
||||||
|
if buf.IsEmpty() {
|
||||||
|
fmt.Print(ClearToEOL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
if escex {
|
||||||
|
escex = false
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case KeyUp:
|
||||||
|
if i.History.Pos > 0 {
|
||||||
|
if i.History.Pos == i.History.Size() {
|
||||||
|
currentLineBuf = []rune(buf.String())
|
||||||
|
}
|
||||||
|
buf.Replace(i.History.Prev())
|
||||||
|
}
|
||||||
|
case KeyDown:
|
||||||
|
if i.History.Pos < i.History.Size() {
|
||||||
|
buf.Replace(i.History.Next())
|
||||||
|
if i.History.Pos == i.History.Size() {
|
||||||
|
buf.Replace(currentLineBuf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case KeyLeft:
|
||||||
|
buf.MoveLeft()
|
||||||
|
case KeyRight:
|
||||||
|
buf.MoveRight()
|
||||||
|
case CharBracketedPaste:
|
||||||
|
var code string
|
||||||
|
for cnt := 0; cnt < 3; cnt++ {
|
||||||
|
r, err = i.Terminal.Read()
|
||||||
|
if err != nil {
|
||||||
|
return "", io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
code += string(r)
|
||||||
|
}
|
||||||
|
if code == CharBracketedPasteStart {
|
||||||
|
pasteMode = PasteModeStart
|
||||||
|
} else if code == CharBracketedPasteEnd {
|
||||||
|
pasteMode = PasteModeEnd
|
||||||
|
}
|
||||||
|
case KeyDel:
|
||||||
|
if buf.Size() > 0 {
|
||||||
|
buf.Delete()
|
||||||
|
}
|
||||||
|
metaDel = true
|
||||||
|
case MetaStart:
|
||||||
|
buf.MoveToStart()
|
||||||
|
case MetaEnd:
|
||||||
|
buf.MoveToEnd()
|
||||||
|
default:
|
||||||
|
// skip any keys we don't know about
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
} else if esc {
|
||||||
|
esc = false
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case 'b':
|
||||||
|
buf.MoveLeftWord()
|
||||||
|
case 'f':
|
||||||
|
buf.MoveRightWord()
|
||||||
|
case CharEscapeEx:
|
||||||
|
escex = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case CharNull:
|
||||||
|
continue
|
||||||
|
case CharEsc:
|
||||||
|
esc = true
|
||||||
|
case CharInterrupt:
|
||||||
|
return "", ErrInterrupt
|
||||||
|
case CharLineStart:
|
||||||
|
buf.MoveToStart()
|
||||||
|
case CharLineEnd:
|
||||||
|
buf.MoveToEnd()
|
||||||
|
case CharBackward:
|
||||||
|
buf.MoveLeft()
|
||||||
|
case CharForward:
|
||||||
|
buf.MoveRight()
|
||||||
|
case CharBackspace, CharCtrlH:
|
||||||
|
buf.Remove()
|
||||||
|
case CharTab:
|
||||||
|
// todo: convert back to real tabs
|
||||||
|
for cnt := 0; cnt < 8; cnt++ {
|
||||||
|
buf.Add(' ')
|
||||||
|
}
|
||||||
|
case CharDelete:
|
||||||
|
if buf.Size() > 0 {
|
||||||
|
buf.Delete()
|
||||||
|
} else {
|
||||||
|
return "", io.EOF
|
||||||
|
}
|
||||||
|
case CharKill:
|
||||||
|
buf.DeleteRemaining()
|
||||||
|
case CharCtrlU:
|
||||||
|
buf.DeleteBefore()
|
||||||
|
case CharCtrlL:
|
||||||
|
buf.ClearScreen()
|
||||||
|
case CharCtrlW:
|
||||||
|
buf.DeleteWord()
|
||||||
|
case CharEnter:
|
||||||
|
output := buf.String()
|
||||||
|
if output != "" {
|
||||||
|
i.History.Add([]rune(output))
|
||||||
|
}
|
||||||
|
buf.MoveToEnd()
|
||||||
|
fmt.Println()
|
||||||
|
switch pasteMode {
|
||||||
|
case PasteModeStart:
|
||||||
|
output = `"""` + output
|
||||||
|
case PasteModeEnd:
|
||||||
|
output = output + `"""`
|
||||||
|
}
|
||||||
|
return output, nil
|
||||||
|
default:
|
||||||
|
if metaDel {
|
||||||
|
metaDel = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if r >= CharSpace || r == CharEnter {
|
||||||
|
buf.Add(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Instance) HistoryEnable() {
|
||||||
|
i.History.Enabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Instance) HistoryDisable() {
|
||||||
|
i.History.Enabled = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTerminal() (*Terminal, error) {
|
||||||
|
t := &Terminal{
|
||||||
|
outchan: make(chan rune),
|
||||||
|
}
|
||||||
|
|
||||||
|
go t.ioloop()
|
||||||
|
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Terminal) ioloop() {
|
||||||
|
buf := bufio.NewReader(os.Stdin)
|
||||||
|
|
||||||
|
for {
|
||||||
|
r, _, err := buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
close(t.outchan)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
t.outchan <- r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Terminal) Read() (rune, error) {
|
||||||
|
r, ok := <-t.outchan
|
||||||
|
if !ok {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
return r, nil
|
||||||
|
}
|
36
readline/term.go
Normal file
36
readline/term.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
//go:build aix || darwin || dragonfly || freebsd || (linux && !appengine) || netbsd || openbsd || os400 || solaris
|
||||||
|
|
||||||
|
package readline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Termios syscall.Termios
|
||||||
|
|
||||||
|
func SetRawMode(fd int) (*Termios, error) {
|
||||||
|
termios, err := getTermios(fd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
newTermios := *termios
|
||||||
|
newTermios.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON
|
||||||
|
newTermios.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN
|
||||||
|
newTermios.Cflag &^= syscall.CSIZE | syscall.PARENB
|
||||||
|
newTermios.Cflag |= syscall.CS8
|
||||||
|
newTermios.Cc[syscall.VMIN] = 1
|
||||||
|
newTermios.Cc[syscall.VTIME] = 0
|
||||||
|
|
||||||
|
return termios, setTermios(fd, &newTermios)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnsetRawMode(fd int, termios *Termios) error {
|
||||||
|
return setTermios(fd, termios)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||||
|
func IsTerminal(fd int) bool {
|
||||||
|
_, err := getTermios(fd)
|
||||||
|
return err == nil
|
||||||
|
}
|
25
readline/term_bsd.go
Normal file
25
readline/term_bsd.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build darwin || freebsd || netbsd || openbsd
|
||||||
|
|
||||||
|
package readline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getTermios(fd int) (*Termios, error) {
|
||||||
|
termios := new(Termios)
|
||||||
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||||
|
if err != 0 {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return termios, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setTermios(fd int, termios *Termios) error {
|
||||||
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||||
|
if err != 0 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
28
readline/term_linux.go
Normal file
28
readline/term_linux.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
//go:build linux || solaris
|
||||||
|
|
||||||
|
package readline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
const tcgets = 0x5401
|
||||||
|
const tcsets = 0x5402
|
||||||
|
|
||||||
|
func getTermios(fd int) (*Termios, error) {
|
||||||
|
termios := new(Termios)
|
||||||
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||||
|
if err != 0 {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return termios, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setTermios(fd int, termios *Termios) error {
|
||||||
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||||
|
if err != 0 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
62
readline/term_windows.go
Normal file
62
readline/term_windows.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package readline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
enableLineInput = 2
|
||||||
|
enableWindowInput = 8
|
||||||
|
enableMouseInput = 16
|
||||||
|
enableInsertMode = 32
|
||||||
|
enableQuickEditMode = 64
|
||||||
|
enableExtendedFlags = 128
|
||||||
|
enableProcessedOutput = 1
|
||||||
|
enableWrapAtEolOutput = 2
|
||||||
|
enableAutoPosition = 256 // Cursor position is not affected by writing data to the console.
|
||||||
|
enableEchoInput = 4 // Characters are written to the console as they're read.
|
||||||
|
enableProcessedInput = 1 // Enables input processing (like recognizing Ctrl+C).
|
||||||
|
)
|
||||||
|
|
||||||
|
var kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||||
|
|
||||||
|
var (
|
||||||
|
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
|
||||||
|
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
|
||||||
|
)
|
||||||
|
|
||||||
|
type State struct {
|
||||||
|
mode uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTerminal checks if the given file descriptor is associated with a terminal
|
||||||
|
func IsTerminal(fd int) bool {
|
||||||
|
var st uint32
|
||||||
|
r, _, e := syscall.SyscallN(procGetConsoleMode.Addr(), uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
|
||||||
|
// if the call succeeds and doesn't produce an error, it's a terminal
|
||||||
|
return r != 0 && e == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetRawMode(fd int) (*State, error) {
|
||||||
|
var st uint32
|
||||||
|
// retrieve the current mode of the terminal
|
||||||
|
_, _, e := syscall.SyscallN(procGetConsoleMode.Addr(), uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
|
||||||
|
if e != 0 {
|
||||||
|
return nil, error(e)
|
||||||
|
}
|
||||||
|
// modify the mode to set it to raw
|
||||||
|
raw := st &^ (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput)
|
||||||
|
// apply the new mode to the terminal
|
||||||
|
_, _, e = syscall.SyscallN(procSetConsoleMode.Addr(), uintptr(fd), uintptr(raw), 0)
|
||||||
|
if e != 0 {
|
||||||
|
return nil, error(e)
|
||||||
|
}
|
||||||
|
// return the original state so that it can be restored later
|
||||||
|
return &State{st}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnsetRawMode(fd int, state *State) error {
|
||||||
|
_, _, err := syscall.SyscallN(procSetConsoleMode.Addr(), uintptr(fd), uintptr(state.mode), 0)
|
||||||
|
return err
|
||||||
|
}
|
86
readline/types.go
Normal file
86
readline/types.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package readline
|
||||||
|
|
||||||
|
const (
|
||||||
|
CharNull = 0
|
||||||
|
CharLineStart = 1
|
||||||
|
CharBackward = 2
|
||||||
|
CharInterrupt = 3
|
||||||
|
CharDelete = 4
|
||||||
|
CharLineEnd = 5
|
||||||
|
CharForward = 6
|
||||||
|
CharBell = 7
|
||||||
|
CharCtrlH = 8
|
||||||
|
CharTab = 9
|
||||||
|
CharCtrlJ = 10
|
||||||
|
CharKill = 11
|
||||||
|
CharCtrlL = 12
|
||||||
|
CharEnter = 13
|
||||||
|
CharNext = 14
|
||||||
|
CharPrev = 16
|
||||||
|
CharBckSearch = 18
|
||||||
|
CharFwdSearch = 19
|
||||||
|
CharTranspose = 20
|
||||||
|
CharCtrlU = 21
|
||||||
|
CharCtrlW = 23
|
||||||
|
CharCtrlY = 25
|
||||||
|
CharCtrlZ = 26
|
||||||
|
CharEsc = 27
|
||||||
|
CharSpace = 32
|
||||||
|
CharEscapeEx = 91
|
||||||
|
CharBackspace = 127
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
KeyDel = 51
|
||||||
|
KeyUp = 65
|
||||||
|
KeyDown = 66
|
||||||
|
KeyRight = 67
|
||||||
|
KeyLeft = 68
|
||||||
|
MetaEnd = 70
|
||||||
|
MetaStart = 72
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CursorUp = "\033[1A"
|
||||||
|
CursorDown = "\033[1B"
|
||||||
|
CursorRight = "\033[1C"
|
||||||
|
CursorLeft = "\033[1D"
|
||||||
|
|
||||||
|
CursorSave = "\033[s"
|
||||||
|
CursorRestore = "\033[u"
|
||||||
|
|
||||||
|
CursorUpN = "\033[%dA"
|
||||||
|
CursorDownN = "\033[%dB"
|
||||||
|
CursorRightN = "\033[%dC"
|
||||||
|
CursorLeftN = "\033[%dD"
|
||||||
|
|
||||||
|
CursorEOL = "\033[E"
|
||||||
|
CursorBOL = "\033[1G"
|
||||||
|
CursorHide = "\033[?25l"
|
||||||
|
CursorShow = "\033[?25h"
|
||||||
|
|
||||||
|
ClearToEOL = "\033[K"
|
||||||
|
ClearLine = "\033[2K"
|
||||||
|
ClearScreen = "\033[2J"
|
||||||
|
CursorReset = "\033[0;0f"
|
||||||
|
|
||||||
|
ColorGrey = "\033[38;5;245m"
|
||||||
|
ColorDefault = "\033[0m"
|
||||||
|
|
||||||
|
StartBracketedPaste = "\033[?2004h"
|
||||||
|
EndBracketedPaste = "\033[?2004l"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CharBracketedPaste = 50
|
||||||
|
CharBracketedPasteStart = "00~"
|
||||||
|
CharBracketedPasteEnd = "01~"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PasteMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
PastModeOff = iota
|
||||||
|
PasteModeStart
|
||||||
|
PasteModeEnd
|
||||||
|
)
|
@@ -26,7 +26,8 @@ require() {
|
|||||||
|
|
||||||
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
|
||||||
|
|
||||||
case "$(uname -m)" in
|
ARCH=$(uname -m)
|
||||||
|
case "$ARCH" in
|
||||||
x86_64) ARCH="amd64" ;;
|
x86_64) ARCH="amd64" ;;
|
||||||
aarch64|arm64) ARCH="arm64" ;;
|
aarch64|arm64) ARCH="arm64" ;;
|
||||||
*) error "Unsupported architecture: $ARCH" ;;
|
*) error "Unsupported architecture: $ARCH" ;;
|
||||||
@@ -62,7 +63,10 @@ status "Installing ollama to $BINDIR..."
|
|||||||
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
||||||
$SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama
|
$SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama
|
||||||
|
|
||||||
install_success() { status 'Install complete. Run "ollama" from the command line.'; }
|
install_success() {
|
||||||
|
status 'The Ollama API is now available at 0.0.0.0:11434.'
|
||||||
|
status 'Install complete. Run "ollama" from the command line.'
|
||||||
|
}
|
||||||
trap install_success EXIT
|
trap install_success EXIT
|
||||||
|
|
||||||
# Everything from this point onwards is optional.
|
# Everything from this point onwards is optional.
|
||||||
@@ -73,6 +77,9 @@ configure_systemd() {
|
|||||||
$SUDO useradd -r -s /bin/false -m -d /usr/share/ollama ollama
|
$SUDO useradd -r -s /bin/false -m -d /usr/share/ollama ollama
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
status "Adding current user to ollama group..."
|
||||||
|
$SUDO usermod -a -G ollama $(whoami)
|
||||||
|
|
||||||
status "Creating ollama systemd service..."
|
status "Creating ollama systemd service..."
|
||||||
cat <<EOF | $SUDO tee /etc/systemd/system/ollama.service >/dev/null
|
cat <<EOF | $SUDO tee /etc/systemd/system/ollama.service >/dev/null
|
||||||
[Unit]
|
[Unit]
|
||||||
@@ -85,7 +92,6 @@ User=ollama
|
|||||||
Group=ollama
|
Group=ollama
|
||||||
Restart=always
|
Restart=always
|
||||||
RestartSec=3
|
RestartSec=3
|
||||||
Environment="HOME=/usr/share/ollama"
|
|
||||||
Environment="PATH=$PATH"
|
Environment="PATH=$PATH"
|
||||||
|
|
||||||
[Install]
|
[Install]
|
||||||
@@ -127,6 +133,7 @@ if check_gpu nvidia-smi; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if ! check_gpu lspci && ! check_gpu lshw; then
|
if ! check_gpu lspci && ! check_gpu lshw; then
|
||||||
|
install_success
|
||||||
warning "No NVIDIA GPU detected. Ollama will run in CPU-only mode."
|
warning "No NVIDIA GPU detected. Ollama will run in CPU-only mode."
|
||||||
exit 0
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
15
scripts/push_docker.sh
Executable file
15
scripts/push_docker.sh
Executable file
@@ -0,0 +1,15 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
export VERSION=${VERSION:-0.0.0}
|
||||||
|
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
|
||||||
|
|
||||||
|
docker buildx build \
|
||||||
|
--push \
|
||||||
|
--platform=linux/arm64,linux/amd64 \
|
||||||
|
--build-arg=VERSION \
|
||||||
|
--build-arg=GOFLAGS \
|
||||||
|
-f Dockerfile \
|
||||||
|
-t ollama/ollama -t ollama/ollama:$VERSION \
|
||||||
|
.
|
@@ -91,7 +91,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := SignatureData{
|
s := SignatureData{
|
||||||
Method: "GET",
|
Method: http.MethodGet,
|
||||||
Path: redirectURL.String(),
|
Path: redirectURL.String(),
|
||||||
Data: nil,
|
Data: nil,
|
||||||
}
|
}
|
||||||
@@ -103,9 +103,10 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
|
|||||||
|
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
headers.Set("Authorization", sig)
|
headers.Set("Authorization", sig)
|
||||||
resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, nil)
|
resp, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("couldn't get token: %q", err)
|
log.Printf("couldn't get token: %q", err)
|
||||||
|
return "", err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
@@ -15,6 +15,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
@@ -88,17 +89,12 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(b.Parts) == 0 {
|
if len(b.Parts) == 0 {
|
||||||
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts)
|
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode >= http.StatusBadRequest {
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
||||||
|
|
||||||
var size = b.Total / numDownloadParts
|
var size = b.Total / numDownloadParts
|
||||||
@@ -133,7 +129,6 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
|
|||||||
|
|
||||||
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||||
defer blobDownloadManager.Delete(b.Digest)
|
defer blobDownloadManager.Delete(b.Digest)
|
||||||
|
|
||||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||||
|
|
||||||
file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
|
file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
|
||||||
@@ -158,7 +153,8 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
|
|||||||
w := io.NewOffsetWriter(file, part.StartsAt())
|
w := io.NewOffsetWriter(file, part.StartsAt())
|
||||||
err := b.downloadChunk(inner, requestURL, w, part, opts)
|
err := b.downloadChunk(inner, requestURL, w, part, opts)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, context.Canceled):
|
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||||
|
// return immediately if the context is canceled or the device is out of space
|
||||||
return err
|
return err
|
||||||
case err != nil:
|
case err != nil:
|
||||||
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], i, try, err)
|
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], i, try, err)
|
||||||
@@ -168,7 +164,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return errors.New("max retries exceeded")
|
return errMaxRetriesExceeded
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,7 +194,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
|
|||||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
|
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||||
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts)
|
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -306,6 +302,8 @@ type downloadOpts struct {
|
|||||||
|
|
||||||
const maxRetries = 3
|
const maxRetries = 3
|
||||||
|
|
||||||
|
var errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||||
|
|
||||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||||
func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
||||||
fp, err := GetBlobsPath(opts.digest)
|
fp, err := GetBlobsPath(opts.digest)
|
||||||
|
378
server/images.go
378
server/images.go
@@ -1,7 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
@@ -26,7 +25,6 @@ import (
|
|||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/llm"
|
"github.com/jmorganca/ollama/llm"
|
||||||
"github.com/jmorganca/ollama/parser"
|
"github.com/jmorganca/ollama/parser"
|
||||||
"github.com/jmorganca/ollama/vector"
|
|
||||||
"github.com/jmorganca/ollama/version"
|
"github.com/jmorganca/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,12 +45,10 @@ type Model struct {
|
|||||||
System string
|
System string
|
||||||
License []string
|
License []string
|
||||||
Digest string
|
Digest string
|
||||||
ConfigDigest string
|
|
||||||
Options map[string]interface{}
|
Options map[string]interface{}
|
||||||
Embeddings []vector.Embedding
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
|
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
||||||
t := m.Template
|
t := m.Template
|
||||||
if request.Template != "" {
|
if request.Template != "" {
|
||||||
t = request.Template
|
t = request.Template
|
||||||
@@ -67,17 +63,11 @@ func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, e
|
|||||||
First bool
|
First bool
|
||||||
System string
|
System string
|
||||||
Prompt string
|
Prompt string
|
||||||
Embed string
|
|
||||||
|
|
||||||
// deprecated: versions <= 0.0.7 used this to omit the system prompt
|
|
||||||
Context []int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vars.First = len(request.Context) == 0
|
vars.First = len(request.Context) == 0
|
||||||
vars.System = m.System
|
vars.System = m.System
|
||||||
vars.Prompt = request.Prompt
|
vars.Prompt = request.Prompt
|
||||||
vars.Context = request.Context
|
|
||||||
vars.Embed = embedding
|
|
||||||
|
|
||||||
if request.System != "" {
|
if request.System != "" {
|
||||||
vars.System = request.System
|
vars.System = request.System
|
||||||
@@ -137,7 +127,7 @@ func (m *ManifestV2) GetTotalSize() (total int64) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
|
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
|
||||||
fp, err := mp.GetManifestPath(false)
|
fp, err := mp.GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
@@ -171,12 +161,11 @@ func GetModel(name string) (*Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
model := &Model{
|
model := &Model{
|
||||||
Name: mp.GetFullTagname(),
|
Name: mp.GetFullTagname(),
|
||||||
ShortName: mp.GetShortTagname(),
|
ShortName: mp.GetShortTagname(),
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
ConfigDigest: manifest.Config.Digest,
|
Template: "{{ .Prompt }}",
|
||||||
Template: "{{ .Prompt }}",
|
License: []string{},
|
||||||
License: []string{},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, layer := range manifest.Layers {
|
for _, layer := range manifest.Layers {
|
||||||
@@ -190,15 +179,9 @@ func GetModel(name string) (*Model, error) {
|
|||||||
model.ModelPath = filename
|
model.ModelPath = filename
|
||||||
model.OriginalModel = layer.From
|
model.OriginalModel = layer.From
|
||||||
case "application/vnd.ollama.image.embed":
|
case "application/vnd.ollama.image.embed":
|
||||||
file, err := os.Open(filename)
|
// Deprecated in versions > 0.1.2
|
||||||
if err != nil {
|
// TODO: remove this warning in a future version
|
||||||
return nil, fmt.Errorf("failed to open file: %s", filename)
|
log.Print("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
case "application/vnd.ollama.image.adapter":
|
case "application/vnd.ollama.image.adapter":
|
||||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||||
case "application/vnd.ollama.image.template":
|
case "application/vnd.ollama.image.template":
|
||||||
@@ -265,7 +248,7 @@ func filenameWithPath(path, f string) (string, error) {
|
|||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateModel(ctx context.Context, workDir, name string, path string, fn func(resp api.ProgressResponse)) error {
|
func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
|
||||||
mp := ParseModelPath(name)
|
mp := ParseModelPath(name)
|
||||||
|
|
||||||
var manifest *ManifestV2
|
var manifest *ManifestV2
|
||||||
@@ -310,13 +293,11 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
|
|||||||
var layers []*LayerReader
|
var layers []*LayerReader
|
||||||
params := make(map[string][]string)
|
params := make(map[string][]string)
|
||||||
var sourceParams map[string]any
|
var sourceParams map[string]any
|
||||||
embed := EmbeddingParams{fn: fn}
|
|
||||||
for _, c := range commands {
|
for _, c := range commands {
|
||||||
log.Printf("[%s] - %s\n", c.Name, c.Args)
|
log.Printf("[%s] - %s\n", c.Name, c.Args)
|
||||||
switch c.Name {
|
switch c.Name {
|
||||||
case "model":
|
case "model":
|
||||||
fn(api.ProgressResponse{Status: "looking for model"})
|
fn(api.ProgressResponse{Status: "looking for model"})
|
||||||
embed.model = c.Args
|
|
||||||
|
|
||||||
mp := ParseModelPath(c.Args)
|
mp := ParseModelPath(c.Args)
|
||||||
mf, _, err := GetManifest(mp)
|
mf, _, err := GetManifest(mp)
|
||||||
@@ -340,7 +321,6 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
embed.model = modelFile
|
|
||||||
// create a model from this specified file
|
// create a model from this specified file
|
||||||
fn(api.ProgressResponse{Status: "creating model layer"})
|
fn(api.ProgressResponse{Status: "creating model layer"})
|
||||||
file, err := os.Open(modelFile)
|
file, err := os.Open(modelFile)
|
||||||
@@ -421,12 +401,6 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
|
|||||||
layers = append(layers, newLayer)
|
layers = append(layers, newLayer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "embed":
|
|
||||||
embedFilePath, err := filenameWithPath(path, c.Args)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
embed.files = append(embed.files, embedFilePath)
|
|
||||||
case "adapter":
|
case "adapter":
|
||||||
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
||||||
|
|
||||||
@@ -517,18 +491,8 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
|
|||||||
}
|
}
|
||||||
l.MediaType = "application/vnd.ollama.image.params"
|
l.MediaType = "application/vnd.ollama.image.params"
|
||||||
layers = append(layers, l)
|
layers = append(layers, l)
|
||||||
|
|
||||||
// apply these parameters to the embedding options, in case embeddings need to be generated using this model
|
|
||||||
embed.opts = formattedParams
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// generate the embedding layers
|
|
||||||
embeddingLayers, err := embeddingLayers(workDir, embed)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
layers = append(layers, embeddingLayers...)
|
|
||||||
|
|
||||||
digests, err := getLayerDigests(layers)
|
digests, err := getLayerDigests(layers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -572,146 +536,6 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingParams struct {
|
|
||||||
model string
|
|
||||||
opts map[string]interface{}
|
|
||||||
files []string // paths to files to embed
|
|
||||||
fn func(resp api.ProgressResponse)
|
|
||||||
}
|
|
||||||
|
|
||||||
// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
|
|
||||||
func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error) {
|
|
||||||
layers := []*LayerReader{}
|
|
||||||
if len(e.files) > 0 {
|
|
||||||
// check if the model is a file path or a model name
|
|
||||||
model, err := GetModel(e.model)
|
|
||||||
if err != nil {
|
|
||||||
if !strings.Contains(err.Error(), "couldn't open file") {
|
|
||||||
return nil, fmt.Errorf("unexpected error opening model to generate embeddings: %v", err)
|
|
||||||
}
|
|
||||||
// the model may be a file path, create a model from this file
|
|
||||||
model = &Model{ModelPath: e.model}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := load(context.Background(), workDir, model, e.opts, defaultSessionDuration); err != nil {
|
|
||||||
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// this will be used to check if we already have embeddings for a file
|
|
||||||
modelInfo, err := os.Stat(model.ModelPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get model file info: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
addedFiles := make(map[string]bool) // keep track of files that have already been added
|
|
||||||
for _, filePattern := range e.files {
|
|
||||||
matchingFiles, err := filepath.Glob(filePattern)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not find files with pattern %s: %w", filePattern, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, filePath := range matchingFiles {
|
|
||||||
if addedFiles[filePath] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
addedFiles[filePath] = true
|
|
||||||
// check if we already have embeddings for this file path
|
|
||||||
layerIdentifier := fmt.Sprintf("%s:%s:%s:%d", filePath, e.model, modelInfo.ModTime().Format("2006-01-02 15:04:05"), modelInfo.Size())
|
|
||||||
digest, _ := GetSHA256Digest(strings.NewReader(layerIdentifier))
|
|
||||||
existing, err := existingFileEmbeddings(digest)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to check existing embeddings for file %s: %v", filePath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: check file type
|
|
||||||
f, err := os.Open(filePath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not open embed file: %w", err)
|
|
||||||
}
|
|
||||||
scanner := bufio.NewScanner(f)
|
|
||||||
scanner.Split(bufio.ScanLines)
|
|
||||||
|
|
||||||
data := []string{}
|
|
||||||
for scanner.Scan() {
|
|
||||||
data = append(data, scanner.Text())
|
|
||||||
}
|
|
||||||
f.Close()
|
|
||||||
|
|
||||||
// the digest of the file is set here so that the client knows a new operation is in progress
|
|
||||||
fileDigest, _ := GetSHA256Digest(bytes.NewReader([]byte(filePath)))
|
|
||||||
|
|
||||||
embeddings := []vector.Embedding{}
|
|
||||||
for i, d := range data {
|
|
||||||
if strings.TrimSpace(d) == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
e.fn(api.ProgressResponse{
|
|
||||||
Status: fmt.Sprintf("creating embeddings for file %s", filePath),
|
|
||||||
Digest: fileDigest,
|
|
||||||
Total: int64(len(data) - 1),
|
|
||||||
Completed: int64(i),
|
|
||||||
})
|
|
||||||
if len(existing[d]) > 0 {
|
|
||||||
// already have an embedding for this line
|
|
||||||
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: existing[d]})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
embed, err := loaded.llm.Embedding(context.Background(), d)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
embeddings = append(embeddings, vector.Embedding{Data: d, Vector: embed})
|
|
||||||
}
|
|
||||||
|
|
||||||
b, err := json.Marshal(embeddings)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to encode embeddings: %w", err)
|
|
||||||
}
|
|
||||||
r := bytes.NewReader(b)
|
|
||||||
|
|
||||||
layer := &LayerReader{
|
|
||||||
Layer: Layer{
|
|
||||||
MediaType: "application/vnd.ollama.image.embed",
|
|
||||||
Digest: digest,
|
|
||||||
Size: r.Size(),
|
|
||||||
},
|
|
||||||
Reader: r,
|
|
||||||
}
|
|
||||||
|
|
||||||
layers = append(layers, layer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return layers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// existingFileEmbeddings checks if we already have embeddings for a file and loads them into a look-up map
|
|
||||||
func existingFileEmbeddings(digest string) (map[string][]float64, error) {
|
|
||||||
path, err := GetBlobsPath(digest)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("embeddings blobs path: %w", err)
|
|
||||||
}
|
|
||||||
existingFileEmbeddings := make(map[string][]float64)
|
|
||||||
if _, err := os.Stat(path); err == nil {
|
|
||||||
// already have some embeddings for this file, load embeddings previously generated
|
|
||||||
file, err := os.Open(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to open existing embedding file: %s", err)
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
existing := []vector.Embedding{}
|
|
||||||
if err = json.NewDecoder(file).Decode(&existing); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for _, e := range existing {
|
|
||||||
existingFileEmbeddings[e.Data] = e.Vector
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return existingFileEmbeddings, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
|
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
|
||||||
return slices.DeleteFunc(layers, func(layer *LayerReader) bool {
|
return slices.DeleteFunc(layers, func(layer *LayerReader) bool {
|
||||||
return layer.MediaType == mediaType
|
return layer.MediaType == mediaType
|
||||||
@@ -727,8 +551,7 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err = os.Stat(fp)
|
_, err = os.Stat(fp)
|
||||||
// note: embed layers are always written since their digest doesnt indicate anything about the contents
|
if os.IsNotExist(err) || force {
|
||||||
if os.IsNotExist(err) || force || layer.MediaType == "application/vnd.ollama.image.embed" {
|
|
||||||
fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
|
fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
|
||||||
|
|
||||||
out, err := os.Create(fp)
|
out, err := os.Create(fp)
|
||||||
@@ -768,10 +591,13 @@ func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fp, err := mp.GetManifestPath(true)
|
fp, err := mp.GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return os.WriteFile(fp, manifestJSON, 0o644)
|
return os.WriteFile(fp, manifestJSON, 0o644)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -883,16 +709,19 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
|
|||||||
|
|
||||||
func CopyModel(src, dest string) error {
|
func CopyModel(src, dest string) error {
|
||||||
srcModelPath := ParseModelPath(src)
|
srcModelPath := ParseModelPath(src)
|
||||||
srcPath, err := srcModelPath.GetManifestPath(false)
|
srcPath, err := srcModelPath.GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
destModelPath := ParseModelPath(dest)
|
destModelPath := ParseModelPath(dest)
|
||||||
destPath, err := destModelPath.GetManifestPath(true)
|
destPath, err := destModelPath.GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// copy the file
|
// copy the file
|
||||||
input, err := os.ReadFile(srcPath)
|
input, err := os.ReadFile(srcPath)
|
||||||
@@ -1055,7 +884,7 @@ func DeleteModel(name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fp, err := mp.GetManifestPath(false)
|
fp, err := mp.GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1069,51 +898,27 @@ func DeleteModel(name string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ShowModelfile(model *Model) (string, error) {
|
func ShowModelfile(model *Model) (string, error) {
|
||||||
type modelTemplate struct {
|
var mt struct {
|
||||||
*Model
|
*Model
|
||||||
From string
|
From string
|
||||||
Params string
|
Parameters map[string][]any
|
||||||
}
|
}
|
||||||
|
|
||||||
var params []string
|
mt.Parameters = make(map[string][]any)
|
||||||
for k, v := range model.Options {
|
for k, v := range model.Options {
|
||||||
switch val := v.(type) {
|
if s, ok := v.([]any); ok {
|
||||||
case string:
|
mt.Parameters[k] = s
|
||||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, val))
|
continue
|
||||||
case int:
|
|
||||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.Itoa(val)))
|
|
||||||
case float64:
|
|
||||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatFloat(val, 'f', 0, 64)))
|
|
||||||
case bool:
|
|
||||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatBool(val)))
|
|
||||||
case []interface{}:
|
|
||||||
for _, nv := range val {
|
|
||||||
switch nval := nv.(type) {
|
|
||||||
case string:
|
|
||||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, nval))
|
|
||||||
case int:
|
|
||||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.Itoa(nval)))
|
|
||||||
case float64:
|
|
||||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatFloat(nval, 'f', 0, 64)))
|
|
||||||
case bool:
|
|
||||||
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatBool(nval)))
|
|
||||||
default:
|
|
||||||
log.Printf("unknown type: %s", reflect.TypeOf(nv).String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
log.Printf("unknown type: %s", reflect.TypeOf(v).String())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mt.Parameters[k] = []any{v}
|
||||||
}
|
}
|
||||||
|
|
||||||
mt := modelTemplate{
|
mt.Model = model
|
||||||
Model: model,
|
mt.From = model.ModelPath
|
||||||
From: model.OriginalModel,
|
|
||||||
Params: strings.Join(params, "\n"),
|
|
||||||
}
|
|
||||||
|
|
||||||
if mt.From == "" {
|
if model.OriginalModel != "" {
|
||||||
mt.From = model.ModelPath
|
mt.From = model.OriginalModel
|
||||||
}
|
}
|
||||||
|
|
||||||
modelFile := `# Modelfile generated by "ollama show"
|
modelFile := `# Modelfile generated by "ollama show"
|
||||||
@@ -1122,12 +927,20 @@ func ShowModelfile(model *Model) (string, error) {
|
|||||||
|
|
||||||
FROM {{ .From }}
|
FROM {{ .From }}
|
||||||
TEMPLATE """{{ .Template }}"""
|
TEMPLATE """{{ .Template }}"""
|
||||||
|
|
||||||
|
{{- if .System }}
|
||||||
SYSTEM """{{ .System }}"""
|
SYSTEM """{{ .System }}"""
|
||||||
{{ .Params }}
|
{{- end }}
|
||||||
`
|
|
||||||
for _, l := range mt.Model.AdapterPaths {
|
{{- range $adapter := .AdapterPaths }}
|
||||||
modelFile += fmt.Sprintf("ADAPTER %s\n", l)
|
ADAPTER {{ $adapter }}
|
||||||
}
|
{{- end }}
|
||||||
|
|
||||||
|
{{- range $k, $v := .Parameters }}
|
||||||
|
{{- range $parameter := $v }}
|
||||||
|
PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}`
|
||||||
|
|
||||||
tmpl, err := template.New("").Parse(modelFile)
|
tmpl, err := template.New("").Parse(modelFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1164,46 +977,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
|||||||
layers = append(layers, &manifest.Config)
|
layers = append(layers, &manifest.Config)
|
||||||
|
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
|
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if exists {
|
|
||||||
fn(api.ProgressResponse{
|
|
||||||
Status: "using existing layer",
|
|
||||||
Digest: layer.Digest,
|
|
||||||
Total: layer.Size,
|
|
||||||
Completed: layer.Size,
|
|
||||||
})
|
|
||||||
log.Printf("Layer %s already exists", layer.Digest)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
fn(api.ProgressResponse{
|
|
||||||
Status: "starting upload",
|
|
||||||
Digest: layer.Digest,
|
|
||||||
Total: layer.Size,
|
|
||||||
})
|
|
||||||
|
|
||||||
location, chunkSize, err := startUpload(ctx, mp, layer, regOpts)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("couldn't start upload: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(filepath.Base(location.Path), "sha256:") {
|
|
||||||
layer.Digest = filepath.Base(location.Path)
|
|
||||||
fn(api.ProgressResponse{
|
|
||||||
Status: "using existing layer",
|
|
||||||
Digest: layer.Digest,
|
|
||||||
Total: layer.Size,
|
|
||||||
Completed: layer.Size,
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := uploadBlob(ctx, location, layer, chunkSize, regOpts, fn); err != nil {
|
|
||||||
log.Printf("error uploading blob: %v", err)
|
log.Printf("error uploading blob: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1220,7 +994,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
|||||||
|
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
|
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
|
||||||
resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
|
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1310,10 +1084,13 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fp, err := mp.GetManifestPath(true)
|
fp, err := mp.GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
err = os.WriteFile(fp, manifestJSON, 0o644)
|
err = os.WriteFile(fp, manifestJSON, 0o644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1339,22 +1116,12 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptio
|
|||||||
|
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||||
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts)
|
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("couldn't get manifest: %v", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode >= http.StatusBadRequest {
|
|
||||||
if resp.StatusCode == http.StatusNotFound {
|
|
||||||
return nil, fmt.Errorf("model not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
var m *ManifestV2
|
var m *ManifestV2
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1398,24 +1165,7 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
|||||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to check if a blob already exists in the Docker registry
|
|
||||||
func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
|
|
||||||
requestURL := mp.BaseURL()
|
|
||||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest)
|
|
||||||
|
|
||||||
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("couldn't check for blob: %v", err)
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// Check for success: If the blob exists, the Docker registry will respond with a 200 OK
|
|
||||||
return resp.StatusCode < http.StatusBadRequest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
|
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
|
||||||
var status string
|
|
||||||
for try := 0; try < maxRetries; try++ {
|
for try := 0; try < maxRetries; try++ {
|
||||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1423,8 +1173,6 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
status = resp.Status
|
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case resp.StatusCode == http.StatusUnauthorized:
|
case resp.StatusCode == http.StatusUnauthorized:
|
||||||
auth := resp.Header.Get("www-authenticate")
|
auth := resp.Header.Get("www-authenticate")
|
||||||
@@ -1436,21 +1184,25 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|||||||
|
|
||||||
regOpts.Token = token
|
regOpts.Token = token
|
||||||
if body != nil {
|
if body != nil {
|
||||||
if _, err := body.Seek(0, io.SeekStart); err != nil {
|
body.Seek(0, io.SeekStart)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
case resp.StatusCode == http.StatusNotFound:
|
||||||
|
return nil, os.ErrNotExist
|
||||||
case resp.StatusCode >= http.StatusBadRequest:
|
case resp.StatusCode >= http.StatusBadRequest:
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("%d: %s", resp.StatusCode, body)
|
||||||
default:
|
default:
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("max retry exceeded: %v", status)
|
return nil, errMaxRetriesExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
||||||
|
@@ -12,7 +12,7 @@ func TestModelPrompt(t *testing.T) {
|
|||||||
Template: "a{{ .Prompt }}b",
|
Template: "a{{ .Prompt }}b",
|
||||||
Prompt: "<h1>",
|
Prompt: "<h1>",
|
||||||
}
|
}
|
||||||
s, err := m.Prompt(req, "")
|
s, err := m.Prompt(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@@ -85,20 +85,27 @@ func (mp ModelPath) GetShortTagname() string {
|
|||||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mp ModelPath) GetManifestPath(createDir bool) (string, error) {
|
// modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set.
|
||||||
|
// The models directory is where Ollama stores its model files and manifests.
|
||||||
|
func modelsDir() (string, error) {
|
||||||
|
if models, exists := os.LookupEnv("OLLAMA_MODELS"); exists {
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
return filepath.Join(home, ".ollama", "models"), nil
|
||||||
|
}
|
||||||
|
|
||||||
path := filepath.Join(home, ".ollama", "models", "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
||||||
if createDir {
|
func (mp ModelPath) GetManifestPath() (string, error) {
|
||||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
dir, err := modelsDir()
|
||||||
return "", err
|
if err != nil {
|
||||||
}
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return path, nil
|
return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mp ModelPath) BaseURL() *url.URL {
|
func (mp ModelPath) BaseURL() *url.URL {
|
||||||
@@ -109,12 +116,12 @@ func (mp ModelPath) BaseURL() *url.URL {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetManifestPath() (string, error) {
|
func GetManifestPath() (string, error) {
|
||||||
home, err := os.UserHomeDir()
|
dir, err := modelsDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
path := filepath.Join(home, ".ollama", "models", "manifests")
|
path := filepath.Join(dir, "manifests")
|
||||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -123,7 +130,7 @@ func GetManifestPath() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetBlobsPath(digest string) (string, error) {
|
func GetBlobsPath(digest string) (string, error) {
|
||||||
home, err := os.UserHomeDir()
|
dir, err := modelsDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -132,7 +139,7 @@ func GetBlobsPath(digest string) (string, error) {
|
|||||||
digest = strings.ReplaceAll(digest, ":", "-")
|
digest = strings.ReplaceAll(digest, ":", "-")
|
||||||
}
|
}
|
||||||
|
|
||||||
path := filepath.Join(home, ".ollama", "models", "blobs", digest)
|
path := filepath.Join(dir, "blobs", digest)
|
||||||
dirPath := filepath.Dir(path)
|
dirPath := filepath.Dir(path)
|
||||||
if digest == "" {
|
if digest == "" {
|
||||||
dirPath = path
|
dirPath = path
|
||||||
|
278
server/routes.go
278
server/routes.go
@@ -23,11 +23,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gonum.org/v1/gonum/mat"
|
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/llm"
|
"github.com/jmorganca/ollama/llm"
|
||||||
"github.com/jmorganca/ollama/vector"
|
"github.com/jmorganca/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
var mode string = gin.DebugMode
|
var mode string = gin.DebugMode
|
||||||
@@ -47,14 +46,13 @@ func init() {
|
|||||||
var loaded struct {
|
var loaded struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
llm llm.LLM
|
runner llm.LLM
|
||||||
Embeddings []vector.Embedding
|
|
||||||
|
|
||||||
expireAt time.Time
|
expireAt time.Time
|
||||||
expireTimer *time.Timer
|
expireTimer *time.Timer
|
||||||
|
|
||||||
digest string
|
*Model
|
||||||
options api.Options
|
*api.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultSessionDuration = 5 * time.Minute
|
var defaultSessionDuration = 5 * time.Minute
|
||||||
@@ -72,66 +70,52 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check if the loaded model is still running in a subprocess, in case something unexpected happened
|
// check if the loaded model is still running in a subprocess, in case something unexpected happened
|
||||||
if loaded.llm != nil {
|
if loaded.runner != nil {
|
||||||
if err := loaded.llm.Ping(ctx); err != nil {
|
if err := loaded.runner.Ping(ctx); err != nil {
|
||||||
log.Print("loaded llm process not responding, closing now")
|
log.Print("loaded llm process not responding, closing now")
|
||||||
// the subprocess is no longer running, so close it
|
// the subprocess is no longer running, so close it
|
||||||
loaded.llm.Close()
|
loaded.runner.Close()
|
||||||
loaded.llm = nil
|
loaded.runner = nil
|
||||||
loaded.digest = ""
|
loaded.Model = nil
|
||||||
|
loaded.Options = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) {
|
needLoad := loaded.runner == nil || // is there a model loaded?
|
||||||
if loaded.llm != nil {
|
loaded.ModelPath != model.ModelPath || // has the base model changed?
|
||||||
|
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
|
||||||
|
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
|
||||||
|
|
||||||
|
if needLoad {
|
||||||
|
if loaded.runner != nil {
|
||||||
log.Println("changing loaded model")
|
log.Println("changing loaded model")
|
||||||
loaded.llm.Close()
|
loaded.runner.Close()
|
||||||
loaded.llm = nil
|
loaded.runner = nil
|
||||||
loaded.digest = ""
|
loaded.Model = nil
|
||||||
|
loaded.Options = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.Embeddings != nil && len(model.Embeddings) > 0 {
|
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
|
||||||
opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work
|
|
||||||
loaded.Embeddings = model.Embeddings
|
|
||||||
}
|
|
||||||
|
|
||||||
llmModel, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// some older models are not compatible with newer versions of llama.cpp
|
||||||
|
// show a generalized compatibility error until there is a better way to
|
||||||
|
// check for model compatibility
|
||||||
|
if strings.Contains(err.Error(), "failed to load model") {
|
||||||
|
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
|
||||||
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// set cache values before modifying opts
|
loaded.Model = model
|
||||||
loaded.llm = llmModel
|
loaded.runner = llmRunner
|
||||||
loaded.digest = model.Digest
|
loaded.Options = &opts
|
||||||
loaded.options = opts
|
|
||||||
|
|
||||||
if opts.NumKeep < 0 {
|
|
||||||
promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}}, "")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
tokensWithSystem, err := llmModel.Encode(ctx, promptWithSystem)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem)
|
|
||||||
|
|
||||||
llmModel.SetOptions(opts)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// update options for the loaded llm
|
||||||
|
// TODO(mxyng): this isn't thread safe, but it should be fine for now
|
||||||
|
loaded.runner.SetOptions(opts)
|
||||||
|
|
||||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||||
|
|
||||||
if loaded.expireTimer == nil {
|
if loaded.expireTimer == nil {
|
||||||
@@ -143,13 +127,13 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if loaded.llm == nil {
|
if loaded.runner != nil {
|
||||||
return
|
loaded.runner.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded.llm.Close()
|
loaded.runner = nil
|
||||||
loaded.llm = nil
|
loaded.Model = nil
|
||||||
loaded.digest = ""
|
loaded.Options = nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,8 +148,18 @@ func GenerateHandler(c *gin.Context) {
|
|||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Model == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,22 +189,7 @@ func GenerateHandler(c *gin.Context) {
|
|||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
embedding := ""
|
prompt, err := model.Prompt(req)
|
||||||
if model.Embeddings != nil && len(model.Embeddings) > 0 {
|
|
||||||
promptEmbed, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// TODO: set embed_top from specified parameters in modelfile
|
|
||||||
embed_top := 3
|
|
||||||
topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
|
|
||||||
for _, e := range topK {
|
|
||||||
embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt, err := model.Prompt(req, embedding)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -219,6 +198,12 @@ func GenerateHandler(c *gin.Context) {
|
|||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
// an empty request loads the model
|
||||||
|
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||||
|
ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
fn := func(r api.GenerateResponse) {
|
fn := func(r api.GenerateResponse) {
|
||||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||||
loaded.expireTimer.Reset(sessionDuration)
|
loaded.expireTimer.Reset(sessionDuration)
|
||||||
@@ -233,13 +218,8 @@ func GenerateHandler(c *gin.Context) {
|
|||||||
ch <- r
|
ch <- r
|
||||||
}
|
}
|
||||||
|
|
||||||
// an empty request loads the model
|
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
|
||||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
ch <- gin.H{"error": err.Error()}
|
||||||
ch <- api.GenerateResponse{Model: req.Model, Done: true}
|
|
||||||
} else {
|
|
||||||
if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
|
|
||||||
ch <- gin.H{"error": err.Error()}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -268,8 +248,18 @@ func EmbeddingHandler(c *gin.Context) {
|
|||||||
defer loaded.mu.Unlock()
|
defer loaded.mu.Unlock()
|
||||||
|
|
||||||
var req api.EmbeddingRequest
|
var req api.EmbeddingRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Model == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,12 +275,12 @@ func EmbeddingHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !loaded.options.EmbeddingOnly {
|
if !loaded.Options.EmbeddingOnly {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embedding, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
|
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("embedding generation failed: %v", err)
|
log.Printf("embedding generation failed: %v", err)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
@@ -305,8 +295,18 @@ func EmbeddingHandler(c *gin.Context) {
|
|||||||
|
|
||||||
func PullModelHandler(c *gin.Context) {
|
func PullModelHandler(c *gin.Context) {
|
||||||
var req api.PullRequest
|
var req api.PullRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -339,8 +339,18 @@ func PullModelHandler(c *gin.Context) {
|
|||||||
|
|
||||||
func PushModelHandler(c *gin.Context) {
|
func PushModelHandler(c *gin.Context) {
|
||||||
var req api.PushRequest
|
var req api.PushRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,7 +365,9 @@ func PushModelHandler(c *gin.Context) {
|
|||||||
Insecure: req.Insecure,
|
Insecure: req.Insecure,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
if err := PushModel(ctx, req.Name, regOpts, fn); err != nil {
|
if err := PushModel(ctx, req.Name, regOpts, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
@@ -371,12 +383,20 @@ func PushModelHandler(c *gin.Context) {
|
|||||||
|
|
||||||
func CreateModelHandler(c *gin.Context) {
|
func CreateModelHandler(c *gin.Context) {
|
||||||
var req api.CreateRequest
|
var req api.CreateRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
workDir := c.GetString("workDir")
|
if req.Name == "" || req.Path == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name and path are required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -388,7 +408,7 @@ func CreateModelHandler(c *gin.Context) {
|
|||||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := CreateModel(ctx, workDir, req.Name, req.Path, fn); err != nil {
|
if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -403,8 +423,18 @@ func CreateModelHandler(c *gin.Context) {
|
|||||||
|
|
||||||
func DeleteModelHandler(c *gin.Context) {
|
func DeleteModelHandler(c *gin.Context) {
|
||||||
var req api.DeleteRequest
|
var req api.DeleteRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -433,8 +463,18 @@ func DeleteModelHandler(c *gin.Context) {
|
|||||||
|
|
||||||
func ShowModelHandler(c *gin.Context) {
|
func ShowModelHandler(c *gin.Context) {
|
||||||
var req api.ShowRequest
|
var req api.ShowRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -503,7 +543,7 @@ func GetModelInfo(name string) (*api.ShowResponse, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ListModelsHandler(c *gin.Context) {
|
func ListModelsHandler(c *gin.Context) {
|
||||||
var models []api.ModelResponse
|
models := make([]api.ModelResponse, 0)
|
||||||
fp, err := GetManifestPath()
|
fp, err := GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -544,8 +584,18 @@ func ListModelsHandler(c *gin.Context) {
|
|||||||
|
|
||||||
func CopyModelHandler(c *gin.Context) {
|
func CopyModelHandler(c *gin.Context) {
|
||||||
var req api.CopyRequest
|
var req api.CopyRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Source == "" || req.Destination == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -566,6 +616,22 @@ var defaultAllowOrigins = []string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Serve(ln net.Listener, allowOrigins []string) error {
|
func Serve(ln net.Listener, allowOrigins []string) error {
|
||||||
|
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
||||||
|
// clean up unused layers and manifests
|
||||||
|
if err := PruneLayers(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
manifestsPath, err := GetManifestPath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := PruneDirectory(manifestsPath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
config := cors.DefaultConfig()
|
config := cors.DefaultConfig()
|
||||||
config.AllowWildcard = true
|
config.AllowWildcard = true
|
||||||
|
|
||||||
@@ -611,7 +677,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
|||||||
r.Handle(method, "/api/tags", ListModelsHandler)
|
r.Handle(method, "/api/tags", ListModelsHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Listening on %s", ln.Addr())
|
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
|
||||||
s := &http.Server{
|
s := &http.Server{
|
||||||
Handler: r,
|
Handler: r,
|
||||||
}
|
}
|
||||||
@@ -621,8 +687,8 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
|||||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||||
go func() {
|
go func() {
|
||||||
<-signals
|
<-signals
|
||||||
if loaded.llm != nil {
|
if loaded.runner != nil {
|
||||||
loaded.llm.Close()
|
loaded.runner.Close()
|
||||||
}
|
}
|
||||||
os.RemoveAll(workDir)
|
os.RemoveAll(workDir)
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
@@ -631,7 +697,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
|||||||
if runtime.GOOS == "linux" {
|
if runtime.GOOS == "linux" {
|
||||||
// check compatibility to log warnings
|
// check compatibility to log warnings
|
||||||
if _, err := llm.CheckVRAM(); err != nil {
|
if _, err := llm.CheckVRAM(); err != nil {
|
||||||
log.Printf("Warning: GPU support may not enabled, check you have installed install GPU drivers: %v", err)
|
log.Printf("Warning: GPU support may not be enabled, check you have installed GPU drivers: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
415
server/upload.go
415
server/upload.go
@@ -2,218 +2,367 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/md5"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
|
"github.com/jmorganca/ollama/format"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var blobUploadManager sync.Map
|
||||||
|
|
||||||
|
type blobUpload struct {
|
||||||
|
*Layer
|
||||||
|
|
||||||
|
Total int64
|
||||||
|
Completed atomic.Int64
|
||||||
|
|
||||||
|
Parts []blobUploadPart
|
||||||
|
|
||||||
|
nextURL chan *url.URL
|
||||||
|
|
||||||
|
context.CancelFunc
|
||||||
|
|
||||||
|
done bool
|
||||||
|
err error
|
||||||
|
references atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type blobUploadPart struct {
|
||||||
|
// N is the part number
|
||||||
|
N int
|
||||||
|
Offset int64
|
||||||
|
Size int64
|
||||||
|
hash.Hash
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
redirectChunkSize int64 = 1024 * 1024 * 1024
|
numUploadParts = 64
|
||||||
regularChunkSize int64 = 95 * 1024 * 1024
|
minUploadPartSize int64 = 95 * 1000 * 1000
|
||||||
|
maxUploadPartSize int64 = 1000 * 1000 * 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) {
|
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||||
requestURL := mp.BaseURL()
|
p, err := GetBlobsPath(b.Digest)
|
||||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
if err != nil {
|
||||||
if layer.From != "" {
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.From != "" {
|
||||||
values := requestURL.Query()
|
values := requestURL.Query()
|
||||||
values.Add("mount", layer.Digest)
|
values.Add("mount", b.Digest)
|
||||||
values.Add("from", layer.From)
|
values.Add("from", b.From)
|
||||||
requestURL.RawQuery = values.Encode()
|
requestURL.RawQuery = values.Encode()
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts)
|
resp, err := makeRequestWithRetry(ctx, http.MethodPost, requestURL, nil, nil, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("couldn't start upload: %v", err)
|
return err
|
||||||
return nil, 0, err
|
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
location := resp.Header.Get("Docker-Upload-Location")
|
location := resp.Header.Get("Docker-Upload-Location")
|
||||||
chunkSize := redirectChunkSize
|
|
||||||
if location == "" {
|
if location == "" {
|
||||||
location = resp.Header.Get("Location")
|
location = resp.Header.Get("Location")
|
||||||
chunkSize = regularChunkSize
|
|
||||||
}
|
}
|
||||||
|
|
||||||
locationURL, err := url.Parse(location)
|
fi, err := os.Stat(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return locationURL, chunkSize, nil
|
b.Total = fi.Size()
|
||||||
|
|
||||||
|
var size = b.Total / numUploadParts
|
||||||
|
switch {
|
||||||
|
case size < minUploadPartSize:
|
||||||
|
size = minUploadPartSize
|
||||||
|
case size > maxUploadPartSize:
|
||||||
|
size = maxUploadPartSize
|
||||||
|
}
|
||||||
|
|
||||||
|
var offset int64
|
||||||
|
for offset < fi.Size() {
|
||||||
|
if offset+size > fi.Size() {
|
||||||
|
size = fi.Size() - offset
|
||||||
|
}
|
||||||
|
|
||||||
|
// set part.N to the current number of parts
|
||||||
|
b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size, Hash: md5.New()})
|
||||||
|
offset += size
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
|
||||||
|
|
||||||
|
requestURL, err = url.Parse(location)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
b.nextURL = make(chan *url.URL, 1)
|
||||||
|
b.nextURL <- requestURL
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
|
||||||
// TODO allow resumability
|
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
|
||||||
// TODO allow canceling uploads via DELETE
|
func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
|
||||||
|
defer blobUploadManager.Delete(b.Digest)
|
||||||
|
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||||
|
|
||||||
fp, err := GetBlobsPath(layer.Digest)
|
p, err := GetBlobsPath(b.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
b.err = err
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := os.Open(fp)
|
f, err := os.Open(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
b.err = err
|
||||||
|
return
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
pw := ProgressWriter{
|
g, inner := errgroup.WithContext(ctx)
|
||||||
status: fmt.Sprintf("uploading %s", layer.Digest),
|
g.SetLimit(numUploadParts)
|
||||||
digest: layer.Digest,
|
for i := range b.Parts {
|
||||||
total: layer.Size,
|
part := &b.Parts[i]
|
||||||
fn: fn,
|
select {
|
||||||
}
|
case <-inner.Done():
|
||||||
|
case requestURL := <-b.nextURL:
|
||||||
|
g.Go(func() error {
|
||||||
|
for try := 0; try < maxRetries; try++ {
|
||||||
|
r := io.NewSectionReader(f, part.Offset, part.Size)
|
||||||
|
err := b.uploadChunk(inner, http.MethodPatch, requestURL, r, part, opts)
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, context.Canceled):
|
||||||
|
return err
|
||||||
|
case errors.Is(err, errMaxRetriesExceeded):
|
||||||
|
return err
|
||||||
|
case err != nil:
|
||||||
|
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
for offset := int64(0); offset < layer.Size; {
|
return nil
|
||||||
chunk := layer.Size - offset
|
}
|
||||||
if chunk > chunkSize {
|
|
||||||
chunk = chunkSize
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw)
|
return errMaxRetriesExceeded
|
||||||
if err != nil {
|
|
||||||
fn(api.ProgressResponse{
|
|
||||||
Status: fmt.Sprintf("error uploading chunk: %v", err),
|
|
||||||
Digest: layer.Digest,
|
|
||||||
Total: layer.Size,
|
|
||||||
Completed: offset,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
offset += chunk
|
|
||||||
location := resp.Header.Get("Docker-Upload-Location")
|
|
||||||
if location == "" {
|
|
||||||
location = resp.Header.Get("Location")
|
|
||||||
}
|
|
||||||
|
|
||||||
requestURL, err = url.Parse(location)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
|
b.err = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
requestURL := <-b.nextURL
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, part := range b.Parts {
|
||||||
|
sb.Write(part.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
md5sum := md5.Sum([]byte(sb.String()))
|
||||||
|
|
||||||
values := requestURL.Query()
|
values := requestURL.Query()
|
||||||
values.Add("digest", layer.Digest)
|
values.Add("digest", b.Digest)
|
||||||
|
values.Add("etag", fmt.Sprintf("%x-%d", md5sum, len(b.Parts)))
|
||||||
requestURL.RawQuery = values.Encode()
|
requestURL.RawQuery = values.Encode()
|
||||||
|
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
headers.Set("Content-Type", "application/octet-stream")
|
headers.Set("Content-Type", "application/octet-stream")
|
||||||
headers.Set("Content-Length", "0")
|
headers.Set("Content-Length", "0")
|
||||||
|
|
||||||
// finish the upload
|
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
|
||||||
resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts)
|
if err != nil {
|
||||||
|
b.err = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
b.done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, rs io.ReadSeeker, part *blobUploadPart, opts *RegistryOptions) error {
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set("Content-Type", "application/octet-stream")
|
||||||
|
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
|
||||||
|
headers.Set("X-Redirect-Uploads", "1")
|
||||||
|
|
||||||
|
if method == http.MethodPatch {
|
||||||
|
headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1))
|
||||||
|
}
|
||||||
|
|
||||||
|
buw := blobUploadWriter{blobUpload: b}
|
||||||
|
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(rs, io.MultiWriter(&buw, part.Hash)), opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("couldn't finish upload: %v", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode >= http.StatusBadRequest {
|
location := resp.Header.Get("Docker-Upload-Location")
|
||||||
body, _ := io.ReadAll(resp.Body)
|
if location == "" {
|
||||||
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
|
location = resp.Header.Get("Location")
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) {
|
|
||||||
sectionReader := io.NewSectionReader(r, offset, limit)
|
|
||||||
|
|
||||||
headers := make(http.Header)
|
|
||||||
headers.Set("Content-Type", "application/octet-stream")
|
|
||||||
headers.Set("Content-Length", strconv.Itoa(int(limit)))
|
|
||||||
headers.Set("X-Redirect-Uploads", "1")
|
|
||||||
|
|
||||||
if method == http.MethodPatch {
|
|
||||||
headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for try := 0; try < maxRetries; try++ {
|
nextURL, err := url.Parse(location)
|
||||||
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sectionReader, pw), opts)
|
if err != nil {
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
return err
|
||||||
return nil, err
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case resp.StatusCode == http.StatusTemporaryRedirect:
|
||||||
|
b.nextURL <- nextURL
|
||||||
|
|
||||||
|
redirectURL, err := resp.Location()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
switch {
|
for try := 0; try < maxRetries; try++ {
|
||||||
case resp.StatusCode == http.StatusTemporaryRedirect:
|
rs.Seek(0, io.SeekStart)
|
||||||
location, err := resp.Location()
|
b.Completed.Add(-buw.written)
|
||||||
if err != nil {
|
buw.written = 0
|
||||||
return nil, err
|
part.Hash = md5.New()
|
||||||
}
|
err := b.uploadChunk(ctx, http.MethodPut, redirectURL, rs, part, nil)
|
||||||
|
switch {
|
||||||
pw.completed = offset
|
case errors.Is(err, context.Canceled):
|
||||||
if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil {
|
return err
|
||||||
// retry
|
case errors.Is(err, errMaxRetriesExceeded):
|
||||||
log.Printf("retrying redirected upload: %v", err)
|
return err
|
||||||
|
case err != nil:
|
||||||
|
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return nil
|
||||||
case resp.StatusCode == http.StatusUnauthorized:
|
|
||||||
auth := resp.Header.Get("www-authenticate")
|
|
||||||
authRedir := ParseAuthRedirectString(auth)
|
|
||||||
token, err := getAuthToken(ctx, authRedir)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
opts.Token = token
|
|
||||||
|
|
||||||
pw.completed = offset
|
|
||||||
sectionReader = io.NewSectionReader(r, offset, limit)
|
|
||||||
continue
|
|
||||||
case resp.StatusCode >= http.StatusBadRequest:
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return errMaxRetriesExceeded
|
||||||
|
|
||||||
|
case resp.StatusCode == http.StatusUnauthorized:
|
||||||
|
auth := resp.Header.Get("www-authenticate")
|
||||||
|
authRedir := ParseAuthRedirectString(auth)
|
||||||
|
token, err := getAuthToken(ctx, authRedir)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
opts.Token = token
|
||||||
|
fallthrough
|
||||||
|
case resp.StatusCode >= http.StatusBadRequest:
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rs.Seek(0, io.SeekStart)
|
||||||
|
b.Completed.Add(-buw.written)
|
||||||
|
buw.written = 0
|
||||||
|
return fmt.Errorf("http status %d %s: %s", resp.StatusCode, resp.Status, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("max retries exceeded")
|
if method == http.MethodPatch {
|
||||||
|
b.nextURL <- nextURL
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProgressWriter struct {
|
func (b *blobUpload) acquire() {
|
||||||
status string
|
b.references.Add(1)
|
||||||
digest string
|
|
||||||
bucket int64
|
|
||||||
completed int64
|
|
||||||
total int64
|
|
||||||
fn func(api.ProgressResponse)
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pw *ProgressWriter) Write(b []byte) (int, error) {
|
func (b *blobUpload) release() {
|
||||||
pw.mu.Lock()
|
if b.references.Add(-1) == 0 {
|
||||||
defer pw.mu.Unlock()
|
b.CancelFunc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
n := len(b)
|
func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
|
||||||
pw.bucket += int64(n)
|
b.acquire()
|
||||||
|
defer b.release()
|
||||||
|
|
||||||
// throttle status updates to not spam the client
|
ticker := time.NewTicker(60 * time.Millisecond)
|
||||||
if pw.bucket >= 1024*1024 || pw.completed+pw.bucket >= pw.total {
|
for {
|
||||||
pw.completed += pw.bucket
|
select {
|
||||||
pw.fn(api.ProgressResponse{
|
case <-ticker.C:
|
||||||
Status: pw.status,
|
case <-ctx.Done():
|
||||||
Digest: pw.digest,
|
return ctx.Err()
|
||||||
Total: pw.total,
|
}
|
||||||
Completed: pw.completed,
|
|
||||||
|
fn(api.ProgressResponse{
|
||||||
|
Status: fmt.Sprintf("uploading %s", b.Digest),
|
||||||
|
Digest: b.Digest,
|
||||||
|
Total: b.Total,
|
||||||
|
Completed: b.Completed.Load(),
|
||||||
})
|
})
|
||||||
|
|
||||||
pw.bucket = 0
|
if b.done || b.err != nil {
|
||||||
|
return b.err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type blobUploadWriter struct {
|
||||||
|
written int64
|
||||||
|
*blobUpload
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *blobUploadWriter) Write(p []byte) (n int, err error) {
|
||||||
|
n = len(p)
|
||||||
|
b.written += int64(n)
|
||||||
|
b.Completed.Add(int64(n))
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||||
|
requestURL := mp.BaseURL()
|
||||||
|
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
||||||
|
|
||||||
|
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, os.ErrNotExist):
|
||||||
|
case err != nil:
|
||||||
|
return err
|
||||||
|
default:
|
||||||
|
defer resp.Body.Close()
|
||||||
|
fn(api.ProgressResponse{
|
||||||
|
Status: fmt.Sprintf("uploading %s", layer.Digest),
|
||||||
|
Digest: layer.Digest,
|
||||||
|
Total: layer.Size,
|
||||||
|
Completed: layer.Size,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
|
||||||
|
upload := data.(*blobUpload)
|
||||||
|
if !ok {
|
||||||
|
requestURL := mp.BaseURL()
|
||||||
|
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
||||||
|
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
|
||||||
|
blobUploadManager.Delete(layer.Digest)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go upload.Run(context.Background(), opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
return upload.Wait(ctx, fn)
|
||||||
|
}
|
||||||
|
@@ -1,69 +0,0 @@
|
|||||||
package vector
|
|
||||||
|
|
||||||
import (
|
|
||||||
"container/heap"
|
|
||||||
"sort"
|
|
||||||
|
|
||||||
"gonum.org/v1/gonum/mat"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Embedding struct {
|
|
||||||
Vector []float64 // the embedding vector
|
|
||||||
Data string // the data represted by the embedding
|
|
||||||
}
|
|
||||||
|
|
||||||
type EmbeddingSimilarity struct {
|
|
||||||
Embedding Embedding // the embedding that was used to calculate the similarity
|
|
||||||
Similarity float64 // the similarity between the embedding and the query
|
|
||||||
}
|
|
||||||
|
|
||||||
type Heap []EmbeddingSimilarity
|
|
||||||
|
|
||||||
func (h Heap) Len() int { return len(h) }
|
|
||||||
func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity }
|
|
||||||
func (h Heap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
|
||||||
func (h *Heap) Push(e any) {
|
|
||||||
*h = append(*h, e.(EmbeddingSimilarity))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Heap) Pop() interface{} {
|
|
||||||
old := *h
|
|
||||||
n := len(old)
|
|
||||||
x := old[n-1]
|
|
||||||
*h = old[0 : n-1]
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
|
|
||||||
// This value will range from -1 to 1, where 1 means the vectors are identical.
|
|
||||||
func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
|
|
||||||
dotProduct := mat.Dot(vec1, vec2)
|
|
||||||
norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
|
|
||||||
|
|
||||||
if norms == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return dotProduct / norms
|
|
||||||
}
|
|
||||||
|
|
||||||
func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity {
|
|
||||||
h := &Heap{}
|
|
||||||
heap.Init(h)
|
|
||||||
for _, emb := range embeddings {
|
|
||||||
similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector))
|
|
||||||
heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity})
|
|
||||||
if h.Len() > k {
|
|
||||||
heap.Pop(h)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
topK := make([]EmbeddingSimilarity, 0, h.Len())
|
|
||||||
for h.Len() > 0 {
|
|
||||||
topK = append(topK, heap.Pop(h).(EmbeddingSimilarity))
|
|
||||||
}
|
|
||||||
sort.Slice(topK, func(i, j int) bool {
|
|
||||||
return topK[i].Similarity > topK[j].Similarity
|
|
||||||
})
|
|
||||||
|
|
||||||
return topK
|
|
||||||
}
|
|
Reference in New Issue
Block a user