Compare commits

...

84 Commits

Author SHA1 Message Date
Matt Williams
0d4fa34aee Added mention of the NOPRUNE env var
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-08 17:37:12 -08:00
Jeffrey Morgan
b74580c913 Update api.md 2023-12-08 16:02:07 -08:00
Bruce MacDonald
7e9405fd07 fix: encode full previous prompt in context (#1424) 2023-12-08 16:53:51 -05:00
Bruce MacDonald
3b0b8930d4 fix: only flush template in chat when current role encountered (#1426) 2023-12-08 16:44:24 -05:00
Bruce MacDonald
e3f925fc1b fix: restore modelfile system in prompt template (#1425) 2023-12-08 14:20:19 -05:00
Jeffrey Morgan
2a2289fb6b Update api.md 2023-12-08 09:36:45 -08:00
Matt Williams
dd427f499a Merge pull request #1419 from jmorganca/mattw/typescript-simplechat
Simple chat example for typescript
2023-12-07 14:42:24 -08:00
Michael Yang
2ae573c7ed Merge pull request #1421 from jmorganca/mxyng/fix-newline
fix redundant newline
2023-12-07 13:47:23 -08:00
Matt Williams
02fe26c44b update the readme as per bruce
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-07 13:46:30 -08:00
Michael Yang
16c7548460 fix redundant newline 2023-12-07 13:44:45 -08:00
Matt Williams
fa75998c0d Update examples/typescript-simplechat/readme.md
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:40:54 -08:00
Matt Williams
5344f886c8 Update examples/typescript-simplechat/client.ts
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:40:37 -08:00
Matt Williams
6cc823c9b5 Update examples/typescript-simplechat/client.ts
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:39:59 -08:00
Matt Williams
b84d34e632 Update examples/typescript-simplechat/readme.md
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:39:33 -08:00
Matt Williams
30229a913c Update examples/typescript-simplechat/client.ts
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:39:24 -08:00
Matt Williams
1ade380bd7 Simple chat example for typescript
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-07 11:48:25 -08:00
Jeffrey Morgan
ba264e9da8 add future version note to chat api docs 2023-12-07 09:42:15 -08:00
Matt Williams
a2405ec831 Merge pull request #1409 from jmorganca/mattw/python-simplechat
Simple chat example
2023-12-06 15:49:45 -08:00
Matt Williams
ce809bb529 Merge branch 'mattw/python-simplechat' of github.com:jmorganca/ollama into mattw/python-simplechat 2023-12-06 15:48:42 -08:00
Matt Williams
76bc4d0458 Cleanup as per Bruce
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-06 15:44:40 -08:00
Bruce MacDonald
4a02945a15 Update examples/python-simplechat/client.py 2023-12-06 18:36:45 -05:00
Matt Williams
aec742b6d2 Update examples/python-simplechat/readme.md
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-06 15:30:45 -08:00
Matt Williams
f337642e94 Update examples/python-simplechat/readme.md
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-06 15:30:35 -08:00
Matt Williams
51131cc6e2 Update examples/python-simplechat/client.py
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-06 15:30:10 -08:00
Matt Williams
43027789dc Simple chat example
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-06 14:35:58 -08:00
Xe Iaso
f9b7d65e2b docs/tutorials: add bit on how to use Fly GPUs on-demand with Ollama (#1406)
Signed-off-by: Xe Iaso <xe@camellia.finch-kitefin.ts.net>
2023-12-06 14:14:02 -08:00
Michael Yang
1f05d77110 Merge pull request #1244 from jmorganca/brucemacd/no-fail-template
do not fail on unsupported template variables
2023-12-06 13:23:04 -08:00
Michael Yang
c3ff36088b Merge pull request #774 from jmorganca/mxyng/server-version
add version api and show server version in cli
2023-12-06 13:22:55 -08:00
Samuel Calderon
13524b5e72 List "Send chat messages" in table of contents (#1399)
Thank you @calderonsamuel
2023-12-06 12:34:27 -08:00
Michael Yang
f1b049fed8 Merge pull request #1377 from jmorganca/mxyng/qwen
update for qwen
2023-12-06 12:31:51 -08:00
Jeffrey Morgan
97c5696945 fix base urls in chat examples 2023-12-06 12:10:20 -08:00
Bruce MacDonald
47d4e22673 use missingkey in set empty interface when missing 2023-12-05 15:49:05 -08:00
Michael Yang
32f62fbb8e Merge pull request #1334 from jmorganca/mxyng/load-projectors
load projectors
2023-12-05 14:40:53 -08:00
Michael Yang
5d75505ebd return model configuration in generate 2023-12-05 14:39:02 -08:00
Michael Yang
b9495ea162 load projectors 2023-12-05 14:36:12 -08:00
Michael Yang
409bb9674e Merge pull request #1308 from jmorganca/mxyng/split-from
split from into one or more models
2023-12-05 14:33:03 -08:00
Michael Yang
d3479c07a1 Merge pull request #1250 from jmorganca/mxyng/create-layer
refactor layer creation
2023-12-05 14:32:52 -08:00
Michael Yang
b12f1b984f Merge pull request #1393 from jmorganca/mxyng/fix-whitespace
fix: trim space in modelfile fields
2023-12-05 12:18:01 -08:00
Bruce MacDonald
195e3d9dbd chat api endpoint (#1392) 2023-12-05 14:57:33 -05:00
Michael Yang
38fe1a368b fix: trim space in modelfile fields 2023-12-05 11:57:29 -08:00
Michael Yang
4b77fcb2b9 comments 2023-12-05 09:43:50 -08:00
Michael Yang
cde13bcdea cmd: only print server version when different 2023-12-05 09:36:01 -08:00
Michael Yang
0f0cd265a7 cmd: add server version 2023-12-05 09:36:01 -08:00
Michael Yang
0db4706ec2 api: add version api handler 2023-12-05 09:36:01 -08:00
Michael Yang
1ebdbd9694 server: add version handler 2023-12-05 09:36:01 -08:00
Michael Yang
5c59455b59 cmd: use existing cmd context 2023-12-05 09:36:01 -08:00
Jeffrey Morgan
00d06619a1 Revert "chat api (#991)" while context variable is fixed
This reverts commit 7a0899d62d.
2023-12-04 21:16:27 -08:00
Matt Williams
f1ef3f9947 remove mention of gpt-neox in import (#1381)
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-04 20:58:10 -08:00
Michael Yang
5a5dca13b2 comments 2023-12-04 16:59:23 -08:00
Michael Yang
7232f1fa41 go mod tidy 2023-12-04 16:59:23 -08:00
Michael Yang
72e7a49aa9 seek instead of copyn 2023-12-04 16:59:23 -08:00
Michael Yang
a3737cbd33 use NewLayer for CreateBlobHandler 2023-12-04 16:59:23 -08:00
Michael Yang
998f1785b6 add modelfamilies 2023-12-04 16:59:23 -08:00
Michael Yang
70a93057cd refactor layer creation
previous layer creation was not ideal because:

1. it required reading the input file multiple times, once to calculate
   the sha256 checksum, another to write it to disk, and potentially one
   more to decode the underlying gguf
2. used io.ReadSeeker which is prone to user error. if the file isn't
   reset correctly or in the right place, it could end up reading an
   empty file

there are also some brittleness when reading existing layers else
writing the inherited layers will error reading an already closed file

this commit aims to fix these issues by restructuring layer creation.

1. it will now write the layer to a temporary file as well as the hash
   function and move it to the final location on Commit
2. layers are read once once when copied to the destination. exception
   is raw model files which still requires a second read to decode the
   model metadata
2023-12-04 16:59:23 -08:00
Michael Yang
2cb0fa7d40 split from into one or more models 2023-12-04 16:59:23 -08:00
Michael Yang
b2816bca67 unnecessary ReadSeeker for DecodeGGML 2023-12-04 16:59:23 -08:00
Patrick Devine
bf704423c5 revert cli to use /api/generate (#1383) 2023-12-04 16:35:29 -08:00
Bruce MacDonald
7a0899d62d chat api (#991)
- update chat docs
- add messages chat endpoint
- remove deprecated context and template generate parameters from docs
- context and template are still supported for the time being and will continue to work as expected
- add partial response to chat history
2023-12-04 18:01:06 -05:00
Michael Yang
0cca1486dd Merge pull request #1376 from jmorganca/mxyng/rocky-install
install: fix rocky kernel packages
2023-12-04 14:23:43 -08:00
Patrick Devine
2113c9d31a make linewrap still work when the terminal width has changed (#1350) 2023-12-04 14:14:56 -08:00
Michael Yang
6deebf2489 update for qwen 2023-12-04 11:38:05 -08:00
Michael Yang
95cb38ae47 install: fix rocky kernel packages 2023-12-04 11:10:42 -08:00
ruecat
1f126afb2d Ollama Telegram Bot (#1364)
* Add "ollama-telegram" to Extensions & Plugins

* Update README.md
2023-12-03 11:19:55 -08:00
Jeffrey Morgan
f6201a7a6c remove duplicate community integration in README.md 2023-12-02 21:18:13 -08:00
Michael Yang
b3f6c6598f Merge pull request #1349 from jmorganca/mxyng/ctrl-z
handle ctrl+z
2023-12-01 16:21:49 -08:00
Michael Yang
88620e983a handle ctrl+z 2023-12-01 16:15:20 -08:00
Michael Yang
cedae0d17a Merge pull request #1347 from jshph/adapter-hash
Fix adapter loading from SHA hash
2023-12-01 11:08:25 -08:00
Joshua Pham
bb80a597db Fix adapter loading from SHA hash 2023-12-01 13:50:55 -05:00
Patrick Devine
6681d37861 allow setting the system and template for prompts in the repl (#1335) 2023-12-01 09:28:35 -08:00
Michael Yang
0409c1fa59 docker: set PATH, LD_LIBRARY_PATH, and capabilities (#1336)
* docker: set PATH, LD_LIBRARY_PATH, and capabilities

* example: update k8s gpu manifest
2023-11-30 21:16:56 -08:00
Michael Yang
b56e92470a Merge pull request #1229 from jmorganca/mxyng/calculate-as-you-go
revert checksum calculation to calculate-as-you-go
2023-11-30 10:54:38 -08:00
Jeffrey Morgan
5687f1a0cf fix unexpected end of response errors when cancelling in ollama run 2023-11-30 00:30:21 -05:00
James Radtke
7eda3d0c55 Corrected transposed 129 to 192 for OLLAMA_ORIGINS example (#1325) 2023-11-29 22:44:17 -05:00
Bruce MacDonald
7194a07d4d Add chatd to example projects 2023-11-29 21:18:21 -05:00
Michael Yang
13efd5f218 upload: fix PUT retry 2023-11-29 16:38:35 -08:00
Michael Yang
c4bdfffd96 upload: separate progress tracking 2023-11-29 16:38:33 -08:00
Michael Yang
26c63418e0 new hasher 2023-11-29 14:52:41 -08:00
Michael Yang
2799784ac8 revert checksum calculation to calculate-as-you-go 2023-11-29 13:47:58 -08:00
Alec Hammond
91897a606f Add OllamaEmbeddings to python LangChain example (#994)
* Add OllamaEmbeddings to python LangChain example

* typo

---------

Co-authored-by: Alec Hammond <alechammond@fb.com>
2023-11-29 16:25:39 -05:00
Bruce MacDonald
96122b7271 validate model tags on copy (#1323) 2023-11-29 15:54:29 -05:00
jeremiahbuckley
39be7fdb98 fix rhel cuda install (#1321)
Co-authored-by: Cloud User <azureuser@testgpu2.hqzwom21okjenksna4y3c4ymjd.phxx.internal.cloudapp.net>
2023-11-29 14:55:15 -05:00
Timothy Jaeryang Baek
c2e3b89176 fix: disable ':' in tag names (#1280)
Co-authored-by: rootedbox
2023-11-29 13:33:45 -05:00
Patrick Devine
cde31cb220 Allow setting parameters in the REPL (#1294) 2023-11-29 09:56:42 -08:00
ToasterUwU
63097607b2 Correct MacOS Host port example (#1301) 2023-11-29 11:44:03 -05:00
33 changed files with 1811 additions and 639 deletions

View File

@@ -19,5 +19,11 @@ RUN apt-get update && apt-get install -y ca-certificates
COPY --from=0 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
EXPOSE 11434
ENV OLLAMA_HOST 0.0.0.0
# set some environment variable for better NVIDIA compatibility
ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENTRYPOINT ["/bin/ollama"]
CMD ["serve"]

View File

@@ -214,14 +214,21 @@ curl http://localhost:11434/api/generate -d '{
}'
```
Or send a chat message (coming in 0.1.14):
```
curl http://localhost:11434/api/chat -d '{
"model": "mistral",
"messages": [
{ "role": "user", "content": "why is the sky blue?" }
]
}'
```
See the [API documentation](./docs/api.md) for all endpoints.
## Community Integrations
### Mobile
- [Mobile Artificial Intelligence Distribution](https://github.com/MaidFoundation/Maid) (Maid)
### Web & Desktop
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
@@ -233,6 +240,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [big-AGI](https://github.com/enricoros/big-agi/blob/main/docs/config-ollama.md)
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
- [Amica](https://github.com/semperai/amica)
- [chatd](https://github.com/BruceMacD/chatd)
### Terminal
@@ -276,6 +284,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Logseq Ollama plugin](https://github.com/omagdy7/ollama-logseq)
- [Dagger Chatbot](https://github.com/samalba/dagger-chatbot)
- [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot)
- [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
- [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)

View File

@@ -221,6 +221,19 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
})
}
type ChatResponseFunc func(ChatResponse) error
func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
var resp ChatResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
}
type PullProgressFunc func(ProgressResponse) error
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
@@ -311,3 +324,15 @@ func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) err
return nil
}
func (c *Client) Version(ctx context.Context) (string, error) {
var version struct {
Version string `json:"version"`
}
if err := c.do(ctx, http.MethodGet, "/api/version", nil, &version); err != nil {
return "", err
}
return version.Version, nil
}

View File

@@ -6,6 +6,7 @@ import (
"math"
"os"
"reflect"
"strconv"
"strings"
"time"
)
@@ -43,6 +44,39 @@ type GenerateRequest struct {
Options map[string]interface{} `json:"options"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"`
Format string `json:"format"`
Options map[string]interface{} `json:"options"`
}
type Message struct {
Role string `json:"role"` // one of ["system", "user", "assistant"]
Content string `json:"content"`
}
type ChatResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message *Message `json:"message,omitempty"`
Done bool `json:"done"`
Metrics
}
type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration time.Duration `json:"eval_duration,omitempty"`
}
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
type Options struct {
Runner
@@ -169,42 +203,47 @@ type GenerateResponse struct {
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
ModelConfiguration ModelConfiguration `json:"model_configuration"`
Done bool `json:"done"`
Context []int `json:"context,omitempty"`
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration time.Duration `json:"eval_duration,omitempty"`
Metrics
}
func (r *GenerateResponse) Summary() {
if r.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
type ModelConfiguration struct {
ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"`
ModelFamilies []string `json:"model_families"`
ModelType string `json:"model_type"`
FileType string `json:"file_type"`
}
func (m *Metrics) Summary() {
if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
}
if r.LoadDuration > 0 {
fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration)
if m.LoadDuration > 0 {
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
}
if r.PromptEvalCount > 0 {
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
if m.PromptEvalCount > 0 {
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", m.PromptEvalCount)
}
if r.PromptEvalDuration > 0 {
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration)
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds())
if m.PromptEvalDuration > 0 {
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", m.PromptEvalDuration)
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(m.PromptEvalCount)/m.PromptEvalDuration.Seconds())
}
if r.EvalCount > 0 {
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount)
if m.EvalCount > 0 {
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", m.EvalCount)
}
if r.EvalDuration > 0 {
fmt.Fprintf(os.Stderr, "eval duration: %s\n", r.EvalDuration)
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds())
if m.EvalDuration > 0 {
fmt.Fprintf(os.Stderr, "eval duration: %s\n", m.EvalDuration)
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(m.EvalCount)/m.EvalDuration.Seconds())
}
}
@@ -360,3 +399,63 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
return nil
}
// FormatParams converts specified parameter options to their correct types
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
opts := Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
// build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField)
for _, field := range reflect.VisibleFields(typeOpts) {
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
if jsonTag != "" {
jsonOpts[jsonTag] = field
}
}
out := make(map[string]interface{})
// iterate params and set values based on json struct tags
for key, vals := range params {
if opt, ok := jsonOpts[key]; !ok {
return nil, fmt.Errorf("unknown parameter '%s'", key)
} else {
field := valueOpts.FieldByName(opt.Name)
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Float32:
floatVal, err := strconv.ParseFloat(vals[0], 32)
if err != nil {
return nil, fmt.Errorf("invalid float value %s", vals)
}
out[key] = float32(floatVal)
case reflect.Int:
intVal, err := strconv.ParseInt(vals[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid int value %s", vals)
}
out[key] = intVal
case reflect.Bool:
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)
}
out[key] = boolVal
case reflect.String:
out[key] = vals[0]
case reflect.Slice:
// TODO: only string slices are supported right now
out[key] = vals
default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}
}
}
}
return out, nil
}

View File

@@ -133,7 +133,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile)}
if err := client.Create(context.Background(), &request, fn); err != nil {
if err := client.Create(cmd.Context(), &request, fn); err != nil {
return err
}
@@ -148,7 +148,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
name := args[0]
// check if the model exists on the server
_, err = client.Show(context.Background(), &api.ShowRequest{Name: name})
_, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
@@ -208,7 +208,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
}
request := api.PushRequest{Name: args[0], Insecure: insecure}
if err := client.Push(context.Background(), &request, fn); err != nil {
if err := client.Push(cmd.Context(), &request, fn); err != nil {
return err
}
@@ -222,7 +222,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
return err
}
models, err := client.List(context.Background())
models, err := client.List(cmd.Context())
if err != nil {
return err
}
@@ -257,7 +257,7 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
for _, name := range args {
req := api.DeleteRequest{Name: name}
if err := client.Delete(context.Background(), &req); err != nil {
if err := client.Delete(cmd.Context(), &req); err != nil {
return err
}
fmt.Printf("deleted '%s'\n", name)
@@ -322,7 +322,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
}
req := api.ShowRequest{Name: args[0]}
resp, err := client.Show(context.Background(), &req)
resp, err := client.Show(cmd.Context(), &req)
if err != nil {
return err
}
@@ -350,7 +350,7 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
}
req := api.CopyRequest{Source: args[0], Destination: args[1]}
if err := client.Copy(context.Background(), &req); err != nil {
if err := client.Copy(cmd.Context(), &req); err != nil {
return err
}
fmt.Printf("copied '%s' to '%s'\n", args[0], args[1])
@@ -404,7 +404,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
}
request := api.PullRequest{Name: args[0], Insecure: insecure}
if err := client.Pull(context.Background(), &request, fn); err != nil {
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
return err
}
@@ -412,10 +412,19 @@ func PullHandler(cmd *cobra.Command, args []string) error {
}
func RunGenerate(cmd *cobra.Command, args []string) error {
interactive := true
opts := generateOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
}
format, err := cmd.Flags().GetString("format")
if err != nil {
return err
}
opts.Format = format
prompts := args[1:]
@@ -427,34 +436,40 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
}
prompts = append([]string{string(in)}, prompts...)
opts.WordWrap = false
interactive = false
}
// output is being piped
if !term.IsTerminal(int(os.Stdout.Fd())) {
return generate(cmd, args[0], strings.Join(prompts, " "), false, format)
opts.Prompt = strings.Join(prompts, " ")
if len(prompts) > 0 {
interactive = false
}
wordWrap := os.Getenv("TERM") == "xterm-256color"
nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil {
return err
}
if nowrap {
wordWrap = false
opts.WordWrap = !nowrap
if !interactive {
return generate(cmd, opts)
}
// prompts are provided via stdin or args so don't enter interactive mode
if len(prompts) > 0 {
return generate(cmd, args[0], strings.Join(prompts, " "), wordWrap, format)
}
return generateInteractive(cmd, args[0], wordWrap, format)
return generateInteractive(cmd, opts)
}
type generateContextKey string
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format string) error {
type generateOptions struct {
Model string
Prompt string
WordWrap bool
Format string
System string
Template string
Options map[string]interface{}
}
func generate(cmd *cobra.Command, opts generateOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
@@ -475,34 +490,39 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {
wordWrap = false
opts.WordWrap = false
}
cancelCtx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
var abort bool
go func() {
<-sigChan
cancel()
abort = true
}()
var currentLineLength int
var wordBuffer string
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format}
fn := func(response api.GenerateResponse) error {
p.StopAndClear()
latest = response
if wordWrap {
termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
if opts.WordWrap && termWidth >= 10 {
for _, ch := range response.Response {
if currentLineLength+1 > termWidth-5 {
if len(wordBuffer) > termWidth-10 {
fmt.Printf("%s%c", wordBuffer, ch)
wordBuffer = ""
currentLineLength = 0
continue
}
// backtrack the length of the last word and clear to the end of the line
fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer))
fmt.Printf("%s%c", wordBuffer, ch)
@@ -522,28 +542,38 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
}
}
} else {
fmt.Print(response.Response)
fmt.Printf("%s%s", wordBuffer, response.Response)
if len(wordBuffer) > 0 {
wordBuffer = ""
}
}
return nil
}
if err := client.Generate(cancelCtx, &request, fn); err != nil {
if strings.Contains(err.Error(), "context canceled") && abort {
request := api.GenerateRequest{
Model: opts.Model,
Prompt: opts.Prompt,
Context: generateContext,
Format: opts.Format,
System: opts.System,
Template: opts.Template,
Options: opts.Options,
}
if err := client.Generate(ctx, &request, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
return err
}
if prompt != "" {
if opts.Prompt != "" {
fmt.Println()
fmt.Println()
}
if !latest.Done {
if abort {
return nil
}
return errors.New("unexpected end of response")
return nil
}
verbose, err := cmd.Flags().GetBool("verbose")
@@ -555,16 +585,26 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
latest.Summary()
}
ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
return nil
}
func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error {
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilinePrompt
MultilineSystem
MultilineTemplate
)
func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
// load the model
if err := generate(cmd, model, "", false, ""); err != nil {
loadOpts := generateOptions{
Model: opts.Model,
Prompt: "",
}
if err := generate(cmd, loadOpts); err != nil {
return err
}
@@ -581,14 +621,17 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
usageSet := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode")
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
fmt.Fprintln(os.Stderr, " /set system <string> Set system prompt")
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode")
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, "")
}
@@ -602,6 +645,22 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
fmt.Fprintln(os.Stderr, "")
}
// only list out the most common parameters
usageParameters := func() {
fmt.Fprintln(os.Stderr, "Available Parameters:")
fmt.Fprintln(os.Stderr, " /set parameter seed <int> Random number seed")
fmt.Fprintln(os.Stderr, " /set parameter num_predict <int> Max number of tokens to predict")
fmt.Fprintln(os.Stderr, " /set parameter top_k <int> Pick from top k num of tokens")
fmt.Fprintln(os.Stderr, " /set parameter top_p <float> Pick token based on sum of probabilities")
fmt.Fprintln(os.Stderr, " /set parameter num_ctx <int> Set the context size")
fmt.Fprintln(os.Stderr, " /set parameter temperature <float> Set creativity level")
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions")
fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU")
fmt.Fprintln(os.Stderr, " /set parameter stop \"<string>\", ... Set the stop parameters")
fmt.Fprintln(os.Stderr, "")
}
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
@@ -615,6 +674,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste)
var multiline MultilineState
var prompt string
for {
@@ -649,8 +709,21 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
prompt = strings.TrimPrefix(prompt, `"""`)
scanner.Prompt.UseAlt = false
switch multiline {
case MultilineSystem:
opts.System = prompt
prompt = ""
fmt.Println("Set system template.")
case MultilineTemplate:
opts.Template = prompt
prompt = ""
fmt.Println("Set model template.")
}
multiline = MultilineNone
case strings.HasPrefix(line, `"""`) && len(prompt) == 0:
scanner.Prompt.UseAlt = true
multiline = MultilinePrompt
prompt += line + "\n"
continue
case scanner.Pasting:
@@ -670,10 +743,10 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
case "nohistory":
scanner.HistoryDisable()
case "wordwrap":
wordWrap = true
opts.WordWrap = true
fmt.Println("Set 'wordwrap' mode.")
case "nowordwrap":
wordWrap = false
opts.WordWrap = false
fmt.Println("Set 'nowordwrap' mode.")
case "verbose":
cmd.Flags().Set("verbose", "true")
@@ -685,12 +758,59 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
} else {
format = args[2]
opts.Format = args[2]
fmt.Printf("Set format to '%s' mode.\n", args[2])
}
case "noformat":
format = ""
opts.Format = ""
fmt.Println("Disabled format.")
case "parameter":
if len(args) < 4 {
usageParameters()
continue
}
var params []string
for _, p := range args[3:] {
params = append(params, p)
}
fp, err := api.FormatParams(map[string][]string{args[2]: params})
if err != nil {
fmt.Printf("Couldn't set parameter: %q\n\n", err)
continue
}
fmt.Printf("Set parameter '%s' to '%s'\n\n", args[2], strings.Join(params, ", "))
opts.Options[args[2]] = fp[args[2]]
case "system", "template":
if len(args) < 3 {
usageSet()
continue
}
line := strings.Join(args[2:], " ")
line = strings.TrimPrefix(line, `"""`)
if strings.HasPrefix(args[2], `"""`) {
cut, found := strings.CutSuffix(line, `"""`)
prompt += cut + "\n"
if found {
opts.System = prompt
if args[1] == "system" {
fmt.Println("Set system template.")
} else {
fmt.Println("Set prompt template.")
}
prompt = ""
} else {
prompt = `"""` + prompt
if args[1] == "system" {
multiline = MultilineSystem
} else {
multiline = MultilineTemplate
}
scanner.Prompt.UseAlt = true
}
} else {
opts.System = line
fmt.Println("Set system template.")
}
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
}
@@ -705,7 +825,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
fmt.Println("error: couldn't connect to ollama server")
return err
}
resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model})
resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: opts.Model})
if err != nil {
fmt.Println("error: couldn't get model")
return err
@@ -724,19 +844,33 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
if resp.Parameters == "" {
fmt.Print("No parameters were specified for this model.\n\n")
} else {
if len(opts.Options) > 0 {
fmt.Println("User defined parameters:")
for k, v := range opts.Options {
fmt.Printf("%-*s %v\n", 30, k, v)
}
fmt.Println()
}
fmt.Println("Model defined parameters:")
fmt.Println(resp.Parameters)
}
case "system":
if resp.System == "" {
switch {
case opts.System != "":
fmt.Println(opts.System + "\n")
case resp.System != "":
fmt.Println(resp.System + "\n")
default:
fmt.Print("No system prompt was specified for this model.\n\n")
} else {
fmt.Println(resp.System)
}
case "template":
if resp.Template == "" {
fmt.Print("No prompt template was specified for this model.\n\n")
} else {
switch {
case opts.Template != "":
fmt.Println(opts.Template + "\n")
case resp.Template != "":
fmt.Println(resp.Template)
default:
fmt.Print("No prompt template was specified for this model.\n\n")
}
default:
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
@@ -766,8 +900,9 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
prompt += line
}
if len(prompt) > 0 && prompt[0] != '/' {
if err := generate(cmd, model, prompt, wordWrap, format); err != nil {
if len(prompt) > 0 && multiline == MultilineNone {
opts.Prompt = prompt
if err := generate(cmd, opts); err != nil {
return err
}
@@ -851,7 +986,7 @@ func initializeKeypair() error {
return nil
}
func startMacApp(client *api.Client) error {
func startMacApp(ctx context.Context, client *api.Client) error {
exe, err := os.Executable()
if err != nil {
return err
@@ -875,24 +1010,24 @@ func startMacApp(client *api.Client) error {
case <-timeout:
return errors.New("timed out waiting for server to start")
case <-tick:
if err := client.Heartbeat(context.Background()); err == nil {
if err := client.Heartbeat(ctx); err == nil {
return nil // server has started
}
}
}
}
func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if err := client.Heartbeat(context.Background()); err != nil {
if err := client.Heartbeat(cmd.Context()); err != nil {
if !strings.Contains(err.Error(), "connection refused") {
return err
}
if runtime.GOOS == "darwin" {
if err := startMacApp(client); err != nil {
if err := startMacApp(cmd.Context(), client); err != nil {
return fmt.Errorf("could not connect to ollama app, is it running?")
}
} else {
@@ -902,8 +1037,29 @@ func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
return nil
}
func versionHandler(cmd *cobra.Command, _ []string) {
client, err := api.ClientFromEnvironment()
if err != nil {
return
}
serverVersion, err := client.Version(cmd.Context())
if err != nil {
fmt.Println("Warning: could not connect to a running Ollama instance")
}
if serverVersion != "" {
fmt.Printf("ollama version is %s\n", serverVersion)
}
if serverVersion != version.Version {
fmt.Printf("Warning: client version is %s\n", version.Version)
}
}
func NewCLI() *cobra.Command {
log.SetFlags(log.LstdFlags | log.Lshortfile)
cobra.EnableCommandSorting = false
rootCmd := &cobra.Command{
Use: "ollama",
@@ -913,10 +1069,17 @@ func NewCLI() *cobra.Command {
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: true,
},
Version: version.Version,
Run: func(cmd *cobra.Command, args []string) {
if version, _ := cmd.Flags().GetBool("version"); version {
versionHandler(cmd, args)
return
}
cmd.Print(cmd.UsageString())
},
}
cobra.EnableCommandSorting = false
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
createCmd := &cobra.Command{
Use: "create MODEL",

View File

@@ -24,7 +24,7 @@ All durations are returned in nanoseconds.
### Streaming responses
Certain endpoints stream responses as JSON objects delineated with the newline (`\n`) character.
Certain endpoints stream responses as JSON objects.
## Generate a completion
@@ -32,7 +32,7 @@ Certain endpoints stream responses as JSON objects delineated with the newline (
POST /api/generate
```
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses. The final response object will include statistics and additional data from the request.
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
### Parameters
@@ -47,7 +47,7 @@ Advanced parameters (optional):
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself.
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API.
### JSON mode
@@ -114,6 +114,8 @@ To calculate how fast the response is generated in tokens per second (token/s),
#### Request (No streaming)
A response can be recieved in one reply when streaming is off.
```shell
curl http://localhost:11434/api/generate -d '{
"model": "llama2",
@@ -144,9 +146,9 @@ If `stream` is set to `false`, the response will be a single JSON object:
}
```
#### Request (Raw mode)
#### Request (Raw Mode)
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context.
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting.
```shell
curl http://localhost:11434/api/generate -d '{
@@ -164,6 +166,7 @@ curl http://localhost:11434/api/generate -d '{
"model": "mistral",
"created_at": "2023-11-03T15:36:02.583064Z",
"response": " The sky appears blue because of a phenomenon called Rayleigh scattering.",
"context": [1, 2, 3],
"done": true,
"total_duration": 14648695333,
"load_duration": 3302671417,
@@ -275,7 +278,6 @@ curl http://localhost:11434/api/generate -d '{
"model": "llama2",
"created_at": "2023-08-04T19:22:45.499127Z",
"response": "The sky is blue because it is the color of the sky.",
"context": [1, 2, 3],
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
@@ -288,6 +290,136 @@ curl http://localhost:11434/api/generate -d '{
}
```
## Send Chat Messages (coming in 0.1.14)
```shell
POST /api/chat
```
Generate the next message in a chat with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
### Parameters
- `model`: (required) the [model name](#model-names)
- `messages`: the messages of the chat, this can be used to keep a chat memory
Advanced parameters (optional):
- `format`: the format to return a response in. Currently the only accepted value is `json`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
### Examples
#### Request
Send a chat message with a streaming response.
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama2",
"messages": [
{
"role": "user",
"content": "why is the sky blue?"
}
]
}'
```
#### Response
A stream of JSON objects is returned:
```json
{
"model": "llama2",
"created_at": "2023-08-04T08:52:19.385406455-07:00",
"message": {
"role": "assisant",
"content": "The"
},
"done": false
}
```
Final response:
```json
{
"model": "llama2",
"created_at": "2023-08-04T19:22:45.499127Z",
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
"sample_count": 114,
"sample_duration": 81442000,
"prompt_eval_count": 46,
"prompt_eval_duration": 1160282000,
"eval_count": 113,
"eval_duration": 1325948000
}
```
#### Request (With History)
Send a chat message with a conversation history.
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama2",
"messages": [
{
"role": "user",
"content": "why is the sky blue?"
},
{
"role": "assistant",
"content": "due to rayleigh scattering."
},
{
"role": "user",
"content": "how is that different than mie scattering?"
}
]
}'
```
#### Response
A stream of JSON objects is returned:
```json
{
"model": "llama2",
"created_at": "2023-08-04T08:52:19.385406455-07:00",
"message": {
"role": "assisant",
"content": "The"
},
"done": false
}
```
Final response:
```json
{
"model": "llama2",
"created_at": "2023-08-04T19:22:45.499127Z",
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
"sample_count": 114,
"sample_duration": 81442000,
"prompt_eval_count": 46,
"prompt_eval_duration": 1160282000,
"eval_count": 113,
"eval_duration": 1325948000
}
```
## Create a Model
```shell

View File

@@ -23,7 +23,7 @@ Ollama binds to 127.0.0.1 port 11434 by default. Change the bind address with th
On macOS:
```bash
OLLAMA_HOST=0.0.0.0:11435 ollama serve
OLLAMA_HOST=0.0.0.0:11434 ollama serve
```
On Linux:
@@ -59,7 +59,7 @@ OLLAMA_ORIGINS=http://192.168.1.1:*,https://example.com ollama serve
On Linux:
```bash
echo 'Environment="OLLAMA_ORIGINS=http://129.168.1.1:*,https://example.com"' >>/etc/systemd/system/ollama.service.d/environment.conf
echo 'Environment="OLLAMA_ORIGINS=http://192.168.1.1:*,https://example.com"' >>/etc/systemd/system/ollama.service.d/environment.conf
```
Reload `systemd` and restart Ollama:
@@ -95,6 +95,10 @@ The manifest lists all the layers used in this model. You will see a `media type
To modify where models are stored, you can use the `OLLAMA_MODELS` environment variable. Note that on Linux this means defining `OLLAMA_MODELS` in a drop-in `/etc/systemd/system/ollama.service.d` service file, reloading systemd, and restarting the ollama service.
### I downloaded most of a model yesterday, but it's gone today. What happened?
When the Ollama server starts, it looks for fragments of models that still exist on the system and cleans them out. If you have an Internet connection that can't complete a model download all at once, this can be frustrating. Adding the OLLAMA_NOPRUNE environment variable will prevent the server from pruning incomplete files.
## Does Ollama send my prompts and answers back to Ollama.ai to use in any way?
No. Anything you do with Ollama, such as generate a response from the model, stays with you. We don't collect any data about how you use the model. You are always in control of your own data.

View File

@@ -43,7 +43,6 @@ Ollama supports a set of model architectures, with support for more coming soon:
- Llama & Mistral
- Falcon & RW
- GPT-NeoX
- BigCode
To view a model's architecture, check the `config.json` file in its HuggingFace repo. You should see an entry under `architectures` (e.g. `LlamaForCausalLM`).
@@ -184,9 +183,6 @@ python convert.py <path to model directory>
# FalconForCausalLM
python convert-falcon-hf-to-gguf.py <path to model directory>
# GPTNeoXForCausalLM
python convert-gptneox-hf-to-gguf.py <path to model directory>
# GPTBigCodeForCausalLM
python convert-starcoder-hf-to-gguf.py <path to model directory>
```

83
docs/tutorials/fly-gpu.md Normal file
View File

@@ -0,0 +1,83 @@
# Running Ollama on Fly.io GPU Instances
Ollama runs with little to no configuration on [Fly.io GPU instances](https://fly.io/docs/gpus/gpu-quickstart/). If you don't have access to GPUs yet, you'll need to [apply for access](https://fly.io/gpu/) on the waitlist. Once you're accepted, you'll get an email with instructions on how to get started.
Create a new app with `fly apps create`:
```bash
fly apps create
```
Then create a `fly.toml` file in a new folder that looks like this:
```toml
app = "sparkling-violet-709"
primary_region = "ord"
vm.size = "a100-40gb" # see https://fly.io/docs/gpus/gpu-quickstart/ for more info
[build]
image = "ollama/ollama"
[http_service]
internal_port = 11434
force_https = false
auto_stop_machines = true
auto_start_machines = true
min_machines_running = 0
processes = ["app"]
[mounts]
source = "models"
destination = "/root/.ollama"
initial_size = "100gb"
```
Then create a [new private IPv6 address](https://fly.io/docs/reference/private-networking/#flycast-private-load-balancing) for your app:
```bash
fly ips allocate-v6 --private
```
Then deploy your app:
```bash
fly deploy
```
And finally you can access it interactively with a new Fly.io Machine:
```
fly machine run -e OLLAMA_HOST=http://your-app-name.flycast --shell ollama/ollama
```
```bash
$ ollama run openchat:7b-v3.5-fp16
>>> How do I bake chocolate chip cookies?
To bake chocolate chip cookies, follow these steps:
1. Preheat the oven to 375°F (190°C) and line a baking sheet with parchment paper or silicone baking mat.
2. In a large bowl, mix together 1 cup of unsalted butter (softened), 3/4 cup granulated sugar, and 3/4
cup packed brown sugar until light and fluffy.
3. Add 2 large eggs, one at a time, to the butter mixture, beating well after each addition. Stir in 1
teaspoon of pure vanilla extract.
4. In a separate bowl, whisk together 2 cups all-purpose flour, 1/2 teaspoon baking soda, and 1/2 teaspoon
salt. Gradually add the dry ingredients to the wet ingredients, stirring until just combined.
5. Fold in 2 cups of chocolate chips (or chunks) into the dough.
6. Drop rounded tablespoons of dough onto the prepared baking sheet, spacing them about 2 inches apart.
7. Bake for 10-12 minutes, or until the edges are golden brown. The centers should still be slightly soft.
8. Allow the cookies to cool on the baking sheet for a few minutes before transferring them to a wire rack
to cool completely.
Enjoy your homemade chocolate chip cookies!
```
When you set it up like this, it will automatically turn off when you're done using it. Then when you access it again, it will automatically turn back on. This is a great way to save money on GPU instances when you're not using them. If you want a persistent wake-on-use connection to your Ollama instance, you can set up a [connection to your Fly network using WireGuard](https://fly.io/docs/reference/private-networking/#discovering-apps-through-dns-on-a-wireguard-connection). Then you can access your Ollama instance at `http://your-app-name.flycast`.
And that's it!

View File

@@ -42,12 +42,13 @@ text_splitter=RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_documents(data)
```
It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. For now, we don't have embeddings built in to Ollama, though we will be adding that soon, so for now, we can use the GPT4All library for that. We will use ChromaDB in this example for a vector database. `pip install GPT4All chromadb`
It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install GPT4All chromadb`
```python
from langchain.embeddings import GPT4AllEmbeddings
from langchain.embeddings import OllamaEmbeddings
from langchain.vectorstores import Chroma
vectorstore = Chroma.from_documents(documents=all_splits, embedding=GPT4AllEmbeddings())
oembed = OllamaEmbeddings(base_url="http://localhost:11434", model="llama2")
vectorstore = Chroma.from_documents(documents=all_splits, embedding=oembed)
```
Now let's ask a question from the document. **Who was Neleus, and who is in his family?** Neleus is a character in the Odyssey, and the answer can be found in our text.

View File

@@ -25,9 +25,11 @@ spec:
image: ollama/ollama:latest
env:
- name: PATH
value: /usr/local/nvidia/bin:/usr/local/nvidia/lib64:/usr/bin:/usr/sbin:/bin:/sbin
value: /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
- name: LD_LIBRARY_PATH
value: /usr/local/nvidia/lib64
value: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
- name: NVIDIA_DRIVER_CAPABILITIES
value: compute,utility
ports:
- name: http
containerPort: 11434

View File

@@ -0,0 +1,46 @@
import json
import requests
# NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
model = "llama2" # TODO: update this for whatever model you wish to use
def chat(messages):
r = requests.post(
"http://0.0.0.0:11434/api/chat",
json={"model": model, "messages": messages, "stream": True},
)
r.raise_for_status()
output = ""
for line in r.iter_lines():
body = json.loads(line)
if "error" in body:
raise Exception(body["error"])
if body.get("done") is False:
message = body.get("message", "")
content = message.get("content", "")
output += content
# the response streams one token at a time, print that as we receive it
print(content, end="", flush=True)
if body.get("done", False):
message["content"] = output
return message
def main():
messages = []
while True:
user_input = input("Enter a prompt: ")
print()
messages.append({"role": "user", "content": user_input})
message = chat(messages)
messages.append(message)
print("\n\n")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,24 @@
# Simple Chat Example
The **chat** endpoint is one of two ways to generate text from an LLM with Ollama. At a high level you provide the endpoint an array of objects with a role and content specified. Then with each output and prompt, you add more of those role/content objects, which builds up the history.
## Review the Code
You can see in the **chat** function that actually calling the endpoint is done simply with:
```python
r = requests.post(
"http://0.0.0.0:11434/api/chat",
json={"model": model, "messages": messages, "stream": True},
)
```
With the **generate** endpoint, you need to provide a `prompt`. But with **chat**, you provide `messages`. And the resulting stream of responses includes a `message` object with a `content` field.
The final JSON object doesn't provide the full content, so you will need to build the content yourself.
In the **main** function, we collect `user_input` and add it as a message to our messages and that is passed to the chat function. When the LLM is done responding the output is added as another message.
## Next Steps
In this example, all generations are kept. You might want to experiment with summarizing everything older than 10 conversations to enable longer history with less context being used.

View File

@@ -0,0 +1,77 @@
import * as readline from "readline";
const model = "llama2";
type Message = {
role: "assistant" | "user" | "system";
content: string;
}
const messages: Message[] = [{
role: "system",
content: "You are a helpful AI agent."
}]
const rl = readline.createInterface({
input: process.stdin,
output: process.stdout
})
async function chat(messages: Message[]): Promise<Message> {
const body = {
model: model,
messages: messages
}
const response = await fetch("http://localhost:11434/api/chat", {
method: "POST",
body: JSON.stringify(body)
})
const reader = response.body?.getReader()
if (!reader) {
throw new Error("Failed to read response body")
}
let content = ""
while (true) {
const { done, value } = await reader.read()
if (done) {
break;
}
const rawjson = new TextDecoder().decode(value);
const json = JSON.parse(rawjson)
if (json.done === false) {
process.stdout.write(json.message.content);
content += json.message.content
}
}
return { role: "assistant", content: content };
}
async function askQuestion(): Promise<void> {
return new Promise<void>((resolve) => {
rl.question("\n\nAsk a question: (press enter alone to quit)\n\n", async (user_input) => {
if (user_input.trim() === "") {
rl.close();
console.log("Thankyou. Goodbye.\n")
console.log("=======\nHere is the message history that was used in this conversation.\n=======\n")
messages.forEach(message => {
console.log(message)
})
resolve();
} else {
console.log();
messages.push({ role: "user", content: user_input });
messages.push(await chat(messages));
await askQuestion(); // Ask the next question
}
});
});
}
async function main() {
await askQuestion();
}
main();

View File

@@ -0,0 +1 @@
{ "dependencies": { "@types/node": "^20.10.4", "prompt-sync": "^4.2.0", "readline": "^1.3.0" } }

View File

@@ -0,0 +1,39 @@
# Simple Chat Example
The **chat** endpoint is one of two ways to generate text from an LLM with Ollama. At a high level you provide the endpoint an array of message objects with a role and content specified. Then with each output and prompt, you add more messages, which builds up the history.
## Run the Example
There are a few ways to run this, just like any Typescript code:
1. Compile with `tsc` and then run it with `node client.js`.
2. Install `tsx` and run it with `tsx client.ts`.
3. Install `bun` and run it with `bun client.ts`.
## Review the Code
You can see in the **chat** function that is actually calling the endpoint is simply done with:
```typescript
const body = {
model: model,
messages: messages
}
const response = await fetch("http://localhost:11434/api/chat", {
method: "POST",
body: JSON.stringify(body)
})
```
With the **generate** endpoint, you need to provide a `prompt`. But with **chat**, you provide `messages`. And the resulting stream of responses includes a `message` object with a `content` field.
The final JSON object doesn't provide the full content, so you will need to build the content yourself. In this example, **chat** takes the full array of messages and outputs the resulting message from this call of the chat endpoint.
In the **askQuestion** function, we collect `user_input` and add it as a message to our messages and that is passed to the chat function. When the LLM is done responding the output is added as another message to the messages array.
At the end, you will see a printout of all the messages.
## Next Steps
In this example, all generations are kept. You might want to experiment with summarizing everything older than 10 conversations to enable longer history with less context being used.

7
go.mod
View File

@@ -5,14 +5,15 @@ go 1.20
require (
github.com/emirpasic/gods v1.18.1
github.com/gin-gonic/gin v1.9.1
github.com/mattn/go-runewidth v0.0.14
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
golang.org/x/sync v0.3.0
)
require github.com/rivo/uniseg v0.2.0 // indirect
require (
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
)
require (
github.com/bytedance/sonic v1.9.1 // indirect

2
go.sum
View File

@@ -63,8 +63,6 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

View File

@@ -7,9 +7,10 @@ import (
)
type GGML struct {
magic uint32
container
model
Size int64
}
const (
@@ -82,7 +83,7 @@ type model interface {
type container interface {
Name() string
Decode(io.Reader) (model, error)
Decode(*readSeekOffset) (model, error)
}
type containerGGML struct{}
@@ -91,7 +92,7 @@ func (c *containerGGML) Name() string {
return "ggml"
}
func (c *containerGGML) Decode(r io.Reader) (model, error) {
func (c *containerGGML) Decode(ro *readSeekOffset) (model, error) {
return nil, nil
}
@@ -103,9 +104,9 @@ func (c *containerGGMF) Name() string {
return "ggmf"
}
func (c *containerGGMF) Decode(r io.Reader) (model, error) {
func (c *containerGGMF) Decode(ro *readSeekOffset) (model, error) {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
binary.Read(ro, binary.LittleEndian, &version)
switch version {
case 1:
@@ -125,9 +126,9 @@ func (c *containerGGJT) Name() string {
return "ggjt"
}
func (c *containerGGJT) Decode(r io.Reader) (model, error) {
func (c *containerGGJT) Decode(ro *readSeekOffset) (model, error) {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
binary.Read(ro, binary.LittleEndian, &version)
switch version {
case 1, 2, 3:
@@ -139,7 +140,7 @@ func (c *containerGGJT) Decode(r io.Reader) (model, error) {
// different model types may have different layouts for hyperparameters
var llama llamaModel
binary.Read(r, binary.LittleEndian, &llama.hyperparameters)
binary.Read(ro, binary.LittleEndian, &llama.hyperparameters)
return &llama, nil
}
@@ -151,9 +152,9 @@ func (c *containerLORA) Name() string {
return "ggla"
}
func (c *containerLORA) Decode(r io.Reader) (model, error) {
func (c *containerLORA) Decode(ro *readSeekOffset) (model, error) {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
binary.Read(ro, binary.LittleEndian, &version)
switch version {
case 1:
@@ -180,33 +181,61 @@ const (
)
func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
var ggml GGML
binary.Read(r, binary.LittleEndian, &ggml.magic)
ro := readSeekOffset{ReadSeeker: r}
switch ggml.magic {
var magic uint32
if err := binary.Read(&ro, binary.LittleEndian, &magic); err != nil {
return nil, err
}
var c container
switch magic {
case FILE_MAGIC_GGML:
ggml.container = &containerGGML{}
c = &containerGGML{}
case FILE_MAGIC_GGMF:
ggml.container = &containerGGMF{}
c = &containerGGMF{}
case FILE_MAGIC_GGJT:
ggml.container = &containerGGJT{}
c = &containerGGJT{}
case FILE_MAGIC_GGLA:
ggml.container = &containerLORA{}
c = &containerLORA{}
case FILE_MAGIC_GGUF_LE:
ggml.container = &containerGGUF{bo: binary.LittleEndian}
c = &containerGGUF{bo: binary.LittleEndian}
case FILE_MAGIC_GGUF_BE:
ggml.container = &containerGGUF{bo: binary.BigEndian}
c = &containerGGUF{bo: binary.BigEndian}
default:
return nil, errors.New("invalid file magic")
}
model, err := ggml.Decode(r)
model, err := c.Decode(&ro)
if err != nil {
return nil, err
}
ggml.model = model
// final model type
return &ggml, nil
return &GGML{
container: c,
model: model,
Size: ro.offset,
}, nil
}
type readSeekOffset struct {
io.ReadSeeker
offset int64
}
func (rso *readSeekOffset) Seek(offset int64, whence int) (int64, error) {
offset, err := rso.ReadSeeker.Seek(offset, whence)
if err != nil {
return 0, err
}
rso.offset = offset
return offset, nil
}
func (rso *readSeekOffset) Read(p []byte) (int, error) {
n, err := rso.ReadSeeker.Read(p)
rso.offset += int64(n)
return n, err
}

View File

@@ -23,26 +23,24 @@ type containerGGUF struct {
NumTensor uint64
NumKV uint64
}
parameters uint64
}
func (c *containerGGUF) Name() string {
return "gguf"
}
func (c *containerGGUF) Decode(r io.Reader) (model, error) {
binary.Read(r, c.bo, &c.Version)
func (c *containerGGUF) Decode(rso *readSeekOffset) (model, error) {
binary.Read(rso, c.bo, &c.Version)
switch c.Version {
case 1:
binary.Read(r, c.bo, &c.V1)
binary.Read(rso, c.bo, &c.V1)
default:
binary.Read(r, c.bo, &c.V2)
binary.Read(rso, c.bo, &c.V2)
}
model := newGGUFModel(c)
if err := model.Decode(r); err != nil {
if err := model.Decode(rso); err != nil {
return nil, err
}
@@ -67,9 +65,23 @@ const (
type kv map[string]any
type tensor struct {
name string
kind uint32
offset uint64
size uint64
// shape is the number of elements in each dimension
shape [4]uint64
}
type ggufModel struct {
*containerGGUF
kv
tensors []tensor
parameters uint64
}
func newGGUFModel(container *containerGGUF) *ggufModel {
@@ -96,8 +108,7 @@ func (llm *ggufModel) NumKV() uint64 {
}
func (llm *ggufModel) ModelFamily() string {
t, ok := llm.kv["general.architecture"].(string)
if ok {
if t, ok := llm.kv["general.architecture"].(string); ok {
return t
}
@@ -134,57 +145,56 @@ func (llm *ggufModel) ModelType() string {
}
func (llm *ggufModel) FileType() string {
t, ok := llm.kv["general.file_type"].(uint32)
if ok {
if t, ok := llm.kv["general.file_type"].(uint32); ok {
return fileType(t)
}
return "unknown"
}
func (llm *ggufModel) Decode(r io.Reader) error {
func (llm *ggufModel) Decode(rso *readSeekOffset) error {
// decode key-values
for i := 0; uint64(i) < llm.NumKV(); i++ {
k, err := llm.readString(r)
k, err := llm.readString(rso)
if err != nil {
return err
}
vtype := llm.readU32(r)
vtype := llm.readU32(rso)
var v any
switch vtype {
case ggufTypeUint8:
v = llm.readU8(r)
v = llm.readU8(rso)
case ggufTypeInt8:
v = llm.readI8(r)
v = llm.readI8(rso)
case ggufTypeUint16:
v = llm.readU16(r)
v = llm.readU16(rso)
case ggufTypeInt16:
v = llm.readI16(r)
v = llm.readI16(rso)
case ggufTypeUint32:
v = llm.readU32(r)
v = llm.readU32(rso)
case ggufTypeInt32:
v = llm.readI32(r)
v = llm.readI32(rso)
case ggufTypeUint64:
v = llm.readU64(r)
v = llm.readU64(rso)
case ggufTypeInt64:
v = llm.readI64(r)
v = llm.readI64(rso)
case ggufTypeFloat32:
v = llm.readF32(r)
v = llm.readF32(rso)
case ggufTypeFloat64:
v = llm.readF64(r)
v = llm.readF64(rso)
case ggufTypeBool:
v = llm.readBool(r)
v = llm.readBool(rso)
case ggufTypeString:
s, err := llm.readString(r)
s, err := llm.readString(rso)
if err != nil {
return err
}
v = s
case ggufTypeArray:
a, err := llm.readArray(r)
a, err := llm.readArray(rso)
if err != nil {
return err
}
@@ -199,21 +209,85 @@ func (llm *ggufModel) Decode(r io.Reader) error {
// decode tensors
for i := 0; uint64(i) < llm.NumTensor(); i++ {
if _, err := llm.readString(r); err != nil {
name, err := llm.readString(rso)
if err != nil {
return err
}
dimensions := llm.readU32(r)
// dims is the number of dimensions in the tensor
dims := llm.readU32(rso)
var elements uint64 = 1
for i := 0; uint32(i) < dimensions; i++ {
elements *= llm.readU64(r)
shape := [4]uint64{1, 1, 1, 1}
for i := 0; uint32(i) < dims; i++ {
shape[i] = llm.readU64(rso)
}
llm.readU32(r) // type
llm.readU64(r) // offset
kind := llm.readU32(rso)
offset := llm.readU64(rso)
llm.parameters += elements
var blockSize uint64
switch {
case kind < 2:
blockSize = 1
case kind < 10:
blockSize = 32
default:
blockSize = 256
}
var typeSize uint64
switch kind {
case 0: // FP32
typeSize = 4
case 1: // FP16
typeSize = 2
case 2: // Q4_0
typeSize = 2 + blockSize/2
case 3: // Q4_1
typeSize = 2 + 2 + blockSize/2
case 6: // Q5_0
typeSize = 2 + 4 + blockSize/2
case 7: // Q5_1
typeSize = 2 + 2 + 4 + blockSize/2
case 8: // Q8_0
typeSize = 2 + blockSize
case 9: // Q8_1
typeSize = 4 + 4 + blockSize
case 10: // Q2_K
typeSize = blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K
typeSize = blockSize/8 + blockSize/4 + 12 + 2
case 12: // Q4_K
typeSize = 2 + 2 + 12 + blockSize/2
case 13: // Q5_K
typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2
case 14: // Q6_K
typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2
}
parameters := shape[0] * shape[1] * shape[2] * shape[3]
size := parameters * typeSize / blockSize
llm.tensors = append(llm.tensors, tensor{
name: name,
kind: kind,
offset: offset,
size: size,
shape: shape,
})
llm.parameters += parameters
}
alignment, ok := llm.kv["general.alignment"].(uint32)
if !ok {
alignment = 32
}
rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
for _, tensor := range llm.tensors {
padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1)
rso.Seek(padded, io.SeekCurrent)
}
return nil

View File

@@ -325,7 +325,7 @@ func (w *StatusWriter) Write(b []byte) (int, error) {
return os.Stderr.Write(b)
}
func newLlama(model string, adapters []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
func newLlama(model string, adapters, projectors []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
fileInfo, err := os.Stat(model)
if err != nil {
return nil, err
@@ -365,6 +365,11 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
params = append(params, "--lora", adapters[0])
}
if len(projectors) > 0 {
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
params = append(params, "--mmproj", projectors[0])
}
if opts.NumThread > 0 {
params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
}
@@ -531,21 +536,30 @@ type prediction struct {
const maxBufferSize = 512 * format.KiloByte
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
prevConvo, err := llm.Decode(ctx, prevContext)
if err != nil {
return err
}
type PredictOpts struct {
Model string
Prompt string
Format string
CheckpointStart time.Time
CheckpointLoaded time.Time
}
// Remove leading spaces from prevConvo if present
prevConvo = strings.TrimPrefix(prevConvo, " ")
var nextContext strings.Builder
nextContext.WriteString(prevConvo)
nextContext.WriteString(prompt)
type PredictResult struct {
Model string
CreatedAt time.Time
TotalDuration time.Duration
LoadDuration time.Duration
Content string
Done bool
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
}
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
request := map[string]any{
"prompt": nextContext.String(),
"prompt": predict.Prompt,
"stream": true,
"n_predict": llm.NumPredict,
"n_keep": llm.NumKeep,
@@ -567,7 +581,7 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
"stop": llm.Stop,
}
if format == "json" {
if predict.Format == "json" {
request["grammar"] = jsonGrammar
}
@@ -624,25 +638,25 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
}
if p.Content != "" {
fn(api.GenerateResponse{Response: p.Content})
nextContext.WriteString(p.Content)
fn(PredictResult{
Model: predict.Model,
CreatedAt: time.Now().UTC(),
Content: p.Content,
})
}
if p.Stop {
embd, err := llm.Encode(ctx, nextContext.String())
if err != nil {
return fmt.Errorf("encoding context: %v", err)
}
fn(PredictResult{
Model: predict.Model,
CreatedAt: time.Now().UTC(),
TotalDuration: time.Since(predict.CheckpointStart),
fn(api.GenerateResponse{
Done: true,
Context: embd,
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
})
return nil
}
}

View File

@@ -14,7 +14,7 @@ import (
)
type LLM interface {
Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
Predict(context.Context, PredictOpts, func(PredictResult)) error
Embedding(context.Context, string) ([]float64, error)
Encode(context.Context, string) ([]int, error)
Decode(context.Context, []int) (string, error)
@@ -23,7 +23,7 @@ type LLM interface {
Ping(context.Context) error
}
func New(workDir, model string, adapters []string, opts api.Options) (LLM, error) {
func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
@@ -82,9 +82,9 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
opts.NumGQA = 0
opts.RopeFrequencyBase = 0.0
opts.RopeFrequencyScale = 0.0
return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
return newLlama(model, adapters, projectors, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
case "ggml", "ggmf", "ggjt", "ggla":
return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
return newLlama(model, adapters, projectors, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
default:
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
}

View File

@@ -37,10 +37,13 @@ func Parse(reader io.Reader) ([]Command, error) {
switch string(bytes.ToUpper(fields[0])) {
case "FROM":
command.Name = "model"
command.Args = string(fields[1])
command.Args = string(bytes.TrimSpace(fields[1]))
// copy command for validation
modelCommand = command
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "ADAPTER":
case "ADAPTER":
command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(bytes.TrimSpace(fields[1]))
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(fields[1])
case "PARAMETER":
@@ -50,7 +53,7 @@ func Parse(reader io.Reader) ([]Command, error) {
}
command.Name = string(fields[0])
command.Args = string(fields[1])
command.Args = string(bytes.TrimSpace(fields[1]))
case "EMBED":
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
default:

View File

@@ -191,6 +191,15 @@ func (i *Instance) Readline() (string, error) {
buf.ClearScreen()
case CharCtrlW:
buf.DeleteWord()
case CharCtrlZ:
if err := UnsetRawMode(fd, termios); err != nil {
return "", err
}
syscall.Kill(0, syscall.SIGSTOP)
// on resume...
return "", nil
case CharEnter:
output := buf.String()
if output != "" {

View File

@@ -217,7 +217,7 @@ fi
if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
case $OS_NAME in
centos|rhel) install_cuda_driver_yum 'rhel' $OS_VERSION ;;
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
fedora) install_cuda_driver_yum $OS_NAME $OS_VERSION ;;
amzn) install_cuda_driver_yum 'fedora' '35' ;;
@@ -230,7 +230,8 @@ fi
if ! lsmod | grep -q nvidia; then
KERNEL_RELEASE="$(uname -r)"
case $OS_NAME in
centos|rhel|rocky|amzn) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE kernel-headers-$KERNEL_RELEASE ;;
rocky) $SUDO $PACKAGE_MANAGER -y install kernel-devel kernel-headers ;;
centos|rhel|amzn) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE kernel-headers-$KERNEL_RELEASE ;;
fedora) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE ;;
debian|ubuntu) $SUDO apt-get -y install linux-headers-$KERNEL_RELEASE ;;
*) exit ;;

View File

@@ -14,7 +14,6 @@ import (
"net/url"
"os"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
@@ -36,80 +35,153 @@ type RegistryOptions struct {
}
type Model struct {
Name string `json:"name"`
ShortName string
ModelPath string
OriginalModel string
AdapterPaths []string
Template string
System string
License []string
Digest string
Options map[string]interface{}
Name string `json:"name"`
Config ConfigV2
ShortName string
ModelPath string
OriginalModel string
AdapterPaths []string
ProjectorPaths []string
Template string
System string
License []string
Digest string
Options map[string]interface{}
}
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
t := m.Template
if request.Template != "" {
t = request.Template
}
type PromptVars struct {
System string
Prompt string
Response string
First bool
}
tmpl, err := template.New("").Parse(t)
func (m *Model) Prompt(p PromptVars) (string, error) {
var prompt strings.Builder
// Use the "missingkey=zero" option to handle missing variables without panicking
tmpl, err := template.New("").Option("missingkey=zero").Parse(m.Template)
if err != nil {
return "", err
}
var vars struct {
First bool
System string
Prompt string
if p.System == "" {
// use the default system prompt for this model if one is not specified
p.System = m.System
}
vars.First = len(request.Context) == 0
vars.System = m.System
vars.Prompt = request.Prompt
if request.System != "" {
vars.System = request.System
vars := map[string]any{
"System": p.System,
"Prompt": p.Prompt,
"Response": p.Response,
"First": p.First,
}
var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil {
return "", err
}
prompt.WriteString(sb.String())
prompt.WriteString(p.Response)
return prompt.String(), nil
}
return sb.String(), nil
func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
// build the prompt from the list of messages
var prompt strings.Builder
currentVars := PromptVars{
First: true,
}
writePrompt := func() error {
p, err := m.Prompt(currentVars)
if err != nil {
return err
}
prompt.WriteString(p)
currentVars = PromptVars{}
return nil
}
for _, msg := range msgs {
switch strings.ToLower(msg.Role) {
case "system":
if currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
currentVars.System = msg.Content
case "user":
if currentVars.Prompt != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
currentVars.Prompt = msg.Content
case "assistant":
currentVars.Response = msg.Content
if err := writePrompt(); err != nil {
return "", err
}
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
return prompt.String(), nil
}
type ManifestV2 struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config Layer `json:"config"`
Config *Layer `json:"config"`
Layers []*Layer `json:"layers"`
}
type Layer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
}
type LayerReader struct {
Layer
io.Reader
}
type ConfigV2 struct {
ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"`
ModelType string `json:"model_type"`
FileType string `json:"file_type"`
RootFS RootFS `json:"rootfs"`
// required by spec
Architecture string `json:"architecture"`
OS string `json:"os"`
RootFS RootFS `json:"rootfs"`
api.ModelConfiguration
}
func (c *ConfigV2) SetModelFormat(format string) {
if c.ModelFormat == "" {
c.ModelFormat = format
}
}
func (c *ConfigV2) SetModelFamily(families ...string) {
for _, family := range families {
if c.ModelFamily == "" {
c.ModelFamily = family
}
if !slices.Contains(c.ModelFamilies, family) {
c.ModelFamilies = append(c.ModelFamilies, family)
}
}
}
func (c *ConfigV2) SetModelType(modelType string) {
if c.ModelType == "" {
c.ModelType = modelType
}
}
func (c *ConfigV2) SetFileType(fileType string) {
if c.FileType == "" {
c.FileType = fileType
}
}
type RootFS struct {
@@ -168,6 +240,21 @@ func GetModel(name string) (*Model, error) {
License: []string{},
}
filename, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
return nil, err
}
configFile, err := os.Open(filename)
if err != nil {
return nil, err
}
defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
return nil, err
}
for _, layer := range manifest.Layers {
filename, err := GetBlobsPath(layer.Digest)
if err != nil {
@@ -184,6 +271,8 @@ func GetModel(name string) (*Model, error) {
log.Print("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename)
case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename)
case "application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename)
if err != nil {
@@ -257,11 +346,14 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
RootFS: RootFS{
Type: "layers",
},
}
deleteMap := make(map[string]struct{})
var layers []*LayerReader
var layers Layers
params := make(map[string][]string)
fromParams := make(map[string]any)
@@ -318,10 +410,10 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
return err
}
config.ModelFormat = fromConfig.ModelFormat
config.ModelFamily = fromConfig.ModelFamily
config.ModelType = fromConfig.ModelType
config.FileType = fromConfig.FileType
config.SetModelFormat(fromConfig.ModelFormat)
config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...)
config.SetModelType(fromConfig.ModelType)
config.SetFileType(fromConfig.FileType)
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
@@ -342,13 +434,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
}
layer, err := GetLayerWithBufferFromLayer(layer)
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
if err != nil {
return err
}
layer.From = modelpath.GetShortTagname()
layers = append(layers, layer)
layers.Add(layer)
}
deleteMap[manifest.Config.Digest] = struct{}{}
@@ -356,26 +447,48 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
defer bin.Close()
fn(api.ProgressResponse{Status: "creating model layer"})
ggml, err := llm.DecodeGGML(bin)
if err != nil {
return err
var offset int64
for {
fn(api.ProgressResponse{Status: "creating model layer"})
bin.Seek(offset, io.SeekStart)
ggml, err := llm.DecodeGGML(bin)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return err
}
config.SetModelFormat(ggml.Name())
config.SetModelFamily(ggml.ModelFamily())
config.SetModelType(ggml.ModelType())
config.SetFileType(ggml.FileType())
mediatype := mediatype
if ggml.ModelFamily() == "clip" {
mediatype = "application/vnd.ollama.image.projector"
}
sr := io.NewSectionReader(bin, offset, ggml.Size)
layer, err := NewLayer(sr, mediatype)
if err != nil {
return err
}
layers.Add(layer)
offset += ggml.Size
}
config.ModelFormat = ggml.Name()
config.ModelFamily = ggml.ModelFamily()
config.ModelType = ggml.ModelType()
config.FileType = ggml.FileType()
bin.Seek(0, io.SeekStart)
layer, err := CreateLayer(bin)
if err != nil {
return err
}
layer.MediaType = mediatype
layers = append(layers, layer)
case "adapter":
if strings.HasPrefix(c.Args, "@") {
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}
c.Args = blobPath
}
fn(api.ProgressResponse{Status: "creating adapter layer"})
bin, err := os.Open(realpath(modelFileDir, c.Args))
if err != nil {
@@ -383,41 +496,32 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
defer bin.Close()
layer, err := CreateLayer(bin)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
layers.Add(layer)
case "license":
fn(api.ProgressResponse{Status: "creating license layer"})
layer, err := CreateLayer(strings.NewReader(c.Args))
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
layers.Add(layer)
case "template", "system":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
// remove duplicate layers
layers = removeLayerFromLayers(layers, mediatype)
layer, err := CreateLayer(strings.NewReader(c.Args))
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
layers.Replace(layer)
default:
params[c.Name] = append(params[c.Name], c.Args)
}
@@ -426,7 +530,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
formattedParams, err := formatParams(params)
formattedParams, err := api.FormatParams(params)
if err != nil {
return err
}
@@ -449,40 +553,51 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
fn(api.ProgressResponse{Status: "creating config layer"})
layer, err := CreateLayer(bytes.NewReader(b.Bytes()))
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return err
}
layer.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, layer)
layers.Replace(layer)
}
digests, err := getLayerDigests(layers)
digests := make([]string, len(layers.items))
for i, layer := range layers.items {
digests[i] = layer.Digest
}
config.RootFS.DiffIDs = digests
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(config); err != nil {
return err
}
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
return err
}
configLayer, err := createConfigLayer(config, digests)
if err != nil {
return err
}
layers = append(layers, configLayer)
delete(deleteMap, configLayer.Digest)
if err := SaveLayers(layers, fn, false); err != nil {
return err
}
for _, layer := range append(layers.items, configLayer) {
committed, err := layer.Commit()
if err != nil {
return err
}
status := "writing layer"
if !committed {
status = "using already created layer"
}
fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)})
var contentLayers []*Layer
for _, layer := range layers {
contentLayers = append(contentLayers, &layer.Layer)
delete(deleteMap, layer.Digest)
}
fn(api.ProgressResponse{Status: "writing manifest"})
if err := CreateManifest(name, configLayer, contentLayers); err != nil {
if err := WriteManifest(name, configLayer, layers.items); err != nil {
return err
}
@@ -496,177 +611,6 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
return nil
}
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
return slices.DeleteFunc(layers, func(layer *LayerReader) bool {
return layer.MediaType == mediaType
})
}
func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
// Write each of the layers to disk
for _, layer := range layers {
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
_, err = os.Stat(fp)
if os.IsNotExist(err) || force {
fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
out, err := os.Create(fp)
if err != nil {
log.Printf("couldn't create %s", fp)
return err
}
defer out.Close()
if _, err = io.Copy(out, layer.Reader); err != nil {
return err
}
} else {
fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
}
}
return nil
}
func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
mp := ParseModelPath(name)
manifest := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: Layer{
MediaType: cfg.MediaType,
Size: cfg.Size,
Digest: cfg.Digest,
},
Layers: layers,
}
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
return os.WriteFile(fp, manifestJSON, 0o644)
}
func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
file, err := os.Open(fp)
if err != nil {
return nil, fmt.Errorf("could not open blob: %w", err)
}
defer file.Close()
newLayer, err := CreateLayer(file)
if err != nil {
return nil, err
}
newLayer.MediaType = layer.MediaType
return newLayer, nil
}
// formatParams converts specified parameter options to their correct types
func formatParams(params map[string][]string) (map[string]interface{}, error) {
opts := api.Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
// build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField)
for _, field := range reflect.VisibleFields(typeOpts) {
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
if jsonTag != "" {
jsonOpts[jsonTag] = field
}
}
out := make(map[string]interface{})
// iterate params and set values based on json struct tags
for key, vals := range params {
if opt, ok := jsonOpts[key]; ok {
field := valueOpts.FieldByName(opt.Name)
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Float32:
floatVal, err := strconv.ParseFloat(vals[0], 32)
if err != nil {
return nil, fmt.Errorf("invalid float value %s", vals)
}
out[key] = float32(floatVal)
case reflect.Int:
intVal, err := strconv.ParseInt(vals[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid int value %s", vals)
}
out[key] = intVal
case reflect.Bool:
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)
}
out[key] = boolVal
case reflect.String:
out[key] = vals[0]
case reflect.Slice:
// TODO: only string slices are supported right now
out[key] = vals
default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}
}
}
}
return out, nil
}
func getLayerDigests(layers []*LayerReader) ([]string, error) {
var digests []string
for _, l := range layers {
if l.Digest == "" {
return nil, fmt.Errorf("layer is missing a digest")
}
digests = append(digests, l.Digest)
}
return digests, nil
}
// CreateLayer creates a Layer object from a given file
func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
digest, size := GetSHA256Digest(f)
f.Seek(0, io.SeekStart)
layer := &LayerReader{
Layer: Layer{
MediaType: "application/vnd.docker.image.rootfs.diff.tar",
Digest: digest,
Size: size,
},
Reader: f,
}
return layer, nil
}
func CopyModel(src, dest string) error {
srcModelPath := ParseModelPath(src)
srcPath, err := srcModelPath.GetManifestPath()
@@ -934,7 +878,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, &manifest.Config)
layers = append(layers, manifest.Config)
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
@@ -1005,7 +949,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, &manifest.Config)
layers = append(layers, manifest.Config)
for _, layer := range layers {
if err := downloadBlob(
@@ -1093,30 +1037,6 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptio
return m, err
}
func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
config.RootFS = RootFS{
Type: "layers",
DiffIDs: layers,
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
layer := &LayerReader{
Layer: Layer{
MediaType: "application/vnd.docker.container.image.v1+json",
Digest: digest,
Size: size,
},
Reader: bytes.NewBuffer(configJSON),
}
return layer, nil
}
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
func GetSHA256Digest(r io.Reader) (string, int64) {
h := sha256.New()

View File

@@ -1,23 +1,98 @@
package server
import (
"strings"
"testing"
"github.com/jmorganca/ollama/api"
)
func TestModelPrompt(t *testing.T) {
var m Model
req := api.GenerateRequest{
Template: "a{{ .Prompt }}b",
Prompt: "<h1>",
func TestChat(t *testing.T) {
tests := []struct {
name string
template string
msgs []api.Message
want string
wantErr string
}{
{
name: "Single Message",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "Message History",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
{
Role: "assistant",
Content: "sugar",
},
{
Role: "user",
Content: "Anything else?",
},
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
},
{
name: "Assistant Only",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "assistant",
Content: "everything nice",
},
},
want: "[INST] [/INST]everything nice",
},
{
name: "Invalid Role",
msgs: []api.Message{
{
Role: "not-a-role",
Content: "howdy",
},
},
wantErr: "invalid role: not-a-role",
},
}
s, err := m.Prompt(req)
if err != nil {
t.Fatal(err)
}
want := "a<h1>b"
if s != want {
t.Errorf("got %q, want %q", s, want)
for _, tt := range tests {
m := Model{
Template: tt.template,
}
t.Run(tt.name, func(t *testing.T) {
got, err := m.ChatPrompt(tt.msgs)
if tt.wantErr != "" {
if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
}
}
if got != tt.want {
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
}
})
}
}

109
server/layers.go Normal file
View File

@@ -0,0 +1,109 @@
package server
import (
"crypto/sha256"
"fmt"
"io"
"os"
"runtime"
"strings"
"golang.org/x/exp/slices"
)
type Layers struct {
items []*Layer
}
func (ls *Layers) Add(layer *Layer) {
if layer.Size > 0 {
ls.items = append(ls.items, layer)
}
}
func (ls *Layers) Replace(layer *Layer) {
if layer.Size > 0 {
mediatype := layer.MediaType
layers := slices.DeleteFunc(ls.items, func(l *Layer) bool {
return l.MediaType == mediatype
})
ls.items = append(layers, layer)
}
}
type Layer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
tempFileName string
}
func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {
return nil, err
}
delimiter := ":"
if runtime.GOOS == "windows" {
delimiter = "-"
}
pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter)
temp, err := os.CreateTemp(blobs, pattern)
if err != nil {
return nil, err
}
defer temp.Close()
sha256sum := sha256.New()
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
if err != nil {
return nil, err
}
return &Layer{
MediaType: mediatype,
Digest: fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)),
Size: n,
tempFileName: temp.Name(),
}, nil
}
func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
blob, err := GetBlobsPath(digest)
if err != nil {
return nil, err
}
fi, err := os.Stat(blob)
if err != nil {
return nil, err
}
return &Layer{
MediaType: mediatype,
Digest: digest,
Size: fi.Size(),
From: from,
}, nil
}
func (l *Layer) Commit() (bool, error) {
// always remove temp
defer os.Remove(l.tempFileName)
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return false, err
}
if _, err := os.Stat(blob); err != nil {
return true, os.Rename(l.tempFileName, blob)
}
return false, nil
}

34
server/manifests.go Normal file
View File

@@ -0,0 +1,34 @@
package server
import (
"bytes"
"encoding/json"
"os"
"path/filepath"
)
func WriteManifest(name string, config *Layer, layers []*Layer) error {
manifest := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,
Layers: layers,
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(manifest); err != nil {
return err
}
modelpath := ParseModelPath(name)
manifestPath, err := modelpath.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
return err
}
return os.WriteFile(manifestPath, b.Bytes(), 0644)
}

View File

@@ -67,6 +67,20 @@ func ParseModelPath(name string) ModelPath {
return mp
}
var errModelPathInvalid = errors.New("invalid model path")
func (mp ModelPath) Validate() error {
if mp.Repository == "" {
return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
}
if strings.Contains(mp.Tag, ":") {
return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
}
return nil
}
func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}

View File

@@ -2,7 +2,6 @@ package server
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
@@ -60,17 +59,26 @@ var loaded struct {
var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
model, err := GetModel(modelName)
if err != nil {
return nil, err
}
workDir := c.GetString("workDir")
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
log.Printf("could not load model options: %v", err)
return err
return nil, err
}
if err := opts.FromMap(reqOpts); err != nil {
return err
return nil, err
}
ctx := c.Request.Context()
// check if the loaded model is still running in a subprocess, in case something unexpected happened
if loaded.runner != nil {
if err := loaded.runner.Ping(ctx); err != nil {
@@ -97,7 +105,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
loaded.Options = nil
}
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
@@ -106,7 +114,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
}
return err
return nil, err
}
loaded.Model = model
@@ -140,7 +148,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
}
loaded.expireTimer.Reset(sessionDuration)
return nil
return model, nil
}
func GenerateHandler(c *gin.Context) {
@@ -173,88 +181,140 @@ func GenerateHandler(c *gin.Context) {
return
}
model, err := GetModel(req.Model)
sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
workDir := c.GetString("workDir")
// TODO: set this duration from the request if specified
sessionDuration := defaultSessionDuration
if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
ModelConfiguration: model.Config.ModelConfiguration,
Done: true})
return
}
checkpointLoaded := time.Now()
prompt := req.Prompt
if !req.Raw {
prompt, err = model.Prompt(req)
var prompt string
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template != "" {
// override the default model template
model.Template = req.Template
}
var rebuild strings.Builder
if req.Context != nil {
// TODO: context is deprecated, at some point the context logic within this conditional should be removed
prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Remove leading spaces from prevCtx if present
prevCtx = strings.TrimPrefix(prevCtx, " ")
rebuild.WriteString(prevCtx)
}
p, err := model.Prompt(PromptVars{
System: req.System,
Prompt: req.Prompt,
First: len(req.Context) == 0,
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rebuild.WriteString(p)
prompt = rebuild.String()
}
ch := make(chan any)
var generated strings.Builder
go func() {
defer close(ch)
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
return
}
fn := func(r api.GenerateResponse) {
fn := func(r llm.PredictResult) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
r.Model = req.Model
r.CreatedAt = time.Now().UTC()
if r.Done {
r.TotalDuration = time.Since(checkpointStart)
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
if req.Raw {
// in raw mode the client must manage history on their own
r.Context = nil
resp := api.GenerateResponse{
Model: r.Model,
ModelConfiguration: model.Config.ModelConfiguration,
CreatedAt: r.CreatedAt,
Done: r.Done,
Response: r.Content,
Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
ch <- r
if r.Done && !req.Raw {
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = embd
}
ch <- resp
}
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil {
// Start prediction
predictReq := llm.PredictOpts{
Model: model.Name,
Prompt: prompt,
Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
var response api.GenerateResponse
generated := ""
// Wait for the channel to close
var r api.GenerateResponse
var sb strings.Builder
for resp := range ch {
if r, ok := resp.(api.GenerateResponse); ok {
generated += r.Response
response = r
} else {
var ok bool
if r, ok = resp.(api.GenerateResponse); !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.WriteString(r.Response)
}
response.Response = generated
c.JSON(http.StatusOK, response)
r.Response = sb.String()
c.JSON(http.StatusOK, r)
return
}
@@ -281,15 +341,18 @@ func EmbeddingHandler(c *gin.Context) {
return
}
model, err := GetModel(req.Model)
sessionDuration := defaultSessionDuration
_, err = load(c, req.Model, req.Options, sessionDuration)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
workDir := c.GetString("workDir")
if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
@@ -416,6 +479,11 @@ func CreateModelHandler(c *gin.Context) {
return
}
if err := ParseModelPath(req.Name).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return
@@ -640,6 +708,11 @@ func CopyModelHandler(c *gin.Context) {
return
}
if err := ParseModelPath(req.Destination).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := CopyModel(req.Source, req.Destination); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
@@ -666,37 +739,18 @@ func HeadBlobHandler(c *gin.Context) {
}
func CreateBlobHandler(c *gin.Context) {
targetPath, err := GetBlobsPath(c.Param("digest"))
layer, err := NewLayer(c.Request.Body, "")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
hash := sha256.New()
temp, err := os.CreateTemp(filepath.Dir(targetPath), c.Param("digest")+"-")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer temp.Close()
defer os.Remove(temp.Name())
if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
if layer.Digest != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
return
}
if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"})
return
}
if err := temp.Close(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := os.Rename(temp.Name(), targetPath); err != nil {
if _, err := layer.Commit(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -757,6 +811,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
r.POST("/api/pull", PullModelHandler)
r.POST("/api/generate", GenerateHandler)
r.POST("/api/chat", ChatHandler)
r.POST("/api/embeddings", EmbeddingHandler)
r.POST("/api/create", CreateModelHandler)
r.POST("/api/push", PushModelHandler)
@@ -772,6 +827,9 @@ func Serve(ln net.Listener, allowOrigins []string) error {
})
r.Handle(method, "/api/tags", ListModelsHandler)
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
})
}
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
@@ -850,3 +908,125 @@ func streamResponse(c *gin.Context, ch chan any) {
return true
})
}
func ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now()
var req api.ChatRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
}
sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
if err != nil {
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
// an empty request loads the model
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
return
}
checkpointLoaded := time.Now()
prompt, err := model.ChatPrompt(req.Messages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ch := make(chan any)
go func() {
defer close(ch)
fn := func(r llm.PredictResult) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{
Model: r.Model,
CreatedAt: r.CreatedAt,
Done: r.Done,
Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if !r.Done {
resp.Message = &api.Message{Role: "assistant", Content: r.Content}
}
ch <- resp
}
// Start prediction
predictReq := llm.PredictOpts{
Model: model.Name,
Prompt: prompt,
Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
// Wait for the channel to close
var r api.ChatResponse
var sb strings.Builder
for resp := range ch {
var ok bool
if r, ok = resp.(api.ChatResponse); !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if r.Message != nil {
sb.WriteString(r.Message.Content)
}
}
r.Message = &api.Message{Role: "assistant", Content: sb.String()}
c.JSON(http.StatusOK, r)
return
}
streamResponse(c, ch)
}

View File

@@ -5,6 +5,7 @@ import (
"crypto/md5"
"errors"
"fmt"
"hash"
"io"
"log"
"math"
@@ -102,7 +103,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
}
// set part.N to the current number of parts
b.Parts = append(b.Parts, blobUploadPart{blobUpload: b, N: len(b.Parts), Offset: offset, Size: size})
b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size})
offset += size
}
@@ -147,14 +148,13 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
g.Go(func() error {
var err error
for try := 0; try < maxRetries; try++ {
err = b.uploadChunk(inner, http.MethodPatch, requestURL, part, opts)
err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
switch {
case errors.Is(err, context.Canceled):
return err
case errors.Is(err, errMaxRetriesExceeded):
return err
case err != nil:
part.Reset()
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep)
@@ -176,17 +176,10 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
requestURL := <-b.nextURL
var sb strings.Builder
// calculate md5 checksum and add it to the commit request
var sb strings.Builder
for _, part := range b.Parts {
hash := md5.New()
if _, err := io.Copy(hash, io.NewSectionReader(b.file, part.Offset, part.Size)); err != nil {
b.err = err
return
}
sb.Write(hash.Sum(nil))
sb.Write(part.Sum(nil))
}
md5sum := md5.Sum([]byte(sb.String()))
@@ -201,27 +194,25 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
headers.Set("Content-Length", "0")
for try := 0; try < maxRetries; try++ {
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
if err != nil {
b.err = err
if errors.Is(err, context.Canceled) {
return
}
var resp *http.Response
resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
if errors.Is(err, context.Canceled) {
break
} else if err != nil {
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep)
time.Sleep(sleep)
continue
}
defer resp.Body.Close()
b.err = nil
b.done = true
return
break
}
b.err = err
b.done = true
}
func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
@@ -232,8 +223,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
}
sr := io.NewSectionReader(b.file, part.Offset, part.Size)
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, part), opts)
md5sum := md5.New()
w := &progressWriter{blobUpload: b}
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
if err != nil {
w.Rollback()
return err
}
defer resp.Body.Close()
@@ -245,11 +241,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
nextURL, err := url.Parse(location)
if err != nil {
w.Rollback()
return err
}
switch {
case resp.StatusCode == http.StatusTemporaryRedirect:
w.Rollback()
b.nextURL <- nextURL
redirectURL, err := resp.Location()
@@ -259,14 +257,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
// retry uploading to the redirect URL
for try := 0; try < maxRetries; try++ {
err = b.uploadChunk(ctx, http.MethodPut, redirectURL, part, nil)
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
switch {
case errors.Is(err, context.Canceled):
return err
case errors.Is(err, errMaxRetriesExceeded):
return err
case err != nil:
part.Reset()
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep)
@@ -279,6 +276,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
case resp.StatusCode == http.StatusUnauthorized:
w.Rollback()
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir)
@@ -289,6 +287,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
opts.Token = token
fallthrough
case resp.StatusCode >= http.StatusBadRequest:
w.Rollback()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
@@ -301,6 +300,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
b.nextURL <- nextURL
}
part.Hash = md5sum
return nil
}
@@ -341,22 +341,26 @@ func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) er
type blobUploadPart struct {
// N is the part number
N int
Offset int64
Size int64
N int
Offset int64
Size int64
hash.Hash
}
type progressWriter struct {
written int64
*blobUpload
}
func (p *blobUploadPart) Write(b []byte) (n int, err error) {
func (p *progressWriter) Write(b []byte) (n int, err error) {
n = len(b)
p.written += int64(n)
p.Completed.Add(int64(n))
return n, nil
}
func (p *blobUploadPart) Reset() {
p.Completed.Add(-int64(p.written))
func (p *progressWriter) Rollback() {
p.Completed.Add(-p.written)
p.written = 0
}