Compare commits
161 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
e3fb1fd3f1 | ||
![]() |
29b897f525 | ||
![]() |
85aeb42869 | ||
![]() |
c5bcf32823 | ||
![]() |
a71ff3f6a2 | ||
![]() |
f0b365a478 | ||
![]() |
df8048fecd | ||
![]() |
da2459d519 | ||
![]() |
bd6d741d87 | ||
![]() |
8b1e791820 | ||
![]() |
03cff3a225 | ||
![]() |
cc509a994e | ||
![]() |
0e79e52ddd | ||
![]() |
6fbb380076 | ||
![]() |
8f8b6288ac | ||
![]() |
b98096389d | ||
![]() |
74a5f7e698 | ||
![]() |
7a1c3e62dc | ||
![]() |
da52f5bfdd | ||
![]() |
50e87c6691 | ||
![]() |
e4a970ece1 | ||
![]() |
4ca43a694c | ||
![]() |
765994362c | ||
![]() |
40a25bf8c3 | ||
![]() |
1c5a8770ee | ||
![]() |
daa0d1de7a | ||
![]() |
58daeb962a | ||
![]() |
528bafa585 | ||
![]() |
81f75696e2 | ||
![]() |
8bdcf894bd | ||
![]() |
fe530423a5 | ||
![]() |
05e390205b | ||
![]() |
872011630a | ||
![]() |
203fdbc4b8 | ||
![]() |
70e0ab6b3d | ||
![]() |
319f078dd9 | ||
![]() |
9968153729 | ||
![]() |
7da249fcc1 | ||
![]() |
f529626c6c | ||
![]() |
36d6081ed1 | ||
![]() |
aadedda486 | ||
![]() |
671eec6da9 | ||
![]() |
e72fe7945f | ||
![]() |
d1c098b038 | ||
![]() |
90ba0b80c7 | ||
![]() |
39bb25d5f6 | ||
![]() |
eadee46840 | ||
![]() |
2e2e624d21 | ||
![]() |
ed832ce3b7 | ||
![]() |
227da16909 | ||
![]() |
bd58528fbd | ||
![]() |
c5e447a359 | ||
![]() |
fc40a4f166 | ||
![]() |
9c7f30d31c | ||
![]() |
6ed3ec0cb3 | ||
![]() |
47bda0b860 | ||
![]() |
c75cafdb58 | ||
![]() |
f5cbcb08e6 | ||
![]() |
67b6f8ba86 | ||
![]() |
184ad8f057 | ||
![]() |
822a0e36eb | ||
![]() |
18b6b601ad | ||
![]() |
0345070dfa | ||
![]() |
dffc8b6e09 | ||
![]() |
0871083776 | ||
![]() |
e5b26c3aa2 | ||
![]() |
3549676678 | ||
![]() |
8fa477fadb | ||
![]() |
fadf75f99d | ||
![]() |
01d155c969 | ||
![]() |
5685c16d4e | ||
![]() |
db77dfe01f | ||
![]() |
ad3a7d0e2c | ||
![]() |
18ffeeec45 | ||
![]() |
688661ab9b | ||
![]() |
36ad90e8e3 | ||
![]() |
6fff59c637 | ||
![]() |
fee7687cf3 | ||
![]() |
d3bfb4889c | ||
![]() |
1ac38ec89c | ||
![]() |
1ad8266473 | ||
![]() |
f5ac8ddfb4 | ||
![]() |
cca61181cb | ||
![]() |
c490416189 | ||
![]() |
f62a882760 | ||
![]() |
3003fc03fc | ||
![]() |
32aec66e6a | ||
![]() |
35af37a2cb | ||
![]() |
dbb3174cbc | ||
![]() |
31673d26d0 | ||
![]() |
8ba0f328af | ||
![]() |
d0e934b497 | ||
![]() |
e751e47d70 | ||
![]() |
19d0f2b4cc | ||
![]() |
c48f07f821 | ||
![]() |
dc642aa07d | ||
![]() |
f1ff892fdd | ||
![]() |
3f2a100465 | ||
![]() |
95397416f3 | ||
![]() |
8a86aae019 | ||
![]() |
24c2c77057 | ||
![]() |
5614984f06 | ||
![]() |
4c1caa3733 | ||
![]() |
12ab8f8f5f | ||
![]() |
8ebbd12f21 | ||
![]() |
07971759fa | ||
![]() |
f5f79049c2 | ||
![]() |
726bc647b2 | ||
![]() |
af9039a167 | ||
![]() |
07ed69bc37 | ||
![]() |
0deb3767fc | ||
![]() |
cb55fa9270 | ||
![]() |
93bc9f17a1 | ||
![]() |
536028c35a | ||
![]() |
aedf3d1f38 | ||
![]() |
91d927abc5 | ||
![]() |
ba8df10a43 | ||
![]() |
abf614804b | ||
![]() |
a0dbbb23c4 | ||
![]() |
0fd6278446 | ||
![]() |
29fe07f0cc | ||
![]() |
abfc73d31e | ||
![]() |
5a5ca8e7ff | ||
![]() |
f24a6f5988 | ||
![]() |
fdbef6c95e | ||
![]() |
24e43e3212 | ||
![]() |
4cb42ca55e | ||
![]() |
ec5e22ac85 | ||
![]() |
ed89da92b4 | ||
![]() |
a3297fed41 | ||
![]() |
88c55199f8 | ||
![]() |
c448443813 | ||
![]() |
efacd45fc5 | ||
![]() |
fa522695c4 | ||
![]() |
8609db77ea | ||
![]() |
65d93a86b2 | ||
![]() |
e6c427ce4d | ||
![]() |
b71c67b6ba | ||
![]() |
6d6b0d3321 | ||
![]() |
37324a0a00 | ||
![]() |
20a5d99f77 | ||
![]() |
3b43cc019a | ||
![]() |
b8421dce3d | ||
![]() |
9f6e97865c | ||
![]() |
9657314ae2 | ||
![]() |
3f7d2336c7 | ||
![]() |
e0a73d7fbe | ||
![]() |
b08c4ca2bd | ||
![]() |
734892f1e2 | ||
![]() |
d2bfaeac63 | ||
![]() |
0768b1b907 | ||
![]() |
f5f0da06d9 | ||
![]() |
52f04e39f2 | ||
![]() |
3c8f4c03d7 | ||
![]() |
7ba1308595 | ||
![]() |
91cd54016c | ||
![]() |
e7a393de54 | ||
![]() |
8454f298ac | ||
![]() |
a3badaf103 | ||
![]() |
09dc6273e3 | ||
![]() |
ebaa33ac28 |
2
.gitignore
vendored
@@ -2,5 +2,7 @@
|
||||
.vscode
|
||||
.env
|
||||
.venv
|
||||
.swp
|
||||
dist
|
||||
ollama
|
||||
/ggml-metal.metal
|
||||
|
45
README.md
@@ -31,14 +31,15 @@ ollama run llama2
|
||||
|
||||
`ollama` includes a library of open-source models:
|
||||
|
||||
| Model | Parameters | Size | Download |
|
||||
| ------------------------ | ---------- | ----- | --------------------------- |
|
||||
| Llama2 | 7B | 3.8GB | `ollama pull llama2` |
|
||||
| Llama2 13B | 13B | 7.3GB | `ollama pull llama2:13b` |
|
||||
| Orca Mini | 3B | 1.9GB | `ollama pull orca` |
|
||||
| Vicuna | 7B | 3.8GB | `ollama pull vicuna` |
|
||||
| Nous-Hermes | 13B | 7.3GB | `ollama pull nous-hermes` |
|
||||
| Wizard Vicuna Uncensored | 13B | 7.3GB | `ollama pull wizard-vicuna` |
|
||||
| Model | Parameters | Size | Download |
|
||||
| ------------------------ | ---------- | ----- | ------------------------------- |
|
||||
| Llama2 | 7B | 3.8GB | `ollama pull llama2` |
|
||||
| Llama2 Uncensored | 7B | 3.8GB | `ollama pull llama2-uncensored` |
|
||||
| Llama2 13B | 13B | 7.3GB | `ollama pull llama2:13b` |
|
||||
| Orca Mini | 3B | 1.9GB | `ollama pull orca` |
|
||||
| Vicuna | 7B | 3.8GB | `ollama pull vicuna` |
|
||||
| Nous-Hermes | 13B | 7.3GB | `ollama pull nous-hermes` |
|
||||
| Wizard Vicuna Uncensored | 13B | 7.3GB | `ollama pull wizard-vicuna` |
|
||||
|
||||
> Note: You should have at least 8 GB of RAM to run the 3B models, 16 GB to run the 7B models, and 32 GB to run the 13B models.
|
||||
|
||||
@@ -59,6 +60,7 @@ Pull a base model:
|
||||
```
|
||||
ollama pull llama2
|
||||
```
|
||||
> To update a model to the latest version, run `ollama pull llama2` again. The model will be updated (if necessary).
|
||||
|
||||
Create a `Modelfile`:
|
||||
|
||||
@@ -85,6 +87,8 @@ Hello! It's your friend Mario.
|
||||
|
||||
For more examples, see the [examples](./examples) directory.
|
||||
|
||||
For more information on creating a Modelfile, see the [Modelfile](./docs/modelfile.md) documentation.
|
||||
|
||||
### Pull a model from the registry
|
||||
|
||||
```
|
||||
@@ -125,3 +129,28 @@ Finally, run a model!
|
||||
```
|
||||
./ollama run llama2
|
||||
```
|
||||
|
||||
## REST API
|
||||
|
||||
### `POST /api/generate`
|
||||
|
||||
Generate text from a model.
|
||||
|
||||
```
|
||||
curl -X POST http://localhost:11434/api/generate -d '{"model": "llama2", "prompt":"Why is the sky blue?"}'
|
||||
```
|
||||
|
||||
### `POST /api/create`
|
||||
|
||||
Create a model from a `Modelfile`.
|
||||
|
||||
```
|
||||
curl -X POST http://localhost:11434/api/create -d '{"name": "my-model", "path": "/path/to/modelfile"}'
|
||||
```
|
||||
|
||||
## Projects built with Ollama
|
||||
|
||||
- [Continue](https://github.com/continuedev/continue) - embeds Ollama inside Visual Studio Code. The extension lets you highlight code to add to the prompt, ask questions in the sidebar, and generate code inline.
|
||||
- [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot) - interact with Ollama as a chatbot on Discord.
|
||||
- [Raycast Ollama](https://github.com/MassimilianoPasquini97/raycast_ollama) - Raycast extension to use Ollama for local llama inference on Raycast.
|
||||
- [Simple HTML UI for Ollama](https://github.com/rtcfirefly/ollama-ui)
|
||||
|
@@ -27,7 +27,7 @@ func checkError(resp *http.Response, body []byte) error {
|
||||
err := json.Unmarshal(body, &apiError)
|
||||
if err != nil {
|
||||
// Use the full body as the message if we fail to decode a response.
|
||||
apiError.Message = string(body)
|
||||
apiError.ErrorMessage = string(body)
|
||||
}
|
||||
|
||||
return apiError
|
||||
@@ -92,7 +92,6 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||
@@ -132,14 +131,14 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
}
|
||||
|
||||
if errorResponse.Error != "" {
|
||||
return fmt.Errorf("stream: %s", errorResponse.Error)
|
||||
return fmt.Errorf(errorResponse.Error)
|
||||
}
|
||||
|
||||
if response.StatusCode >= 400 {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
Message: errorResponse.Error,
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
ErrorMessage: errorResponse.Error,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,11 +189,11 @@ func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc
|
||||
})
|
||||
}
|
||||
|
||||
type CreateProgressFunc func(CreateProgress) error
|
||||
type CreateProgressFunc func(ProgressResponse) error
|
||||
|
||||
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
|
||||
var resp CreateProgress
|
||||
var resp ProgressResponse
|
||||
if err := json.Unmarshal(bts, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -210,3 +209,24 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
|
||||
}
|
||||
return &lr, nil
|
||||
}
|
||||
|
||||
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
|
||||
if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
|
||||
if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Heartbeat(ctx context.Context) error {
|
||||
if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
197
api/types.go
@@ -1,23 +1,35 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type StatusError struct {
|
||||
StatusCode int
|
||||
Status string
|
||||
Message string
|
||||
StatusCode int
|
||||
Status string
|
||||
ErrorMessage string `json:"error"`
|
||||
}
|
||||
|
||||
func (e StatusError) Error() string {
|
||||
if e.Message != "" {
|
||||
return fmt.Sprintf("%s: %s", e.Status, e.Message)
|
||||
switch {
|
||||
case e.Status != "" && e.ErrorMessage != "":
|
||||
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
|
||||
case e.Status != "":
|
||||
return e.Status
|
||||
case e.ErrorMessage != "":
|
||||
return e.ErrorMessage
|
||||
default:
|
||||
// this should not happen
|
||||
return "something went wrong, please see the ollama server logs for details"
|
||||
}
|
||||
return e.Status
|
||||
}
|
||||
|
||||
type GenerateRequest struct {
|
||||
@@ -25,7 +37,7 @@ type GenerateRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
Options `json:"options"`
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
@@ -33,25 +45,32 @@ type CreateRequest struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type CreateProgress struct {
|
||||
Status string `json:"status"`
|
||||
type DeleteRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type CopyRequest struct {
|
||||
Source string `json:"source"`
|
||||
Destination string `json:"destination"`
|
||||
}
|
||||
|
||||
type PullRequest struct {
|
||||
Name string `json:"name"`
|
||||
Insecure bool `json:"insecure,omitempty"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type ProgressResponse struct {
|
||||
Status string `json:"status"`
|
||||
Digest string `json:"digest,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Completed int `json:"completed,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Digest string `json:"digest,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Completed int `json:"completed,omitempty"`
|
||||
}
|
||||
|
||||
type PushRequest struct {
|
||||
Name string `json:"name"`
|
||||
Insecure bool `json:"insecure,omitempty"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
@@ -75,6 +94,9 @@ type GenerateResponse struct {
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
SampleCount int `json:"sample_count,omitempty"`
|
||||
SampleDuration time.Duration `json:"sample_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
@@ -86,6 +108,19 @@ func (r *GenerateResponse) Summary() {
|
||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
|
||||
}
|
||||
|
||||
if r.LoadDuration > 0 {
|
||||
fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration)
|
||||
}
|
||||
|
||||
if r.SampleCount > 0 {
|
||||
fmt.Fprintf(os.Stderr, "sample count: %d token(s)\n", r.SampleCount)
|
||||
}
|
||||
|
||||
if r.SampleDuration > 0 {
|
||||
fmt.Fprintf(os.Stderr, "sample duration: %s\n", r.SampleDuration)
|
||||
fmt.Fprintf(os.Stderr, "sample rate: %.2f tokens/s\n", float64(r.SampleCount)/r.SampleDuration.Seconds())
|
||||
}
|
||||
|
||||
if r.PromptEvalCount > 0 {
|
||||
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
|
||||
}
|
||||
@@ -113,7 +148,9 @@ type Options struct {
|
||||
|
||||
// Model options
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
NumKeep int `json:"num_keep,omitempty"`
|
||||
NumBatch int `json:"num_batch,omitempty"`
|
||||
NumGQA int `json:"num_gqa,omitempty"`
|
||||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
LowVRAM bool `json:"low_vram,omitempty"`
|
||||
@@ -125,22 +162,99 @@ type Options struct {
|
||||
EmbeddingOnly bool `json:"embedding_only,omitempty"`
|
||||
|
||||
// Predict options
|
||||
RepeatLastN int `json:"repeat_last_n,omitempty"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
TFSZ float32 `json:"tfs_z,omitempty"`
|
||||
TypicalP float32 `json:"typical_p,omitempty"`
|
||||
Mirostat int `json:"mirostat,omitempty"`
|
||||
MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
||||
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
||||
RepeatLastN int `json:"repeat_last_n,omitempty"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
TFSZ float32 `json:"tfs_z,omitempty"`
|
||||
TypicalP float32 `json:"typical_p,omitempty"`
|
||||
Mirostat int `json:"mirostat,omitempty"`
|
||||
MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
||||
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
||||
PenalizeNewline bool `json:"penalize_newline,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
}
|
||||
|
||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||
|
||||
// build map of json struct tags to their types
|
||||
jsonOpts := make(map[string]reflect.StructField)
|
||||
for _, field := range reflect.VisibleFields(typeOpts) {
|
||||
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
|
||||
if jsonTag != "" {
|
||||
jsonOpts[jsonTag] = field
|
||||
}
|
||||
}
|
||||
|
||||
for key, val := range m {
|
||||
if opt, ok := jsonOpts[key]; ok {
|
||||
field := valueOpts.FieldByName(opt.Name)
|
||||
if field.IsValid() && field.CanSet() {
|
||||
switch field.Kind() {
|
||||
case reflect.Int:
|
||||
// when JSON unmarshals numbers, it uses float64 by default, not int
|
||||
val, ok := val.(float64)
|
||||
if !ok {
|
||||
log.Printf("could not convert model parmeter %v to int, skipped", key)
|
||||
continue
|
||||
}
|
||||
field.SetInt(int64(val))
|
||||
case reflect.Bool:
|
||||
val, ok := val.(bool)
|
||||
if !ok {
|
||||
log.Printf("could not convert model parmeter %v to bool, skipped", key)
|
||||
continue
|
||||
}
|
||||
field.SetBool(val)
|
||||
case reflect.Float32:
|
||||
// JSON unmarshals to float64
|
||||
val, ok := val.(float64)
|
||||
if !ok {
|
||||
log.Printf("could not convert model parmeter %v to float32, skipped", key)
|
||||
continue
|
||||
}
|
||||
field.SetFloat(val)
|
||||
case reflect.String:
|
||||
val, ok := val.(string)
|
||||
if !ok {
|
||||
log.Printf("could not convert model parmeter %v to string, skipped", key)
|
||||
continue
|
||||
}
|
||||
field.SetString(val)
|
||||
case reflect.Slice:
|
||||
// JSON unmarshals to []interface{}, not []string
|
||||
val, ok := val.([]interface{})
|
||||
if !ok {
|
||||
log.Printf("could not convert model parmeter %v to slice, skipped", key)
|
||||
continue
|
||||
}
|
||||
// convert []interface{} to []string
|
||||
slice := make([]string, len(val))
|
||||
for i, item := range val {
|
||||
str, ok := item.(string)
|
||||
if !ok {
|
||||
log.Printf("could not convert model parmeter %v to slice of strings, skipped", key)
|
||||
continue
|
||||
}
|
||||
slice[i] = str
|
||||
}
|
||||
field.Set(reflect.ValueOf(slice))
|
||||
default:
|
||||
return fmt.Errorf("unknown type loading config params: %v", field.Kind())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DefaultOptions() Options {
|
||||
return Options{
|
||||
Seed: -1,
|
||||
@@ -150,12 +264,13 @@ func DefaultOptions() Options {
|
||||
NumCtx: 2048,
|
||||
NumBatch: 512,
|
||||
NumGPU: 1,
|
||||
NumGQA: 1,
|
||||
LowVRAM: false,
|
||||
F16KV: true,
|
||||
UseMMap: true,
|
||||
UseMLock: false,
|
||||
|
||||
RepeatLastN: 512,
|
||||
RepeatLastN: 64,
|
||||
RepeatPenalty: 1.1,
|
||||
FrequencyPenalty: 0.0,
|
||||
PresencePenalty: 0.0,
|
||||
@@ -167,7 +282,37 @@ func DefaultOptions() Options {
|
||||
Mirostat: 0,
|
||||
MirostatTau: 5.0,
|
||||
MirostatEta: 0.1,
|
||||
PenalizeNewline: true,
|
||||
|
||||
NumThread: runtime.NumCPU(),
|
||||
}
|
||||
}
|
||||
|
||||
type Duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
var v any
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.Duration = 5 * time.Minute
|
||||
|
||||
switch t := v.(type) {
|
||||
case float64:
|
||||
if t < 0 {
|
||||
t = math.MaxFloat64
|
||||
}
|
||||
|
||||
d.Duration = time.Duration(t)
|
||||
case string:
|
||||
d.Duration, err = time.ParseDuration(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -1,7 +1,5 @@
|
||||
# Desktop
|
||||
|
||||
_Note: the Ollama desktop app is a work in progress and is not ready yet for general use._
|
||||
|
||||
This app builds upon Ollama to provide a desktop experience for running models.
|
||||
|
||||
## Developing
|
||||
@@ -9,19 +7,15 @@ This app builds upon Ollama to provide a desktop experience for running models.
|
||||
First, build the `ollama` binary:
|
||||
|
||||
```
|
||||
make -C ..
|
||||
cd ..
|
||||
go build .
|
||||
```
|
||||
|
||||
Then run the desktop app with `npm start`:
|
||||
|
||||
```
|
||||
cd app
|
||||
npm install
|
||||
npm start
|
||||
```
|
||||
|
||||
## Coming soon
|
||||
|
||||
- Browse the latest available models on Hugging Face and other sources
|
||||
- Keep track of previous conversations with models
|
||||
- Switch quickly between models
|
||||
- Connect to remote Ollama servers to run models
|
||||
|
BIN
app/assets/iconDarkTemplate.png
Normal file
After Width: | Height: | Size: 402 B |
Before Width: | Height: | Size: 741 B After Width: | Height: | Size: 741 B |
BIN
app/assets/iconDarkUpdateTemplate.png
Normal file
After Width: | Height: | Size: 440 B |
BIN
app/assets/iconDarkUpdateTemplate@2x.png
Normal file
After Width: | Height: | Size: 763 B |
BIN
app/assets/iconTemplate.png
Normal file
After Width: | Height: | Size: 447 B |
Before Width: | Height: | Size: 891 B After Width: | Height: | Size: 891 B |
BIN
app/assets/iconUpdateTemplate.png
Normal file
After Width: | Height: | Size: 443 B |
BIN
app/assets/iconUpdateTemplate@2x.png
Normal file
After Width: | Height: | Size: 844 B |
Before Width: | Height: | Size: 403 B |
Before Width: | Height: | Size: 445 B |
@@ -18,11 +18,15 @@ const config: ForgeConfig = {
|
||||
asar: true,
|
||||
icon: './assets/icon.icns',
|
||||
extraResource: [
|
||||
'../ollama',
|
||||
path.join(__dirname, './assets/ollama_icon_16x16Template.png'),
|
||||
path.join(__dirname, './assets/ollama_icon_16x16Template@2x.png'),
|
||||
path.join(__dirname, './assets/ollama_outline_icon_16x16Template.png'),
|
||||
path.join(__dirname, './assets/ollama_outline_icon_16x16Template@2x.png'),
|
||||
'../dist/ollama',
|
||||
path.join(__dirname, './assets/iconTemplate.png'),
|
||||
path.join(__dirname, './assets/iconTemplate@2x.png'),
|
||||
path.join(__dirname, './assets/iconUpdateTemplate.png'),
|
||||
path.join(__dirname, './assets/iconUpdateTemplate@2x.png'),
|
||||
path.join(__dirname, './assets/iconDarkTemplate.png'),
|
||||
path.join(__dirname, './assets/iconDarkTemplate@2x.png'),
|
||||
path.join(__dirname, './assets/iconDarkUpdateTemplate.png'),
|
||||
path.join(__dirname, './assets/iconDarkUpdateTemplate@2x.png'),
|
||||
...(process.platform === 'darwin' ? ['../llama/ggml-metal.metal'] : []),
|
||||
],
|
||||
...(process.env.SIGN
|
||||
@@ -38,6 +42,9 @@ const config: ForgeConfig = {
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
osxUniversal: {
|
||||
x64ArchFiles: '**/ollama',
|
||||
},
|
||||
},
|
||||
rebuildConfig: {},
|
||||
makers: [new MakerSquirrel({}), new MakerZIP({}, ['darwin'])],
|
||||
|
7
app/package-lock.json
generated
@@ -32,6 +32,7 @@
|
||||
"@electron-forge/plugin-auto-unpack-natives": "^6.2.1",
|
||||
"@electron-forge/plugin-webpack": "^6.2.1",
|
||||
"@electron-forge/publisher-github": "^6.2.1",
|
||||
"@electron/universal": "^1.4.1",
|
||||
"@svgr/webpack": "^8.0.1",
|
||||
"@types/chmodr": "^1.0.0",
|
||||
"@types/node": "^20.4.0",
|
||||
@@ -3328,9 +3329,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@electron/universal": {
|
||||
"version": "1.3.4",
|
||||
"resolved": "https://registry.npmjs.org/@electron/universal/-/universal-1.3.4.tgz",
|
||||
"integrity": "sha512-BdhBgm2ZBnYyYRLRgOjM5VHkyFItsbggJ0MHycOjKWdFGYwK97ZFXH54dTvUWEfha81vfvwr5On6XBjt99uDcg==",
|
||||
"version": "1.4.1",
|
||||
"resolved": "https://registry.npmjs.org/@electron/universal/-/universal-1.4.1.tgz",
|
||||
"integrity": "sha512-lE/U3UNw1YHuowNbTmKNs9UlS3En3cPgwM5MI+agIgr/B1hSze9NdOP0qn7boZaI9Lph8IDv3/24g9IxnJP7aQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@electron/asar": "^3.2.1",
|
||||
|
@@ -6,12 +6,14 @@
|
||||
"main": ".webpack/main",
|
||||
"scripts": {
|
||||
"start": "electron-forge start",
|
||||
"package": "electron-forge package",
|
||||
"package:sign": "SIGN=1 electron-forge package",
|
||||
"make": "electron-forge make",
|
||||
"make:sign": "SIGN=1 electron-forge make",
|
||||
"package": "electron-forge package --arch universal",
|
||||
"package:sign": "SIGN=1 electron-forge package --arch universal",
|
||||
"make": "electron-forge make --arch universal",
|
||||
"make:sign": "SIGN=1 electron-forge make --arch universal",
|
||||
"publish": "SIGN=1 electron-forge publish",
|
||||
"lint": "eslint --ext .ts,.tsx ."
|
||||
"lint": "eslint --ext .ts,.tsx .",
|
||||
"format": "prettier --check . --ignore-path .gitignore",
|
||||
"format:fix": "prettier --write . --ignore-path .gitignore"
|
||||
},
|
||||
"keywords": [],
|
||||
"author": {
|
||||
@@ -30,6 +32,7 @@
|
||||
"@electron-forge/plugin-auto-unpack-natives": "^6.2.1",
|
||||
"@electron-forge/plugin-webpack": "^6.2.1",
|
||||
"@electron-forge/publisher-github": "^6.2.1",
|
||||
"@electron/universal": "^1.4.1",
|
||||
"@svgr/webpack": "^8.0.1",
|
||||
"@types/chmodr": "^1.0.0",
|
||||
"@types/node": "^20.4.0",
|
||||
|
@@ -2,7 +2,7 @@ import { useState } from 'react'
|
||||
import copy from 'copy-to-clipboard'
|
||||
import { CheckIcon, DocumentDuplicateIcon } from '@heroicons/react/24/outline'
|
||||
import Store from 'electron-store'
|
||||
import { getCurrentWindow } from '@electron/remote'
|
||||
import { getCurrentWindow, app } from '@electron/remote'
|
||||
|
||||
import { install } from './install'
|
||||
import OllamaIcon from './ollama.svg'
|
||||
@@ -51,10 +51,15 @@ export default function () {
|
||||
<div className='mx-auto'>
|
||||
<button
|
||||
onClick={async () => {
|
||||
await install()
|
||||
getCurrentWindow().show()
|
||||
getCurrentWindow().focus()
|
||||
setStep(Step.FINISH)
|
||||
try {
|
||||
await install()
|
||||
setStep(Step.FINISH)
|
||||
} catch (e) {
|
||||
console.error('could not install: ', e)
|
||||
} finally {
|
||||
getCurrentWindow().show()
|
||||
getCurrentWindow().focus()
|
||||
}
|
||||
}}
|
||||
className='no-drag rounded-dm mx-auto w-[60%] rounded-md bg-black px-4 py-2 text-sm text-white hover:brightness-110'
|
||||
>
|
||||
|
6
app/src/declarations.d.ts
vendored
@@ -1,4 +1,4 @@
|
||||
declare module '*.svg' {
|
||||
const content: string;
|
||||
export default content;
|
||||
}
|
||||
const content: string
|
||||
export default content
|
||||
}
|
||||
|
195
app/src/index.ts
@@ -1,5 +1,5 @@
|
||||
import { spawn } from 'child_process'
|
||||
import { app, autoUpdater, dialog, Tray, Menu, BrowserWindow, nativeTheme } from 'electron'
|
||||
import { spawn, ChildProcess } from 'child_process'
|
||||
import { app, autoUpdater, dialog, Tray, Menu, BrowserWindow, MenuItemConstructorOptions, nativeTheme } from 'electron'
|
||||
import Store from 'electron-store'
|
||||
import winston from 'winston'
|
||||
import 'winston-daily-rotate-file'
|
||||
@@ -10,8 +10,12 @@ import { installed } from './install'
|
||||
|
||||
require('@electron/remote/main').initialize()
|
||||
|
||||
if (require('electron-squirrel-startup')) {
|
||||
app.quit()
|
||||
}
|
||||
|
||||
const store = new Store()
|
||||
let tray: Tray | null = null
|
||||
|
||||
let welcomeWindow: BrowserWindow | null = null
|
||||
|
||||
declare const MAIN_WINDOW_WEBPACK_ENTRY: string
|
||||
@@ -28,10 +32,30 @@ const logger = winston.createLogger({
|
||||
format: winston.format.printf(info => info.message),
|
||||
})
|
||||
|
||||
const SingleInstanceLock = app.requestSingleInstanceLock()
|
||||
if (!SingleInstanceLock) {
|
||||
app.quit()
|
||||
}
|
||||
app.on('ready', () => {
|
||||
const gotTheLock = app.requestSingleInstanceLock()
|
||||
if (!gotTheLock) {
|
||||
app.exit(0)
|
||||
return
|
||||
}
|
||||
|
||||
app.on('second-instance', () => {
|
||||
if (app.hasSingleInstanceLock()) {
|
||||
app.releaseSingleInstanceLock()
|
||||
}
|
||||
|
||||
if (proc) {
|
||||
proc.off('exit', restart)
|
||||
proc.kill()
|
||||
}
|
||||
|
||||
app.exit(0)
|
||||
})
|
||||
|
||||
app.focus({ steal: true })
|
||||
|
||||
init()
|
||||
})
|
||||
|
||||
function firstRunWindow() {
|
||||
// Create the browser window.
|
||||
@@ -52,60 +76,70 @@ function firstRunWindow() {
|
||||
|
||||
require('@electron/remote/main').enable(welcomeWindow.webContents)
|
||||
|
||||
// and load the index.html of the app.
|
||||
welcomeWindow.loadURL(MAIN_WINDOW_WEBPACK_ENTRY)
|
||||
|
||||
welcomeWindow.on('ready-to-show', () => welcomeWindow.show())
|
||||
|
||||
// for debugging
|
||||
// welcomeWindow.webContents.openDevTools()
|
||||
|
||||
if (process.platform === 'darwin') {
|
||||
app.dock.hide()
|
||||
}
|
||||
}
|
||||
|
||||
function createSystemtray() {
|
||||
let iconPath = nativeTheme.shouldUseDarkColors
|
||||
? path.join(__dirname, '..', '..', 'assets', 'ollama_icon_16x16Template.png')
|
||||
: path.join(__dirname, '..', '..', 'assets', 'ollama_outline_icon_16x16Template.png')
|
||||
|
||||
if (app.isPackaged) {
|
||||
iconPath = nativeTheme.shouldUseDarkColors
|
||||
? path.join(process.resourcesPath, 'ollama_icon_16x16Template.png')
|
||||
: path.join(process.resourcesPath, 'ollama_outline_icon_16x16Template.png')
|
||||
}
|
||||
|
||||
tray = new Tray(iconPath)
|
||||
|
||||
nativeTheme.on('updated', function theThemeHasChanged () {
|
||||
if (nativeTheme.shouldUseDarkColors) {
|
||||
app.isPackaged
|
||||
? tray.setImage(path.join(process.resourcesPath, 'ollama_icon_16x16Template.png'))
|
||||
: tray.setImage(path.join(__dirname, '..', '..', 'assets', 'ollama_icon_16x16Template.png'))
|
||||
} else {
|
||||
app.isPackaged
|
||||
? tray.setImage(path.join(process.resourcesPath, 'ollama_outline_icon_16x16Template.png'))
|
||||
: tray.setImage(path.join(__dirname, '..', '..', 'assets', 'ollama_outline_icon_16x16Template.png'))
|
||||
welcomeWindow.on('closed', () => {
|
||||
if (process.platform === 'darwin') {
|
||||
app.dock.hide()
|
||||
}
|
||||
})
|
||||
|
||||
const contextMenu = Menu.buildFromTemplate([{ role: 'quit', label: 'Quit Ollama', accelerator: 'Command+Q' }])
|
||||
|
||||
tray.setContextMenu(contextMenu)
|
||||
tray.setToolTip('Ollama')
|
||||
}
|
||||
|
||||
if (require('electron-squirrel-startup')) {
|
||||
app.quit()
|
||||
let tray: Tray | null = null
|
||||
let updateAvailable = false
|
||||
const assetPath = app.isPackaged ? process.resourcesPath : path.join(__dirname, '..', '..', 'assets')
|
||||
|
||||
function trayIconPath() {
|
||||
return nativeTheme.shouldUseDarkColors
|
||||
? updateAvailable
|
||||
? path.join(assetPath, 'iconDarkUpdateTemplate.png')
|
||||
: path.join(assetPath, 'iconDarkTemplate.png')
|
||||
: updateAvailable
|
||||
? path.join(assetPath, 'iconUpdateTemplate.png')
|
||||
: path.join(assetPath, 'iconTemplate.png')
|
||||
}
|
||||
|
||||
function updateTrayIcon() {
|
||||
if (tray) {
|
||||
tray.setImage(trayIconPath())
|
||||
}
|
||||
}
|
||||
|
||||
function updateTray() {
|
||||
const updateItems: MenuItemConstructorOptions[] = [
|
||||
{ label: 'An update is available', enabled: false },
|
||||
{
|
||||
label: 'Restart to update',
|
||||
click: () => autoUpdater.quitAndInstall(),
|
||||
},
|
||||
{ type: 'separator' },
|
||||
]
|
||||
|
||||
const menu = Menu.buildFromTemplate([
|
||||
...(updateAvailable ? updateItems : []),
|
||||
{ role: 'quit', label: 'Quit Ollama', accelerator: 'Command+Q' },
|
||||
])
|
||||
|
||||
if (!tray) {
|
||||
tray = new Tray(trayIconPath())
|
||||
}
|
||||
|
||||
tray.setToolTip(updateAvailable ? 'An update is available' : 'Ollama')
|
||||
tray.setContextMenu(menu)
|
||||
tray.setImage(trayIconPath())
|
||||
|
||||
nativeTheme.off('updated', updateTrayIcon)
|
||||
nativeTheme.on('updated', updateTrayIcon)
|
||||
}
|
||||
|
||||
let proc: ChildProcess = null
|
||||
|
||||
function server() {
|
||||
const binary = app.isPackaged
|
||||
? path.join(process.resourcesPath, 'ollama')
|
||||
: path.resolve(process.cwd(), '..', 'ollama')
|
||||
|
||||
const proc = spawn(binary, ['serve'])
|
||||
proc = spawn(binary, ['serve'])
|
||||
|
||||
proc.stdout.on('data', data => {
|
||||
logger.info(data.toString().trim())
|
||||
@@ -115,24 +149,32 @@ function server() {
|
||||
logger.error(data.toString().trim())
|
||||
})
|
||||
|
||||
function restart() {
|
||||
logger.info('Restarting the server...')
|
||||
server()
|
||||
}
|
||||
|
||||
proc.on('exit', restart)
|
||||
}
|
||||
|
||||
app.on('before-quit', () => {
|
||||
function restart() {
|
||||
setTimeout(server, 1000)
|
||||
}
|
||||
|
||||
app.on('before-quit', () => {
|
||||
if (proc) {
|
||||
proc.off('exit', restart)
|
||||
proc.kill()
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if (process.platform === 'darwin') {
|
||||
app.dock.hide()
|
||||
}
|
||||
function init() {
|
||||
if (app.isPackaged) {
|
||||
heartbeat()
|
||||
autoUpdater.checkForUpdates()
|
||||
setInterval(() => {
|
||||
heartbeat()
|
||||
autoUpdater.checkForUpdates()
|
||||
}, 60 * 60 * 1000)
|
||||
}
|
||||
|
||||
updateTray()
|
||||
|
||||
app.on('ready', () => {
|
||||
if (process.platform === 'darwin') {
|
||||
if (app.isPackaged) {
|
||||
if (!app.isInApplicationsFolder()) {
|
||||
@@ -168,10 +210,13 @@ app.on('ready', () => {
|
||||
}
|
||||
}
|
||||
|
||||
createSystemtray()
|
||||
server()
|
||||
|
||||
if (store.get('first-time-run') && installed()) {
|
||||
if (process.platform === 'darwin') {
|
||||
app.dock.hide()
|
||||
}
|
||||
|
||||
app.setLoginItemSettings({ openAtLogin: app.getLoginItemSettings().openAtLogin })
|
||||
return
|
||||
}
|
||||
@@ -179,7 +224,7 @@ app.on('ready', () => {
|
||||
// This is the first run or the CLI is no longer installed
|
||||
app.setLoginItemSettings({ openAtLogin: true })
|
||||
firstRunWindow()
|
||||
})
|
||||
}
|
||||
|
||||
// Quit when all windows are closed, except on macOS. There, it's common
|
||||
// for applications and their menu bar to stay active until the user quits
|
||||
@@ -206,29 +251,11 @@ async function heartbeat() {
|
||||
})
|
||||
}
|
||||
|
||||
if (app.isPackaged) {
|
||||
heartbeat()
|
||||
autoUpdater.checkForUpdates()
|
||||
setInterval(() => {
|
||||
heartbeat()
|
||||
autoUpdater.checkForUpdates()
|
||||
}, 60 * 60 * 1000)
|
||||
}
|
||||
|
||||
autoUpdater.on('error', e => {
|
||||
logger.error(`update check failed - ${e.message}`)
|
||||
console.error(`update check failed - ${e.message}`)
|
||||
})
|
||||
|
||||
autoUpdater.on('update-downloaded', (event, releaseNotes, releaseName) => {
|
||||
dialog
|
||||
.showMessageBox({
|
||||
type: 'info',
|
||||
buttons: ['Restart Now', 'Later'],
|
||||
title: 'New update available',
|
||||
message: process.platform === 'win32' ? releaseNotes : releaseName,
|
||||
detail: 'A new version of Ollama is available. Restart to apply the update.',
|
||||
})
|
||||
.then(returnValue => {
|
||||
if (returnValue.response === 0) autoUpdater.quitAndInstall()
|
||||
})
|
||||
autoUpdater.on('update-downloaded', () => {
|
||||
updateAvailable = true
|
||||
updateTray()
|
||||
})
|
||||
|
@@ -15,12 +15,7 @@ export function installed() {
|
||||
export async function install() {
|
||||
const command = `do shell script "mkdir -p ${path.dirname(
|
||||
symlinkPath
|
||||
)} && ln -F -s ${ollama} ${symlinkPath}" with administrator privileges`
|
||||
)} && ln -F -s \\"${ollama}\\" \\"${symlinkPath}\\"" with administrator privileges`
|
||||
|
||||
try {
|
||||
await exec(`osascript -e '${command}'`)
|
||||
} catch (error) {
|
||||
console.error(`cli: failed to install cli: ${error.message}`)
|
||||
return
|
||||
}
|
||||
await exec(`osascript -e '${command}'`)
|
||||
}
|
||||
|
329
cmd/cmd.go
@@ -10,7 +10,9 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -25,7 +27,7 @@ import (
|
||||
"github.com/jmorganca/ollama/server"
|
||||
)
|
||||
|
||||
func create(cmd *cobra.Command, args []string) error {
|
||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
filename, _ := cmd.Flags().GetString("file")
|
||||
filename, err := filepath.Abs(filename)
|
||||
if err != nil {
|
||||
@@ -36,15 +38,32 @@ func create(cmd *cobra.Command, args []string) error {
|
||||
|
||||
var spinner *Spinner
|
||||
|
||||
var currentDigest string
|
||||
var bar *progressbar.ProgressBar
|
||||
|
||||
request := api.CreateRequest{Name: args[0], Path: filename}
|
||||
fn := func(resp api.CreateProgress) error {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != currentDigest && resp.Digest != "" {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
currentDigest = resp.Digest
|
||||
bar = progressbar.DefaultBytes(
|
||||
int64(resp.Total),
|
||||
fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
|
||||
)
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
||||
bar.Set(resp.Completed)
|
||||
} else {
|
||||
currentDigest = ""
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
spinner = NewSpinner(resp.Status)
|
||||
go spinner.Spin(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
spinner = NewSpinner(resp.Status)
|
||||
go spinner.Spin(100 * time.Millisecond)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -59,7 +78,7 @@ func create(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunRun(cmd *cobra.Command, args []string) error {
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
mp := server.ParseModelPath(args[0])
|
||||
fp, err := mp.GetManifestPath(false)
|
||||
if err != nil {
|
||||
@@ -69,7 +88,7 @@ func RunRun(cmd *cobra.Command, args []string) error {
|
||||
_, err = os.Stat(fp)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := pull(args[0]); err != nil {
|
||||
if err := pull(args[0], false); err != nil {
|
||||
var apiStatusError api.StatusError
|
||||
if !errors.As(err, &apiStatusError) {
|
||||
return err
|
||||
@@ -86,12 +105,33 @@ func RunRun(cmd *cobra.Command, args []string) error {
|
||||
return RunGenerate(cmd, args)
|
||||
}
|
||||
|
||||
func push(cmd *cobra.Command, args []string) error {
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
client := api.NewClient()
|
||||
|
||||
request := api.PushRequest{Name: args[0]}
|
||||
insecure, err := cmd.Flags().GetBool("insecure")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var currentDigest string
|
||||
var bar *progressbar.ProgressBar
|
||||
|
||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
fmt.Println(resp.Status)
|
||||
if resp.Digest != currentDigest && resp.Digest != "" {
|
||||
currentDigest = resp.Digest
|
||||
bar = progressbar.DefaultBytes(
|
||||
int64(resp.Total),
|
||||
fmt.Sprintf("pushing %s...", resp.Digest[7:19]),
|
||||
)
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
||||
bar.Set(resp.Completed)
|
||||
} else {
|
||||
currentDigest = ""
|
||||
fmt.Println(resp.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -101,7 +141,7 @@ func push(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func list(cmd *cobra.Command, args []string) error {
|
||||
func ListHandler(cmd *cobra.Command, args []string) error {
|
||||
client := api.NewClient()
|
||||
|
||||
models, err := client.List(context.Background())
|
||||
@@ -131,17 +171,44 @@ func list(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunPull(cmd *cobra.Command, args []string) error {
|
||||
return pull(args[0])
|
||||
func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
client := api.NewClient()
|
||||
|
||||
req := api.DeleteRequest{Name: args[0]}
|
||||
if err := client.Delete(context.Background(), &req); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("deleted '%s'\n", args[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
func pull(model string) error {
|
||||
func CopyHandler(cmd *cobra.Command, args []string) error {
|
||||
client := api.NewClient()
|
||||
|
||||
req := api.CopyRequest{Source: args[0], Destination: args[1]}
|
||||
if err := client.Copy(context.Background(), &req); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("copied '%s' to '%s'\n", args[0], args[1])
|
||||
return nil
|
||||
}
|
||||
|
||||
func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
insecure, err := cmd.Flags().GetBool("insecure")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return pull(args[0], insecure)
|
||||
}
|
||||
|
||||
func pull(model string, insecure bool) error {
|
||||
client := api.NewClient()
|
||||
|
||||
var currentDigest string
|
||||
var bar *progressbar.ProgressBar
|
||||
|
||||
request := api.PullRequest{Name: model}
|
||||
request := api.PullRequest{Name: model, Insecure: insecure}
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != currentDigest && resp.Digest != "" {
|
||||
currentDigest = resp.Digest
|
||||
@@ -179,7 +246,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||
return generateBatch(cmd, args[0])
|
||||
}
|
||||
|
||||
var generateContextKey struct{}
|
||||
type generateContextKey string
|
||||
|
||||
func generate(cmd *cobra.Command, model, prompt string) error {
|
||||
if len(strings.TrimSpace(prompt)) > 0 {
|
||||
@@ -190,26 +257,36 @@ func generate(cmd *cobra.Command, model, prompt string) error {
|
||||
|
||||
var latest api.GenerateResponse
|
||||
|
||||
generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
|
||||
generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
|
||||
if !ok {
|
||||
generateContext = []int{}
|
||||
}
|
||||
|
||||
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
|
||||
fn := func(resp api.GenerateResponse) error {
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
if !spinner.IsFinished() {
|
||||
spinner.Finish()
|
||||
}
|
||||
|
||||
latest = resp
|
||||
latest = response
|
||||
|
||||
fmt.Print(resp.Response)
|
||||
|
||||
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
|
||||
fmt.Print(response.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := client.Generate(context.Background(), &request, fn); err != nil {
|
||||
if strings.Contains(err.Error(), "failed to load model") {
|
||||
// tell the user to check the server log, if it exists locally
|
||||
home, nestedErr := os.UserHomeDir()
|
||||
if nestedErr != nil {
|
||||
// return the original error
|
||||
return err
|
||||
}
|
||||
logPath := filepath.Join(home, ".ollama", "logs", "server.log")
|
||||
if _, nestedErr := os.Stat(logPath); nestedErr == nil {
|
||||
err = fmt.Errorf("%w\nFor more details, check the error logs at %s", err, logPath)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -224,11 +301,25 @@ func generate(cmd *cobra.Command, model, prompt string) error {
|
||||
if verbose {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
ctx := cmd.Context()
|
||||
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
|
||||
cmd.SetContext(ctx)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func showLayer(l *server.Layer) {
|
||||
filename, err := server.GetBlobsPath(l.Digest)
|
||||
bts, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
fmt.Printf("Couldn't read layer")
|
||||
return
|
||||
}
|
||||
fmt.Printf(string(bts) + "\n")
|
||||
}
|
||||
|
||||
func generateInteractive(cmd *cobra.Command, model string) error {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
@@ -249,6 +340,11 @@ func generateInteractive(cmd *cobra.Command, model string) error {
|
||||
readline.PcItem("default"),
|
||||
),
|
||||
),
|
||||
readline.PcItem("/show",
|
||||
readline.PcItem("license"),
|
||||
readline.PcItem("system"),
|
||||
readline.PcItem("template"),
|
||||
),
|
||||
readline.PcItem("/exit"),
|
||||
readline.PcItem("/bye"),
|
||||
)
|
||||
@@ -270,6 +366,9 @@ func generateInteractive(cmd *cobra.Command, model string) error {
|
||||
}
|
||||
defer scanner.Close()
|
||||
|
||||
var multiLineBuffer string
|
||||
var isMultiLine bool
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
@@ -288,9 +387,25 @@ func generateInteractive(cmd *cobra.Command, model string) error {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
switch {
|
||||
case isMultiLine:
|
||||
if strings.HasSuffix(line, `"""`) {
|
||||
isMultiLine = false
|
||||
multiLineBuffer += strings.TrimSuffix(line, `"""`)
|
||||
line = multiLineBuffer
|
||||
multiLineBuffer = ""
|
||||
scanner.SetPrompt(">>> ")
|
||||
} else {
|
||||
multiLineBuffer += line + " "
|
||||
continue
|
||||
}
|
||||
case strings.HasPrefix(line, `"""`):
|
||||
isMultiLine = true
|
||||
multiLineBuffer = strings.TrimPrefix(line, `"""`) + " "
|
||||
scanner.SetPrompt("... ")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/list"):
|
||||
args := strings.Fields(line)
|
||||
if err := list(cmd, args[1:]); err != nil {
|
||||
if err := ListHandler(cmd, args[1:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -320,9 +435,57 @@ func generateInteractive(cmd *cobra.Command, model string) error {
|
||||
case "emacs", "default":
|
||||
scanner.SetVimMode(false)
|
||||
continue
|
||||
default:
|
||||
usage()
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
usage()
|
||||
continue
|
||||
}
|
||||
}
|
||||
} else {
|
||||
usage()
|
||||
continue
|
||||
}
|
||||
case strings.HasPrefix(line, "/show"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
mp := server.ParseModelPath(model)
|
||||
manifest, err := server.GetManifest(mp)
|
||||
if err != nil {
|
||||
fmt.Printf("error: couldn't get a manifestfor this model")
|
||||
continue
|
||||
}
|
||||
switch args[1] {
|
||||
case "license":
|
||||
for _, l := range manifest.Layers {
|
||||
if l.MediaType == "application/vnd.ollama.image.license" {
|
||||
showLayer(l)
|
||||
}
|
||||
}
|
||||
continue
|
||||
case "system":
|
||||
for _, l := range manifest.Layers {
|
||||
if l.MediaType == "application/vnd.ollama.image.system" {
|
||||
showLayer(l)
|
||||
}
|
||||
}
|
||||
continue
|
||||
case "template":
|
||||
for _, l := range manifest.Layers {
|
||||
if l.MediaType == "application/vnd.ollama.image.template" {
|
||||
showLayer(l)
|
||||
}
|
||||
}
|
||||
continue
|
||||
default:
|
||||
usage()
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
usage()
|
||||
continue
|
||||
}
|
||||
case line == "/help", line == "/?":
|
||||
usage()
|
||||
@@ -369,6 +532,54 @@ func RunServer(_ *cobra.Command, _ []string) error {
|
||||
return server.Serve(ln)
|
||||
}
|
||||
|
||||
func startMacApp(client *api.Client) error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
link, err := os.Readlink(exe)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !strings.Contains(link, "Ollama.app") {
|
||||
return fmt.Errorf("could not find ollama app")
|
||||
}
|
||||
path := strings.Split(link, "Ollama.app")
|
||||
if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
// wait for the server to start
|
||||
timeout := time.After(5 * time.Second)
|
||||
tick := time.Tick(500 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
return errors.New("timed out waiting for server to start")
|
||||
case <-tick:
|
||||
if err := client.Heartbeat(context.Background()); err == nil {
|
||||
return nil // server has started
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
|
||||
client := api.NewClient()
|
||||
if err := client.Heartbeat(context.Background()); err != nil {
|
||||
if !strings.Contains(err.Error(), "connection refused") {
|
||||
return err
|
||||
}
|
||||
if runtime.GOOS == "darwin" {
|
||||
if err := startMacApp(client); err != nil {
|
||||
return fmt.Errorf("could not connect to ollama app, is it running?")
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("could not connect to ollama server, run 'ollama serve' to start it")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewCLI() *cobra.Command {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
|
||||
@@ -384,19 +595,21 @@ func NewCLI() *cobra.Command {
|
||||
cobra.EnableCommandSorting = false
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model from a Modelfile",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: create,
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model from a Modelfile",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CreateHandler,
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
|
||||
|
||||
runCmd := &cobra.Command{
|
||||
Use: "run MODEL [PROMPT]",
|
||||
Short: "Run a model",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: RunRun,
|
||||
Use: "run MODEL [PROMPT]",
|
||||
Short: "Run a model",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: RunHandler,
|
||||
}
|
||||
|
||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
@@ -409,23 +622,47 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
|
||||
pullCmd := &cobra.Command{
|
||||
Use: "pull MODEL",
|
||||
Short: "Pull a model from a registry",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: RunPull,
|
||||
Use: "pull MODEL",
|
||||
Short: "Pull a model from a registry",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: PullHandler,
|
||||
}
|
||||
|
||||
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
pushCmd := &cobra.Command{
|
||||
Use: "push MODEL",
|
||||
Short: "Push a model to a registry",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: push,
|
||||
Use: "push MODEL",
|
||||
Short: "Push a model to a registry",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: PushHandler,
|
||||
}
|
||||
|
||||
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List models",
|
||||
RunE: list,
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List models",
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: ListHandler,
|
||||
}
|
||||
|
||||
copyCmd := &cobra.Command{
|
||||
Use: "cp",
|
||||
Short: "Copy a model",
|
||||
Args: cobra.MinimumNArgs(2),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CopyHandler,
|
||||
}
|
||||
|
||||
deleteCmd := &cobra.Command{
|
||||
Use: "rm",
|
||||
Short: "Remove a model",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: DeleteHandler,
|
||||
}
|
||||
|
||||
rootCmd.AddCommand(
|
||||
@@ -435,6 +672,8 @@ func NewCLI() *cobra.Command {
|
||||
pullCmd,
|
||||
pushCmd,
|
||||
listCmd,
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
@@ -6,6 +6,14 @@ Install required tools:
|
||||
brew install go
|
||||
```
|
||||
|
||||
Enable CGO:
|
||||
|
||||
```
|
||||
export CGO_ENABLED=1
|
||||
```
|
||||
|
||||
You will also need a C/C++ compiler such as GCC for MacOS and Linux or Mingw-w64 GCC for Windows.
|
||||
|
||||
Then build ollama:
|
||||
|
||||
```
|
||||
|
@@ -4,6 +4,22 @@
|
||||
|
||||
A model file is the blueprint to create and share models with Ollama.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Format](#format)
|
||||
- [Examples](#examples)
|
||||
- [Instructions](#instructions)
|
||||
- [FROM (Required)](#from-required)
|
||||
- [Build from llama2](#build-from-llama2)
|
||||
- [Build from a bin file](#build-from-a-bin-file)
|
||||
- [PARAMETER](#parameter)
|
||||
- [Valid Parameters and Values](#valid-parameters-and-values)
|
||||
- [TEMPLATE](#template)
|
||||
- [Template Variables](#template-variables)
|
||||
- [SYSTEM](#system)
|
||||
- [LICENSE](#license)
|
||||
- [Notes](#notes)
|
||||
|
||||
## Format
|
||||
|
||||
The format of the Modelfile:
|
||||
@@ -13,13 +29,13 @@ The format of the Modelfile:
|
||||
INSTRUCTION arguments
|
||||
```
|
||||
|
||||
| Instruction | Description |
|
||||
| ----------------- | ----------------------------------------------------- |
|
||||
| `FROM` (required) | Defines the base model to use |
|
||||
| `PARAMETER` | Sets the parameters for how Ollama will run the model |
|
||||
| `SYSTEM` | Specifies the system prompt that will set the context |
|
||||
| `TEMPLATE` | The full prompt template to be sent to the model |
|
||||
| `LICENSE` | Specifies the legal license |
|
||||
| Instruction | Description |
|
||||
| ----------------------------------- | ------------------------------------------------------------- |
|
||||
| [`FROM`](#from-required) (required) | Defines the base model to use. |
|
||||
| [`PARAMETER`](#parameter) | Sets the parameters for how Ollama will run the model. |
|
||||
| [`TEMPLATE`](#template) | The full prompt template to be sent to the model. |
|
||||
| [`SYSTEM`](#system) | Specifies the system prompt that will be set in the template. |
|
||||
| [`LICENSE`](#license) | Specifies the legal license. |
|
||||
|
||||
## Examples
|
||||
|
||||
@@ -28,22 +44,26 @@ An example of a model file creating a mario blueprint:
|
||||
```
|
||||
FROM llama2
|
||||
# sets the temperature to 1 [higher is more creative, lower is more coherent]
|
||||
# sets the context size to 4096
|
||||
PARAMETER temperature 1
|
||||
# sets the context window size to 4096, this controls how many tokens the LLM can use as context to generate the next token
|
||||
PARAMETER num_ctx 4096
|
||||
|
||||
# Overriding the system prompt
|
||||
# sets a custom system prompt to specify the behavior of the chat assistant
|
||||
SYSTEM You are Mario from super mario bros, acting as an assistant.
|
||||
```
|
||||
|
||||
To use this:
|
||||
|
||||
1. Save it as a file (eg. `Modelfile``)
|
||||
1. Save it as a file (eg. `Modelfile`)
|
||||
2. `ollama create NAME -f <location of the file eg. ./Modelfile>'`
|
||||
3. `ollama run NAME`
|
||||
4. Start using the model!
|
||||
|
||||
## FROM (Required)
|
||||
More examples are available in the [examples directory](../examples).
|
||||
|
||||
## Instructions
|
||||
|
||||
### FROM (Required)
|
||||
|
||||
The FROM instruction defines the base model to use when creating a model.
|
||||
|
||||
@@ -51,7 +71,7 @@ The FROM instruction defines the base model to use when creating a model.
|
||||
FROM <model name>:<tag>
|
||||
```
|
||||
|
||||
### Build from llama2
|
||||
#### Build from llama2
|
||||
|
||||
```
|
||||
FROM llama2
|
||||
@@ -60,13 +80,15 @@ FROM llama2
|
||||
A list of available base models:
|
||||
<https://github.com/jmorganca/ollama#model-library>
|
||||
|
||||
### Build from a bin file
|
||||
#### Build from a bin file
|
||||
|
||||
```
|
||||
FROM ./ollama-model.bin
|
||||
```
|
||||
|
||||
## PARAMETER (Optional)
|
||||
This bin file location should be specified as an absolute path or relative to the Modelfile location.
|
||||
|
||||
### PARAMETER
|
||||
|
||||
The `PARAMETER` instruction defines a parameter that can be set when the model is run.
|
||||
|
||||
@@ -76,28 +98,67 @@ PARAMETER <parameter> <parametervalue>
|
||||
|
||||
### Valid Parameters and Values
|
||||
|
||||
| Parameter | Description | Value Type | Example Usage |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | ------------------ |
|
||||
| num_ctx | Sets the size of the prompt context size length model. (Default: 2048) | int | num_ctx 4096 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
|
||||
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
|
||||
| num_gpu | The number of GPUs to use. On macOS it defaults to 1 to enable metal support, 0 to disable. | int | num_gpu 1 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = ctx-size) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
|
||||
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
|
||||
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
|
||||
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
|
||||
| num_thread | Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). | int | num_thread 8 |
|
||||
| Parameter | Description | Value Type | Example Usage |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
|
||||
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
|
||||
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| num_gpu | The number of GPUs to use. On macOS it defaults to 1 to enable metal support, 0 to disable. | int | num_gpu 1 |
|
||||
| num_thread | Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). | int | num_thread 8 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
| stop | Sets the stop tokens to use. | string | stop "AI assistant:" |
|
||||
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
|
||||
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
|
||||
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
|
||||
|
||||
## Prompt
|
||||
### TEMPLATE
|
||||
|
||||
When building on top of the base models supplied by Ollama, it comes with the prompt template predefined. To override the supplied system prompt, simply add `SYSTEM insert system prompt` to change the system prompt.
|
||||
`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system prompt and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific.
|
||||
|
||||
### Prompt Template
|
||||
#### Template Variables
|
||||
|
||||
`TEMPLATE` the full prompt template to be passed into the model. It may include (optionally) a system prompt, user prompt, and assistant prompt. This is used to create a full custom prompt, and syntax may be model specific.
|
||||
| Variable | Description |
|
||||
| --------------- | ------------------------------------------------------------------------------------------------------------ |
|
||||
| `{{ .System }}` | The system prompt used to specify custom behavior, this must also be set in the Modelfile as an instruction. |
|
||||
| `{{ .Prompt }}` | The incoming prompt, this is not specified in the model file and will be set based on input. |
|
||||
| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. |
|
||||
|
||||
```
|
||||
TEMPLATE """
|
||||
{{- if .First }}
|
||||
### System:
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
|
||||
### User:
|
||||
{{ .Prompt }}
|
||||
|
||||
### Response:
|
||||
"""
|
||||
|
||||
SYSTEM """<system message>"""
|
||||
```
|
||||
|
||||
### SYSTEM
|
||||
|
||||
The `SYSTEM` instruction specifies the system prompt to be used in the template, if applicable.
|
||||
|
||||
```
|
||||
SYSTEM """<system message>"""
|
||||
```
|
||||
|
||||
### LICENSE
|
||||
|
||||
The `LICENSE` instruction allows you to specify the legal license under which the model used with this Modelfile is shared or distributed.
|
||||
|
||||
```
|
||||
LICENSE """
|
||||
<license text>
|
||||
"""
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
|
8
examples/devops-engineer/Modelfile
Normal file
@@ -0,0 +1,8 @@
|
||||
# Modelfile for creating a devops engineer assistant
|
||||
# Run `ollama create devops-engineer -f ./Modelfile` and then `ollama run devops-engineer` and enter a topic
|
||||
|
||||
FROM llama2:13b
|
||||
PARAMETER temperature 1
|
||||
SYSTEM """
|
||||
You are a senior devops engineer, acting as an assistant. You offer help with cloud technologies like: Terraform, AWS, kubernetes, python. You answer with code examples when possible
|
||||
"""
|
@@ -20,14 +20,8 @@ What the model file looks like:
|
||||
```
|
||||
FROM llama2
|
||||
PARAMETER temperature 1
|
||||
PROMPT """
|
||||
{{- if not .Context }}
|
||||
<<SYS>>
|
||||
You are Mario from super mario bros, acting as an assistant.
|
||||
<</SYS>>
|
||||
|
||||
{{- end }}
|
||||
[INST] {{ .Prompt }} [/INST]
|
||||
SYSTEM """
|
||||
You are Mario from Super Mario Bros, acting as an assistant.
|
||||
"""
|
||||
```
|
||||
|
||||
|
@@ -2,6 +2,6 @@
|
||||
# Run `ollama create tweetwriter -f ./Modelfile` and then `ollama run tweetwriter` and enter a topic
|
||||
|
||||
FROM nous-hermes
|
||||
PROMPT """
|
||||
SYSTEM """
|
||||
You are a content marketer who needs to come up with a short but succinct tweet. Make sure to include the appropriate hashtags and links. Sometimes when appropriate, describe a meme that can be includes as well. All answers should be in the form of a tweet which has a max size of 280 characters. Every instruction will be the topic to create a tweet about.
|
||||
"""
|
||||
|
@@ -1 +0,0 @@
|
||||
llama/ggml-metal.metal
|
2
go.mod
@@ -14,11 +14,11 @@ require (
|
||||
require github.com/rivo/uniseg v0.2.0 // indirect
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.0
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/chzyer/readline v1.5.1
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/cors v1.4.0
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
|
47
go.sum
@@ -1,5 +1,3 @@
|
||||
dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
|
||||
dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
@@ -13,6 +11,7 @@ github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObk
|
||||
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
|
||||
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -20,17 +19,25 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
|
||||
github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
|
||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
|
||||
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
@@ -45,8 +52,18 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
|
||||
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
@@ -61,12 +78,17 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
@@ -76,6 +98,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
@@ -85,31 +108,49 @@ github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gt
|
||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
|
||||
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
|
||||
golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
|
||||
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
|
||||
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
567
llama/ggml-alloc.c
Normal file
@@ -0,0 +1,567 @@
|
||||
/**
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2023 Georgi Gerganov
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml.h"
|
||||
#include <assert.h>
|
||||
#include <stdarg.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
//#define GGML_ALLOCATOR_DEBUG
|
||||
|
||||
//#define AT_PRINTF printf
|
||||
#define AT_PRINTF(...) ((void)0)
|
||||
|
||||
struct hash_node {
|
||||
struct ggml_tensor * t;
|
||||
int n_children;
|
||||
int n_views;
|
||||
};
|
||||
|
||||
static size_t hash(void * p) {
|
||||
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
}
|
||||
|
||||
static struct hash_node * hash_get(struct hash_node hash_table[], struct ggml_tensor * t) {
|
||||
size_t h = hash(t);
|
||||
|
||||
// linear probing
|
||||
size_t i = h;
|
||||
while (hash_table[i].t != NULL) {
|
||||
if (hash_table[i].t == t) {
|
||||
return &hash_table[i];
|
||||
}
|
||||
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
if (i == h) {
|
||||
// hash table is full
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
hash_table[i].t = t;
|
||||
return &hash_table[i];
|
||||
}
|
||||
|
||||
// TODO: GGML_PAD ?
|
||||
static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
|
||||
assert(alignment && !(alignment & (alignment - 1))); // power of 2
|
||||
size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;
|
||||
return offset + align;
|
||||
}
|
||||
|
||||
struct free_block {
|
||||
void * addr;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
#define MAX_FREE_BLOCKS 128
|
||||
|
||||
struct ggml_allocr {
|
||||
void * data;
|
||||
size_t size;
|
||||
size_t alignment;
|
||||
int n_free_blocks;
|
||||
struct free_block free_blocks[MAX_FREE_BLOCKS];
|
||||
struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
|
||||
size_t max_size;
|
||||
bool measure;
|
||||
|
||||
#ifdef GGML_ALLOCATOR_DEBUG
|
||||
struct ggml_tensor * allocated_tensors[1024];
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef GGML_ALLOCATOR_DEBUG
|
||||
static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
|
||||
for (int i = 0; i < 1024; i++) {
|
||||
if (alloc->allocated_tensors[i] == NULL) {
|
||||
alloc->allocated_tensors[i] = tensor;
|
||||
return;
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(!"out of allocated_tensors");
|
||||
}
|
||||
static void remove_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
|
||||
for (int i = 0; i < 1024; i++) {
|
||||
if (alloc->allocated_tensors[i] == tensor ||
|
||||
(alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
|
||||
alloc->allocated_tensors[i] = NULL;
|
||||
return;
|
||||
}
|
||||
}
|
||||
printf("tried to free tensor %s not found\n", tensor->name);
|
||||
GGML_ASSERT(!"tensor not found");
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
static size_t ggml_allocator_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||
return ggml_nbytes(tensor);
|
||||
|
||||
UNUSED(alloc);
|
||||
}
|
||||
|
||||
void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||
size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
|
||||
size = aligned_offset(NULL, size, alloc->alignment);
|
||||
|
||||
AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
|
||||
|
||||
size_t max_avail = 0;
|
||||
|
||||
// find the best fitting free block
|
||||
int best_fit_block = -1;
|
||||
size_t best_fit_size = SIZE_MAX;
|
||||
for (int i = 0; i < alloc->n_free_blocks; i++) {
|
||||
struct free_block * block = &alloc->free_blocks[i];
|
||||
max_avail = MAX(max_avail, block->size);
|
||||
if (block->size >= size && block->size <= best_fit_size) {
|
||||
best_fit_block = i;
|
||||
best_fit_size = block->size;
|
||||
}
|
||||
}
|
||||
|
||||
AT_PRINTF("block %d\n", best_fit_block);
|
||||
|
||||
if (best_fit_block == -1) {
|
||||
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
|
||||
__func__, size, max_avail);
|
||||
GGML_ASSERT(!"not enough space in the buffer");
|
||||
return;
|
||||
}
|
||||
struct free_block * block = &alloc->free_blocks[best_fit_block];
|
||||
void * addr = block->addr;
|
||||
block->addr = (char*)block->addr + size;
|
||||
block->size -= size;
|
||||
if (block->size == 0) {
|
||||
// remove block if empty
|
||||
alloc->n_free_blocks--;
|
||||
for (int j = best_fit_block; j < alloc->n_free_blocks; j++) {
|
||||
alloc->free_blocks[j] = alloc->free_blocks[j+1];
|
||||
}
|
||||
}
|
||||
|
||||
tensor->data = addr;
|
||||
|
||||
#ifdef GGML_ALLOCATOR_DEBUG
|
||||
add_allocated_tensor(alloc, tensor);
|
||||
size_t cur_max = (char*)addr - (char*)alloc->data + size;
|
||||
if (cur_max > alloc->max_size) {
|
||||
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
|
||||
for (int i = 0; i < 1024; i++) {
|
||||
if (alloc->allocated_tensors[i]) {
|
||||
printf("%s (%.2f MB) ", alloc->allocated_tensors[i]->name, ggml_nbytes(alloc->allocated_tensors[i]) / 1024.0 / 1024.0);
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size);
|
||||
}
|
||||
|
||||
// this is a very naive implementation, but for our case the number of free blocks should be very small
|
||||
static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||
void * ptr = tensor->data;
|
||||
|
||||
if (ptr < alloc->data || (char*)ptr >= (char*)alloc->data + alloc->max_size) {
|
||||
// the tensor was not allocated in this buffer
|
||||
// this can happen because the graph allocator will try to free weights and other tensors from different buffers
|
||||
// the easiest way to deal with this is just to ignore it
|
||||
return;
|
||||
}
|
||||
|
||||
size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
|
||||
size = aligned_offset(NULL, size, alloc->alignment);
|
||||
AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
|
||||
|
||||
#ifdef GGML_ALLOCATOR_DEBUG
|
||||
remove_allocated_tensor(alloc, tensor);
|
||||
#endif
|
||||
|
||||
// see if we can merge with an existing block
|
||||
for (int i = 0; i < alloc->n_free_blocks; i++) {
|
||||
struct free_block * block = &alloc->free_blocks[i];
|
||||
// check if ptr is at the end of the block
|
||||
if ((char*)block->addr + block->size == ptr) {
|
||||
block->size += size;
|
||||
// check if we can merge with the next block
|
||||
if (i < alloc->n_free_blocks - 1 && (char*)block->addr + block->size == alloc->free_blocks[i+1].addr) {
|
||||
block->size += alloc->free_blocks[i+1].size;
|
||||
alloc->n_free_blocks--;
|
||||
for (int j = i+1; j < alloc->n_free_blocks; j++) {
|
||||
alloc->free_blocks[j] = alloc->free_blocks[j+1];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
// check if ptr is at the beginning of the block
|
||||
if ((char*)ptr + size == block->addr) {
|
||||
block->addr = ptr;
|
||||
block->size += size;
|
||||
// check if we can merge with the previous block
|
||||
if (i > 0 && (char*)alloc->free_blocks[i-1].addr + alloc->free_blocks[i-1].size == block->addr) {
|
||||
alloc->free_blocks[i-1].size += block->size;
|
||||
alloc->n_free_blocks--;
|
||||
for (int j = i; j < alloc->n_free_blocks; j++) {
|
||||
alloc->free_blocks[j] = alloc->free_blocks[j+1];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
// otherwise, add a new block
|
||||
GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks");
|
||||
// insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster)
|
||||
int insert_pos = 0;
|
||||
while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].addr < ptr) {
|
||||
insert_pos++;
|
||||
}
|
||||
// shift all blocks from insert_pos onward to make room for the new block
|
||||
for (int i = alloc->n_free_blocks; i > insert_pos; i--) {
|
||||
alloc->free_blocks[i] = alloc->free_blocks[i-1];
|
||||
}
|
||||
// insert the new block
|
||||
alloc->free_blocks[insert_pos].addr = ptr;
|
||||
alloc->free_blocks[insert_pos].size = size;
|
||||
alloc->n_free_blocks++;
|
||||
}
|
||||
|
||||
void ggml_allocr_reset(struct ggml_allocr * alloc) {
|
||||
alloc->n_free_blocks = 1;
|
||||
size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
|
||||
alloc->free_blocks[0].addr = (char *)alloc->data + align_offset;
|
||||
alloc->free_blocks[0].size = alloc->size - align_offset;
|
||||
}
|
||||
|
||||
struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) {
|
||||
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
|
||||
|
||||
*alloc = (struct ggml_allocr){
|
||||
/*.data = */ data,
|
||||
/*.size = */ size,
|
||||
/*.alignment = */ alignment,
|
||||
/*.n_free_blocks = */ 0,
|
||||
/*.free_blocks = */ {{0}},
|
||||
/*.hash_table = */ {{0}},
|
||||
/*.max_size = */ 0,
|
||||
/*.measure = */ false,
|
||||
#ifdef GGML_ALLOCATOR_DEBUG
|
||||
/*.allocated_tensors = */ = {0},
|
||||
#endif
|
||||
};
|
||||
|
||||
ggml_allocr_reset(alloc);
|
||||
|
||||
return alloc;
|
||||
}
|
||||
|
||||
// address and size of the buffer when measuring
|
||||
// it needs to be large enough to fit all the tensors, but it cannot overlap with other existing buffers
|
||||
static void * const MEASURE_BASE_ADDR = (void *) 0x1000;
|
||||
static const size_t MEASURE_MAX_SIZE = 1ULL<<40; // 1 TB
|
||||
|
||||
struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
|
||||
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
|
||||
|
||||
*alloc = (struct ggml_allocr){
|
||||
/*.data = */ MEASURE_BASE_ADDR,
|
||||
/*.size = */ MEASURE_MAX_SIZE,
|
||||
/*.alignment = */ alignment,
|
||||
/*.n_free_blocks = */ 0,
|
||||
/*.free_blocks = */ {{0}},
|
||||
/*.hash_table = */ {{0}},
|
||||
/*.max_size = */ 0,
|
||||
/*.measure = */ true,
|
||||
#ifdef GGML_ALLOCATOR_DEBUG
|
||||
/*.allocated_tensors = */ = {0},
|
||||
#endif
|
||||
};
|
||||
|
||||
ggml_allocr_reset(alloc);
|
||||
|
||||
return alloc;
|
||||
}
|
||||
|
||||
void ggml_allocr_free(struct ggml_allocr * alloc) {
|
||||
free(alloc);
|
||||
}
|
||||
|
||||
bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
|
||||
return alloc->measure;
|
||||
}
|
||||
|
||||
//////////// compute graph allocator
|
||||
|
||||
static bool ggml_is_view(struct ggml_tensor * t) {
|
||||
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
|
||||
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
|
||||
}
|
||||
|
||||
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
||||
if (a->type != b->type) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (a->ne[i] != b->ne[i]) {
|
||||
return false;
|
||||
}
|
||||
if (a->nb[i] != b->nb[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
|
||||
switch (t->op) {
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_VIEW:
|
||||
return t->src[0];
|
||||
case GGML_OP_CPY:
|
||||
return t->src[1];
|
||||
default:
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
|
||||
struct ggml_tensor * parent = t;
|
||||
do {
|
||||
parent = get_view_parent(parent);
|
||||
} while (ggml_is_view(parent));
|
||||
return parent;
|
||||
}
|
||||
|
||||
static bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
switch (op) {
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_DIAG_MASK_ZERO:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_CONT:
|
||||
return true;
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) {
|
||||
struct hash_node * ht = alloc->hash_table;
|
||||
if (node->data == NULL) {
|
||||
if (ggml_is_view(node)) {
|
||||
size_t offset;
|
||||
switch(node->op) {
|
||||
case GGML_OP_VIEW:
|
||||
memcpy(&offset, node->op_params, sizeof(size_t));
|
||||
node->data = (char *) node->src[0]->data + offset;
|
||||
break;
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
node->data = node->src[0]->data;
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
node->data = node->src[1]->data;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(!"unknown view op");
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// see if we can reuse a parent's buffer (inplace)
|
||||
if (ggml_op_can_inplace(node->op)) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
struct ggml_tensor * parent = node->src[i];
|
||||
if (parent == NULL) {
|
||||
break;
|
||||
}
|
||||
struct hash_node * p_hn = hash_get(ht, parent);
|
||||
if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
|
||||
if (ggml_is_view(parent)) {
|
||||
struct ggml_tensor * view_src = get_view_source(parent);
|
||||
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
||||
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
|
||||
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
|
||||
// the parent's data that it will need later (same layout requirement). the problem is that then
|
||||
// we cannot free the tensor because the original address of the allocation is lost.
|
||||
// adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
|
||||
// for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
|
||||
AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
|
||||
node->data = parent->data;
|
||||
return;
|
||||
}
|
||||
}
|
||||
else {
|
||||
AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
|
||||
node->data = parent->data;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_allocr_alloc(alloc, node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static size_t ggml_allocator_alloc_graph_tensors_n(
|
||||
struct ggml_allocr * alloc,
|
||||
struct ggml_cgraph ** graphs, int n_graphs,
|
||||
struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
|
||||
|
||||
// reset hash table
|
||||
struct hash_node * ht = alloc->hash_table;
|
||||
memset(ht, 0, sizeof(struct hash_node) * GGML_GRAPH_HASHTABLE_SIZE);
|
||||
|
||||
// count number of children and views
|
||||
for (int g = 0; g < n_graphs; g++) {
|
||||
struct ggml_cgraph * gf = graphs[g];
|
||||
for (int i = 0; i < gf->n_nodes; i++) {
|
||||
struct ggml_tensor * node = gf->nodes[i];
|
||||
|
||||
if (ggml_is_view(node)) {
|
||||
struct ggml_tensor * view_src = get_view_source(node);
|
||||
hash_get(ht, view_src)->n_views += 1;
|
||||
}
|
||||
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * parent = node->src[j];
|
||||
if (parent == NULL) {
|
||||
break;
|
||||
}
|
||||
hash_get(ht, parent)->n_children += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// allocate tensors
|
||||
for (int g = 0; g < n_graphs; g++) {
|
||||
struct ggml_cgraph * gf = graphs[g];
|
||||
AT_PRINTF("####### graph %d/%d\n", g, n_graphs);
|
||||
// graph inputs are allocated first to ensure that they are not overwritten by each other
|
||||
if (inputs != NULL && inputs[g] != NULL) {
|
||||
for (int i = 0; inputs[g][i] != NULL; i++) {
|
||||
struct ggml_tensor * input = inputs[g][i];
|
||||
AT_PRINTF("input: %s\n", input->name);
|
||||
allocate_node(alloc, input);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < gf->n_nodes; i++) {
|
||||
struct ggml_tensor * node = gf->nodes[i];
|
||||
|
||||
// allocate parents (leafs)
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * parent = node->src[j];
|
||||
if (parent == NULL) {
|
||||
break;
|
||||
}
|
||||
allocate_node(alloc, parent);
|
||||
}
|
||||
|
||||
// allocate node
|
||||
allocate_node(alloc, node);
|
||||
|
||||
AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * parent = node->src[j];
|
||||
if (parent == NULL) {
|
||||
break;
|
||||
}
|
||||
AT_PRINTF("%s", parent->name);
|
||||
if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
|
||||
AT_PRINTF(", ");
|
||||
}
|
||||
}
|
||||
AT_PRINTF("\n");
|
||||
|
||||
// update parents
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * parent = node->src[j];
|
||||
if (parent == NULL) {
|
||||
break;
|
||||
}
|
||||
struct hash_node * p_hn = hash_get(ht, parent);
|
||||
p_hn->n_children -= 1;
|
||||
|
||||
//AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
|
||||
|
||||
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
|
||||
if (ggml_is_view(parent)) {
|
||||
struct ggml_tensor * view_src = get_view_source(parent);
|
||||
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
||||
view_src_hn->n_views -= 1;
|
||||
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src->n_children, view_src->n_views);
|
||||
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
|
||||
ggml_allocator_free_tensor(alloc, view_src);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (parent->data != node->data) {
|
||||
ggml_allocator_free_tensor(alloc, parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
AT_PRINTF("\n");
|
||||
}
|
||||
// free graph outputs here that wouldn't be freed otherwise because they have no children
|
||||
if (outputs != NULL && outputs[g] != NULL) {
|
||||
for (int i = 0; outputs[g][i] != NULL; i++) {
|
||||
struct ggml_tensor * output = outputs[g][i];
|
||||
AT_PRINTF("output: %s\n", output->name);
|
||||
ggml_allocator_free_tensor(alloc, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return alloc->max_size;
|
||||
}
|
||||
|
||||
size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
|
||||
return ggml_allocator_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
|
||||
}
|
48
llama/ggml-alloc.h
Normal file
@@ -0,0 +1,48 @@
|
||||
/**
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2023 Georgi Gerganov
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
|
||||
GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);
|
||||
GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
|
||||
|
||||
GGML_API void ggml_allocr_free(struct ggml_allocr * alloc);
|
||||
GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);
|
||||
GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc);
|
||||
GGML_API void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor);
|
||||
GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
2680
llama/ggml-cuda.cu
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
@@ -53,6 +53,7 @@ void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_set_main_device(int main_device);
|
||||
void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
||||
void ggml_cuda_set_scratch_size(size_t scratch_size);
|
||||
void ggml_cuda_free_scratch(void);
|
||||
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
||||
|
@@ -1,7 +1,7 @@
|
||||
//go:build darwin
|
||||
|
||||
/**
|
||||
* llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
@@ -89,6 +89,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
|
||||
// get data from the device into host memory
|
||||
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
|
||||
|
||||
// try to find operations that can be run concurrently in the graph
|
||||
// you should run it again if the topology of your graph changes
|
||||
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
||||
|
||||
// if the graph has been optimized for concurrently dispatch
|
||||
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
|
||||
|
||||
// same as ggml_graph_compute but uses Metal
|
||||
// creates gf->n_threads command buffers in parallel
|
||||
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
||||
|
@@ -1,7 +1,7 @@
|
||||
//go:build darwin
|
||||
|
||||
/**
|
||||
* llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
@@ -64,12 +64,16 @@ struct ggml_metal_context {
|
||||
int n_buffers;
|
||||
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
||||
|
||||
int concur_list[GGML_MAX_NODES];
|
||||
int concur_list_len;
|
||||
|
||||
// custom kernels
|
||||
#define GGML_METAL_DECL_KERNEL(name) \
|
||||
id<MTLFunction> function_##name; \
|
||||
id<MTLComputePipelineState> pipeline_##name
|
||||
|
||||
GGML_METAL_DECL_KERNEL(add);
|
||||
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
||||
GGML_METAL_DECL_KERNEL(mul);
|
||||
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
||||
GGML_METAL_DECL_KERNEL(scale);
|
||||
@@ -125,6 +129,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
ctx->device = MTLCreateSystemDefaultDevice();
|
||||
ctx->queue = [ctx->device newCommandQueue];
|
||||
ctx->n_buffers = 0;
|
||||
ctx->concur_list_len = 0;
|
||||
|
||||
// determine if we can use MPS
|
||||
if (MPSSupportsMTLDevice(ctx->device)) {
|
||||
@@ -185,6 +190,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
|
||||
|
||||
GGML_METAL_ADD_KERNEL(add);
|
||||
GGML_METAL_ADD_KERNEL(add_row);
|
||||
GGML_METAL_ADD_KERNEL(mul);
|
||||
GGML_METAL_ADD_KERNEL(mul_row);
|
||||
GGML_METAL_ADD_KERNEL(scale);
|
||||
@@ -243,6 +249,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
||||
ctx->n_cb = n_cb;
|
||||
}
|
||||
|
||||
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
||||
if (ctx->concur_list_len) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// finds the Metal buffer that contains the tensor data on the GPU device
|
||||
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
||||
// Metal buffer based on the host memory pointer
|
||||
@@ -381,11 +394,98 @@ void ggml_metal_get_tensor(
|
||||
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
|
||||
}
|
||||
|
||||
void ggml_metal_graph_find_concurrency(
|
||||
struct ggml_metal_context * ctx,
|
||||
struct ggml_cgraph * gf) {
|
||||
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
|
||||
int nodes_unused[GGML_MAX_NODES];
|
||||
|
||||
for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
|
||||
for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
|
||||
ctx->concur_list_len = 0;
|
||||
|
||||
int n_left = gf->n_nodes;
|
||||
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
|
||||
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
|
||||
|
||||
while (n_left > 0) {
|
||||
// number of nodes at a layer (that can be issued concurrently)
|
||||
int concurrency = 0;
|
||||
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
|
||||
if (nodes_unused[i]) {
|
||||
// if the requirements for gf->nodes[i] are satisfied
|
||||
int exe_flag=1;
|
||||
// scan all srcs
|
||||
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
|
||||
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
|
||||
if (src_cur) {
|
||||
// if is leaf nodes it's satisfied.
|
||||
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
|
||||
|
||||
// otherwise this src should be the output from previous nodes.
|
||||
int is_found = 0;
|
||||
// scan 2*search_depth back because we inserted barrier.
|
||||
for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
||||
if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
|
||||
}
|
||||
if (is_found == 0) {exe_flag = 0; break;}
|
||||
}
|
||||
}
|
||||
if (exe_flag) {
|
||||
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
|
||||
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
|
||||
int64_t data_start = (int64_t) gf->nodes[i]->data;
|
||||
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
|
||||
for (int j = n_start; j < i; j++) {
|
||||
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
|
||||
&& gf->nodes[j]->op != GGML_OP_VIEW \
|
||||
&& gf->nodes[j]->op != GGML_OP_TRANSPOSE \
|
||||
&& gf->nodes[j]->op != GGML_OP_PERMUTE) {
|
||||
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
|
||||
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
|
||||
continue;
|
||||
} else {
|
||||
exe_flag = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (exe_flag) {
|
||||
ctx->concur_list[level_pos + concurrency] = i;
|
||||
nodes_unused[i] = 0;
|
||||
concurrency++;
|
||||
ctx->concur_list_len++;
|
||||
}
|
||||
}
|
||||
}
|
||||
n_left -= concurrency;
|
||||
// adding a barrier different layer
|
||||
ctx->concur_list[level_pos + concurrency] = -1;
|
||||
ctx->concur_list_len++;
|
||||
// jump all sorted nodes at nodes_bak
|
||||
while (!nodes_unused[n_start]) {n_start++;}
|
||||
level_pos += concurrency + 1;
|
||||
}
|
||||
|
||||
if (ctx->concur_list_len > GGML_MAX_NODES) {
|
||||
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_metal_graph_compute(
|
||||
struct ggml_metal_context * ctx,
|
||||
struct ggml_cgraph * gf) {
|
||||
metal_printf("%s: evaluating graph\n", __func__);
|
||||
|
||||
// if there is ctx->concur_list, dispatch concurrently
|
||||
// else fallback to serial dispatch
|
||||
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
||||
|
||||
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
|
||||
|
||||
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
||||
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
||||
|
||||
// create multiple command buffers and enqueue them
|
||||
// then, we encode the graph into the command buffers in parallel
|
||||
|
||||
@@ -404,7 +504,7 @@ void ggml_metal_graph_compute(
|
||||
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
||||
|
||||
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
||||
const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
|
||||
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
||||
|
||||
dispatch_async(queue, ^{
|
||||
size_t offs_src0 = 0;
|
||||
@@ -415,10 +515,21 @@ void ggml_metal_graph_compute(
|
||||
|
||||
id<MTLComputeCommandEncoder> encoder = nil;
|
||||
|
||||
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
||||
const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
||||
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
||||
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
||||
|
||||
for (int ind = node_start; ind < node_end; ++ind) {
|
||||
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
||||
|
||||
if (i == -1) {
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
continue;
|
||||
}
|
||||
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int i = node_start; i < node_end; ++i) {
|
||||
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||
|
||||
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
||||
@@ -489,13 +600,19 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_add];
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
// src1 is a row
|
||||
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_add];
|
||||
}
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
@@ -504,7 +621,7 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_MUL:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
@@ -525,7 +642,7 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_SCALE:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
const float scale = *(const float *) src1->data;
|
||||
@@ -539,52 +656,60 @@ void ggml_metal_graph_compute(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SILU:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_silu];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setComputePipelineState:ctx->pipeline_silu];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_RELU:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_relu];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_GELU:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_RELU:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_relu];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_GELU:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
const int nth = 32;
|
||||
@@ -602,10 +727,10 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
const int n_past = ((int32_t *)(src1->data))[0];
|
||||
const int n_past = ((int32_t *)(dst->op_params))[0];
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
@@ -621,7 +746,8 @@ void ggml_metal_graph_compute(
|
||||
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
|
||||
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
||||
GGML_ASSERT(ne03 == ne13);
|
||||
|
||||
if (ggml_is_contiguous(src0) &&
|
||||
ggml_is_contiguous(src1) &&
|
||||
@@ -649,11 +775,11 @@ void ggml_metal_graph_compute(
|
||||
initWithDevice:ctx->device transposeLeft:false transposeRight:true
|
||||
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
|
||||
|
||||
// we need to do ne02 multiplications
|
||||
// we need to do ne12 multiplications
|
||||
// TODO: is there a way to do this in parallel - currently very slow ..
|
||||
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
|
||||
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
||||
size_t offs_src0_cur = offs_src0 + i02*nb02;
|
||||
for (int64_t i02 = 0; i02 < ne12; ++i02) {
|
||||
size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
|
||||
size_t offs_src1_cur = offs_src1 + i02*nb12;
|
||||
size_t offs_dst_cur = offs_dst + i02*nb2;
|
||||
|
||||
@@ -665,7 +791,7 @@ void ggml_metal_graph_compute(
|
||||
}
|
||||
} else {
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
int nth0 = 32;
|
||||
@@ -675,8 +801,6 @@ void ggml_metal_graph_compute(
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
|
||||
nth0 = 64;
|
||||
nth1 = 1;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||
@@ -704,8 +828,8 @@ void ggml_metal_graph_compute(
|
||||
GGML_ASSERT(ne02 == 1);
|
||||
GGML_ASSERT(ne12 == 1);
|
||||
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
@@ -713,8 +837,8 @@ void ggml_metal_graph_compute(
|
||||
GGML_ASSERT(ne02 == 1);
|
||||
GGML_ASSERT(ne12 == 1);
|
||||
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
@@ -756,31 +880,35 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
||||
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
||||
src0t == GGML_TYPE_Q4_K) {
|
||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q3_K) {
|
||||
#ifdef GGML_QKK_64
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
#else
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
#endif
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q5_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q6_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q2_K ||
|
||||
src0t == GGML_TYPE_Q3_K) {
|
||||
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
@@ -790,7 +918,7 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
switch (src0->type) {
|
||||
@@ -819,10 +947,11 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
const float eps = 1e-6f;
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
const int nth = 512;
|
||||
|
||||
@@ -841,7 +970,7 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_NORM:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
const float eps = 1e-5f;
|
||||
@@ -863,14 +992,15 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_ALIBI:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
||||
|
||||
const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
|
||||
const int n_head = ((int32_t *) src1->data)[1];
|
||||
const float max_bias = ((float *) src1->data)[2];
|
||||
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
||||
float max_bias;
|
||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
if (__builtin_popcount(n_head) != 1) {
|
||||
GGML_ASSERT(false && "only power-of-two n_head implemented");
|
||||
@@ -905,18 +1035,17 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
const int n_dims = ((int32_t *) src1->data)[1];
|
||||
const int mode = ((int32_t *) src1->data)[2];
|
||||
|
||||
const int n_past = ((int32_t *)(src1->data))[0];
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_rope];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
@@ -945,10 +1074,12 @@ void ggml_metal_graph_compute(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
const int nth = 32;
|
||||
@@ -995,8 +1126,10 @@ void ggml_metal_graph_compute(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
||||
GGML_ASSERT(false);
|
||||
{
|
||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
//go:build darwin
|
||||
|
||||
/**
|
||||
* llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
@@ -95,6 +95,17 @@ kernel void kernel_add(
|
||||
dst[tpig] = src0[tpig] + src1[tpig];
|
||||
}
|
||||
|
||||
// assumption: src1 is a row
|
||||
// broadcast src1 into src0
|
||||
kernel void kernel_add_row(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] + src1[tpig % ne00];
|
||||
}
|
||||
|
||||
kernel void kernel_mul(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
@@ -379,7 +390,7 @@ kernel void kernel_rms_norm(
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (int i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
sum[tpitg] += sum[tpitg + i];
|
||||
}
|
||||
@@ -404,87 +415,90 @@ kernel void kernel_rms_norm(
|
||||
}
|
||||
}
|
||||
|
||||
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
|
||||
float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
|
||||
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
||||
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||
float d = qb_curr->d;
|
||||
float4 acc = 0.f;
|
||||
device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
|
||||
for (int i = 0; i < 16; i+=2) {
|
||||
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
||||
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
||||
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
||||
float2 acc = 0.f;
|
||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
||||
for (int i = 0; i < 8; i+=2) {
|
||||
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
||||
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
||||
}
|
||||
return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
|
||||
return d * (sumy * -8.f + acc[0] + acc[1]);
|
||||
}
|
||||
|
||||
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
|
||||
float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
|
||||
// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
||||
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
||||
float d = qb_curr->d;
|
||||
float m = qb_curr->m;
|
||||
float4 acc = 0.f;
|
||||
device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
|
||||
for (int i = 0; i < 16; i+=2) {
|
||||
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
||||
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
||||
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
||||
float2 acc = 0.f;
|
||||
for (int i = 0; i < 8; i+=2) {
|
||||
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
||||
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
||||
}
|
||||
return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
|
||||
return d * (acc[0] + acc[1]) + sumy * m;
|
||||
}
|
||||
|
||||
// putting them in the kernel cause a significant performance penalty
|
||||
#define N_DST 4 // each SIMD group works on 4 rows
|
||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||
template<typename block_q_type>
|
||||
//Note: This is a template, but strictly speaking it only applies to
|
||||
// quantizations where the block size is 32. It also does not
|
||||
// giard against the number of rows not being divisible by
|
||||
// N_DST, so this is another explicit assumption of the implementation.
|
||||
template<typename block_q_type, int nr, int nsg, int nw>
|
||||
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
||||
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
|
||||
uint2 tgpig, uint tiisg, uint sgitg) {
|
||||
const int nb = ne00/QK4_0;
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
|
||||
const int first_row = (r0 * nsg + sgitg) * nr;
|
||||
device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
|
||||
device const float * y = (device const float *) src1 + r1*ne10;
|
||||
float4 y_curr[8]; // src1 vector cache
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
thread float * yl=(thread float *)y_curr;
|
||||
float yl[16]; // src1 vector cache
|
||||
float sumf[nr]={0.f};
|
||||
|
||||
// each thread in a SIMD group deals with 1 block.
|
||||
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
|
||||
const int ix = tiisg/2;
|
||||
const int il = 8*(tiisg%2);
|
||||
|
||||
device const float * yb = y + ix * QK4_0 + il;
|
||||
|
||||
// each thread in a SIMD group deals with half a block.
|
||||
for (int ib = ix; ib < nb; ib += nw/2) {
|
||||
float sumy = 0;
|
||||
for (int i = 0; i < QK4_0 / 4; i++) {
|
||||
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
|
||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
sumy += yb[i] + yb[i+1];
|
||||
yl[i+0] = yb[i+ 0];
|
||||
yl[i+1] = yb[i+ 1]/256.f;
|
||||
sumy += yb[i+16] + yb[i+17];
|
||||
yl[i+8] = yb[i+16]/16.f;
|
||||
yl[i+9] = yb[i+17]/4096.f;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
|
||||
for (int row = 0; row < nr; row++) {
|
||||
sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
|
||||
}
|
||||
|
||||
yb += QK4_0 * 16;
|
||||
}
|
||||
|
||||
// from now loads two rows every time and 16 blocks per row
|
||||
int ir = tiisg / (N_SIMDWIDTH / 2);
|
||||
int ib = tiisg % (N_SIMDWIDTH / 2);
|
||||
for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
|
||||
int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
|
||||
float sumy = 0;
|
||||
for (int i = 0; i < QK4_0 / 4; i++) {
|
||||
y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
|
||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row+=2) {
|
||||
if (nb_start + ib < nb) {
|
||||
sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
||||
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
||||
for (int row = 0; row < nr; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
if (tiisg == 0 && first_row + row < ne01) {
|
||||
dst[r1*ne0 + first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -500,7 +514,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
||||
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_q4_1_f32(
|
||||
@@ -514,7 +528,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
||||
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_f16_f32(
|
||||
@@ -523,11 +537,13 @@ kernel void kernel_mul_mat_f16_f32(
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
@@ -543,7 +559,7 @@ kernel void kernel_mul_mat_f16_f32(
|
||||
const int64_t r1 = tgpig.y;
|
||||
const int64_t im = tgpig.z;
|
||||
|
||||
device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
|
||||
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
sum[tpitg.x] = 0.0f;
|
||||
@@ -566,6 +582,7 @@ kernel void kernel_mul_mat_f16_f32(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
kernel void kernel_alibi_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
@@ -1237,111 +1254,137 @@ kernel void kernel_mul_mat_q2_K_f32(
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne0,
|
||||
threadgroup float * sum [[threadgroup(0)]],
|
||||
constant int64_t & ne01[[buffer(4)]],
|
||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||
uint2 tpitg[[thread_position_in_threadgroup]],
|
||||
uint2 tptg[[threads_per_threadgroup]]) {
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
const int nb = ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t r1 = tgpig.y;
|
||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||
const int ib_row = first_row * nb;
|
||||
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
|
||||
device const float * y = (device const float *) src1 + r1*ne10;
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
||||
device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
|
||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||
|
||||
const int nth = tptg.x*tptg.y;
|
||||
const int ith = tptg.y*tpitg.x + tpitg.y;
|
||||
|
||||
float sumf = 0;
|
||||
const int step = sizeof(block_q2_K) * nb;
|
||||
|
||||
#if QK_K == 256
|
||||
const int tid = tpitg.y; // 0...16
|
||||
const int il = tid/4; // 0...3
|
||||
const int ir = tid%4; // 0...3
|
||||
const int ip = il/2; // 0 or 1
|
||||
const int shift1 = 4*(il%2);// 0 or 4
|
||||
const int shift2 = shift1+2;// 2 or 6
|
||||
const int n = 8;
|
||||
const int is = 4*il + (n*ir)/16;
|
||||
const int ix = tiisg/8; // 0...3
|
||||
const int it = tiisg%8; // 0...7
|
||||
const int im = it/4; // 0 or 1
|
||||
const int ir = it%4; // 0...3
|
||||
const int is = (8*ir)/16;// 0 or 1
|
||||
|
||||
const int y_offset = 64*il + n*ir;
|
||||
const int q_offset = 32*ip + n*ir;
|
||||
device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
|
||||
|
||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||
for (int ib = ix; ib < nb; ib += 4) {
|
||||
|
||||
device const uint8_t * q = x[i].qs + q_offset;
|
||||
device const uint8_t * scales = x[i].scales + is;
|
||||
|
||||
uint8_t d1 = scales[0] & 0xF;
|
||||
uint8_t d2 = scales[2] & 0xF;
|
||||
uint8_t m1 = scales[0] >> 4;
|
||||
uint8_t m2 = scales[2] >> 4;
|
||||
|
||||
device const float * y = yy + i*QK_K + y_offset;
|
||||
|
||||
float2 s = {0.f, 0.f};
|
||||
float smin = 0;
|
||||
for (int l = 0; l < n; ++l) {
|
||||
s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
|
||||
s[1] += y[l+32] * ((q[l] >> shift2) & 3);
|
||||
smin += y[l+ 0] * m1 + y[l+32] * m2;
|
||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
||||
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
||||
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
||||
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
||||
}
|
||||
|
||||
const float dall = (float)x[i].d;
|
||||
const float dmin = (float)x[i].dmin;
|
||||
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
|
||||
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
|
||||
device const half * dh = &x[ib].d;
|
||||
|
||||
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
||||
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
||||
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
||||
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
||||
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
||||
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
||||
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
||||
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
||||
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
||||
}
|
||||
float dall = dh[0];
|
||||
float dmin = dh[1] * 1.f/16.f;
|
||||
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
||||
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
|
||||
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
|
||||
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
|
||||
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
|
||||
|
||||
qs += step/2;
|
||||
sc += step;
|
||||
dh += step/2;
|
||||
}
|
||||
|
||||
y4 += 4 * QK_K;
|
||||
}
|
||||
#else
|
||||
const int il = 4 * tpitg.x;
|
||||
const int ix = tiisg/2; // 0...15
|
||||
const int it = tiisg%2; // 0...1
|
||||
|
||||
uint32_t aux[2];
|
||||
thread const uint8_t * d = (thread const uint8_t *)aux;
|
||||
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
|
||||
device const float * y4 = y + ix * QK_K + 8 * it;
|
||||
|
||||
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
||||
for (int ib = ix; ib < nb; ib += 16) {
|
||||
|
||||
device const uint8_t * q = x[i].qs + il;
|
||||
device const float * y = yy + i*QK_K + il;
|
||||
|
||||
const float dall = (float)x[i].d;
|
||||
const float dmin = (float)x[i].dmin;
|
||||
|
||||
device const uint32_t * a = (device const uint32_t *)x[i].scales;
|
||||
aux[0] = a[0] & 0x0f0f0f0f;
|
||||
aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
|
||||
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
|
||||
+ y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
|
||||
+ y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
|
||||
+ y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
|
||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
||||
yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
|
||||
yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
|
||||
yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
|
||||
}
|
||||
|
||||
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
|
||||
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
||||
device const half * dh = &x[ib].d;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
||||
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
||||
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
||||
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
||||
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
||||
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
||||
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
||||
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
||||
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
||||
}
|
||||
|
||||
float dall = dh[0];
|
||||
float dmin = dh[1];
|
||||
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
||||
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
|
||||
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
|
||||
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
|
||||
dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
|
||||
|
||||
qs += step/2;
|
||||
sc += step;
|
||||
dh += step/2;
|
||||
}
|
||||
|
||||
y4 += 16 * QK_K;
|
||||
}
|
||||
#endif
|
||||
|
||||
sum[ith] = sumf;
|
||||
|
||||
//
|
||||
// Accumulate the sum from all threads in the threadgroup
|
||||
//
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (ith%4 == 0) {
|
||||
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (ith%16 == 0) {
|
||||
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (ith == 0) {
|
||||
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
||||
dst[r1*ne0 + r0] = sum[0];
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if QK_K == 256
|
||||
kernel void kernel_mul_mat_q3_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
@@ -1350,40 +1393,41 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
threadgroup float * sum [[threadgroup(0)]],
|
||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||
uint2 tpitg[[thread_position_in_threadgroup]],
|
||||
uint2 tptg[[threads_per_threadgroup]]) {
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
const int nb = ne00/QK_K;
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t r1 = tgpig.y;
|
||||
|
||||
device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
|
||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
||||
|
||||
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
|
||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||
|
||||
const int nth = tptg.x*tptg.y;
|
||||
const int ith = tptg.y*tpitg.x + tpitg.y;
|
||||
|
||||
#if QK_K == 256
|
||||
|
||||
const uint8_t m3 = 3;
|
||||
const int8_t m4 = 4;
|
||||
float yl[16];
|
||||
|
||||
const uint16_t kmask1 = 0x0303;
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
|
||||
const int tid = tpitg.y; // expecting 16
|
||||
const int tid = tiisg/2;
|
||||
const int ix = tiisg%2;
|
||||
const int ip = tid/8; // 0 or 1
|
||||
const int il = tid/2 - 4*ip; // 0...3
|
||||
const int ir = tid%2;
|
||||
const int n = 8;
|
||||
const int l0 = n*ir;
|
||||
|
||||
const uint8_t m = 1 << (4*ip + il);
|
||||
const uint16_t m1 = 1 << (4*ip + il);
|
||||
const uint16_t m2 = m1 << 8;
|
||||
|
||||
const int shift = 2*il;
|
||||
const uint16_t qm1 = 0x0003 << shift;
|
||||
const uint16_t qm2 = 0x0300 << shift;
|
||||
const int32_t v1 = 4 << shift;
|
||||
const int32_t v2 = 1024 << shift;
|
||||
|
||||
const uint16_t s_shift1 = 4*ip;
|
||||
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
||||
@@ -1392,93 +1436,132 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||
const int q_offset = 32*ip + l0;
|
||||
const int y_offset = 128*ip + 32*il + l0;
|
||||
|
||||
//float sumf = 0;
|
||||
float sumf1 = 0, sumf2 = 0;
|
||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||
const int step = sizeof(block_q3_K) * nb / 2;
|
||||
|
||||
const float d_all = (float)(x[i].d);
|
||||
device const float * y1 = yy + ix*QK_K + y_offset;
|
||||
|
||||
device const uint8_t * q = x[i].qs + q_offset;
|
||||
device const uint8_t * h = x[i].hmask + l0;
|
||||
device const float * y = yy + i * QK_K + y_offset;
|
||||
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
|
||||
for (int i = ix; i < nb; i += 2) {
|
||||
|
||||
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
||||
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
||||
|
||||
float s = 0;
|
||||
for (int l = 0; l < n; ++l) {
|
||||
s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
yl[l+0] = y1[l+ 0];
|
||||
yl[l+8] = y1[l+16];
|
||||
}
|
||||
float d = d_all * s;
|
||||
sumf1 += d * scales[0];
|
||||
sumf2 += d;
|
||||
//sumf += d_all * s * (scales[0] - 32);
|
||||
|
||||
s = 0;
|
||||
for (int l = 0; l < n; ++l) {
|
||||
s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
|
||||
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
||||
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
|
||||
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
|
||||
device const half * dh = &x[i].d;
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
|
||||
const float d_all = (float)dh[0];
|
||||
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
||||
|
||||
float s1 = 0, s2 = 0;
|
||||
for (int l = 0; l < n; l += 2) {
|
||||
const uint16_t qs = q[l/2];
|
||||
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
|
||||
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
|
||||
}
|
||||
float d = d_all * (s1 + 1.f/256.f * s2);
|
||||
sumf1[row] += d * scales[0];
|
||||
sumf2[row] += d;
|
||||
|
||||
s1 = s2 = 0;
|
||||
for (int l = 0; l < n; l += 2) {
|
||||
const uint16_t qs = q[l/2+8];
|
||||
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
|
||||
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
|
||||
}
|
||||
d = d_all * (s1 + 1.f/256.f * s2);
|
||||
sumf1[row] += d * scales[1];
|
||||
sumf2[row] += d;
|
||||
|
||||
q += step;
|
||||
h += step;
|
||||
a += step;
|
||||
dh += step;
|
||||
|
||||
}
|
||||
d = d_all * s;
|
||||
sumf1 += d * scales[1];
|
||||
sumf2 += d;
|
||||
//sumf += d_all * s * (scales[1] - 32);
|
||||
|
||||
y1 += 2 * QK_K;
|
||||
|
||||
}
|
||||
|
||||
//sum[ith] = sumf;
|
||||
sum[ith] = sumf1 - 32.f*sumf2;
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
||||
const float tot = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
||||
kernel void kernel_mul_mat_q3_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
const int nb = ne00/QK_K;
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t r1 = tgpig.y;
|
||||
|
||||
const int row = 2 * r0 + sgitg;
|
||||
|
||||
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
|
||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||
const int ix = tiisg/4;
|
||||
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
||||
const int im = il/8; // 0, 0, 1, 1
|
||||
const int in = il%8; // 0, 4, 0, 4
|
||||
|
||||
float sumf = 0;
|
||||
float2 sum = {0.f, 0.f};
|
||||
|
||||
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
||||
for (int i = ix; i < nb; i += 8) {
|
||||
|
||||
const float d_all = (float)(x[i].d);
|
||||
|
||||
device const uint8_t * q = x[i].qs + il;
|
||||
device const uint8_t * h = x[i].hmask + in;
|
||||
device const float * y = yy + i * QK_K + il;
|
||||
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
||||
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
||||
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
||||
device const float * y = yy + i * QK_K + il;
|
||||
|
||||
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
||||
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
||||
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
||||
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
||||
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
||||
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
||||
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
||||
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
||||
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const uint8_t hm = h[l] >> im;
|
||||
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
|
||||
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
|
||||
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
|
||||
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
|
||||
for (int l = 0; l < 4; l += 2) {
|
||||
const uint16_t hm = h[l/2] >> im;
|
||||
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
||||
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
||||
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
||||
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
||||
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
||||
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
||||
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
||||
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
||||
}
|
||||
|
||||
}
|
||||
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
||||
|
||||
sum[ith] = sumf;
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Accumulate the sum from all threads in the threadgroup
|
||||
//
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (ith%4 == 0) {
|
||||
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (ith%16 == 0) {
|
||||
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (ith == 0) {
|
||||
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
||||
dst[r1*ne0 + r0] = sum[0];
|
||||
const float tot = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + row] = tot;
|
||||
}
|
||||
|
||||
}
|
||||
#endif
|
||||
|
||||
#if QK_K == 256
|
||||
kernel void kernel_mul_mat_q4_K_f32(
|
||||
@@ -1776,7 +1859,6 @@ kernel void kernel_mul_mat_q5_K_f32(
|
||||
|
||||
for (int i = ix; i < nb; i += 8) {
|
||||
|
||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
yl[l+0] = y[l+ 0];
|
||||
yl[l+4] = y[l+16];
|
||||
|
@@ -1,7 +1,7 @@
|
||||
//go:build mpi
|
||||
|
||||
/**
|
||||
* llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
|
@@ -1,7 +1,7 @@
|
||||
//go:build mpi
|
||||
|
||||
/**
|
||||
* llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
|
@@ -1,7 +1,7 @@
|
||||
//go:build opencl
|
||||
|
||||
/**
|
||||
* llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
|
@@ -1,7 +1,7 @@
|
||||
//go:build opencl
|
||||
|
||||
/**
|
||||
* llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
|
1776
llama/ggml.c
121
llama/ggml.h
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
@@ -225,6 +225,7 @@
|
||||
#define GGML_MAX_CONTEXTS 64
|
||||
#define GGML_MAX_SRC 6
|
||||
#define GGML_MAX_NAME 48
|
||||
#define GGML_MAX_OP_PARAMS 32
|
||||
#define GGML_DEFAULT_N_THREADS 4
|
||||
|
||||
|
||||
@@ -233,6 +234,7 @@
|
||||
|
||||
#define GGML_UNUSED(x) (void)(x)
|
||||
|
||||
#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
|
||||
|
||||
#define GGML_ASSERT(x) \
|
||||
do { \
|
||||
@@ -355,16 +357,6 @@ extern "C" {
|
||||
GGML_OP_ARGMAX,
|
||||
GGML_OP_REPEAT,
|
||||
GGML_OP_REPEAT_BACK,
|
||||
GGML_OP_ABS,
|
||||
GGML_OP_SGN,
|
||||
GGML_OP_NEG,
|
||||
GGML_OP_STEP,
|
||||
GGML_OP_TANH,
|
||||
GGML_OP_ELU,
|
||||
GGML_OP_RELU,
|
||||
GGML_OP_GELU,
|
||||
GGML_OP_GELU_QUICK,
|
||||
GGML_OP_SILU,
|
||||
GGML_OP_SILU_BACK,
|
||||
GGML_OP_NORM, // normalize
|
||||
GGML_OP_RMS_NORM,
|
||||
@@ -403,6 +395,8 @@ extern "C" {
|
||||
GGML_OP_WIN_PART,
|
||||
GGML_OP_WIN_UNPART,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
GGML_OP_MAP_UNARY,
|
||||
GGML_OP_MAP_BINARY,
|
||||
|
||||
@@ -416,6 +410,24 @@ extern "C" {
|
||||
GGML_OP_COUNT,
|
||||
};
|
||||
|
||||
enum ggml_unary_op {
|
||||
GGML_UNARY_OP_ABS,
|
||||
GGML_UNARY_OP_SGN,
|
||||
GGML_UNARY_OP_NEG,
|
||||
GGML_UNARY_OP_STEP,
|
||||
GGML_UNARY_OP_TANH,
|
||||
GGML_UNARY_OP_ELU,
|
||||
GGML_UNARY_OP_RELU,
|
||||
GGML_UNARY_OP_GELU,
|
||||
GGML_UNARY_OP_GELU_QUICK,
|
||||
GGML_UNARY_OP_SILU,
|
||||
};
|
||||
|
||||
enum ggml_object_type {
|
||||
GGML_OBJECT_TENSOR,
|
||||
GGML_OBJECT_GRAPH,
|
||||
GGML_OBJECT_WORK_BUFFER
|
||||
};
|
||||
|
||||
// ggml object
|
||||
struct ggml_object {
|
||||
@@ -424,7 +436,9 @@ extern "C" {
|
||||
|
||||
struct ggml_object * next;
|
||||
|
||||
char padding[8];
|
||||
enum ggml_object_type type;
|
||||
|
||||
char padding[4];
|
||||
};
|
||||
|
||||
static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
|
||||
@@ -444,6 +458,9 @@ extern "C" {
|
||||
// compute data
|
||||
enum ggml_op op;
|
||||
|
||||
// op params - allocated as int32_t for alignment
|
||||
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
|
||||
bool is_param;
|
||||
|
||||
struct ggml_tensor * grad;
|
||||
@@ -460,7 +477,7 @@ extern "C" {
|
||||
|
||||
void * extra; // extra things e.g. for ggml-cuda.cu
|
||||
|
||||
char padding[8];
|
||||
char padding[4];
|
||||
};
|
||||
|
||||
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
||||
@@ -481,6 +498,11 @@ extern "C" {
|
||||
void * abort_callback_data;
|
||||
};
|
||||
|
||||
// next prime after GGML_MAX_NODES
|
||||
// #define GGML_GRAPH_HASHTABLE_SIZE 4099
|
||||
// next prime after GGML_MAX_NODES * 2 (nodes + leafs)
|
||||
#define GGML_GRAPH_HASHTABLE_SIZE 8273
|
||||
|
||||
// computation graph
|
||||
struct ggml_cgraph {
|
||||
int n_nodes;
|
||||
@@ -490,12 +512,16 @@ extern "C" {
|
||||
struct ggml_tensor * grads[GGML_MAX_NODES];
|
||||
struct ggml_tensor * leafs[GGML_MAX_NODES];
|
||||
|
||||
void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
|
||||
|
||||
// performance
|
||||
int perf_runs;
|
||||
int64_t perf_cycles;
|
||||
int64_t perf_time_us;
|
||||
};
|
||||
|
||||
static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph);
|
||||
|
||||
// scratch buffer
|
||||
struct ggml_scratch {
|
||||
size_t offs;
|
||||
@@ -557,6 +583,7 @@ extern "C" {
|
||||
|
||||
GGML_API const char * ggml_type_name(enum ggml_type type);
|
||||
GGML_API const char * ggml_op_name (enum ggml_op op);
|
||||
GGML_API const char * ggml_op_symbol(enum ggml_op op);
|
||||
|
||||
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
|
||||
|
||||
@@ -580,6 +607,7 @@ extern "C" {
|
||||
GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
|
||||
|
||||
GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
|
||||
GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx);
|
||||
GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
|
||||
|
||||
GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx);
|
||||
@@ -639,9 +667,11 @@ extern "C" {
|
||||
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
|
||||
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor);
|
||||
GGML_API struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name);
|
||||
GGML_API struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...);
|
||||
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
|
||||
GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
|
||||
GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
|
||||
|
||||
//
|
||||
// operations on tensors with backpropagation
|
||||
@@ -651,6 +681,11 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_dup_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_add(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@@ -875,14 +910,17 @@ extern "C" {
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_rms_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
struct ggml_tensor * a,
|
||||
float eps);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
struct ggml_tensor * a,
|
||||
float eps);
|
||||
|
||||
// a - x
|
||||
// b - dy
|
||||
// TODO: update with configurable eps
|
||||
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@@ -974,11 +1012,22 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// a -> b, in-place, return view(b)
|
||||
GGML_API struct ggml_tensor * ggml_cpy_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// make contiguous
|
||||
GGML_API struct ggml_tensor * ggml_cont(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// make contiguous, in-place
|
||||
GGML_API struct ggml_tensor * ggml_cont_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// return view(a), b specifies the new shape
|
||||
// TODO: when we start computing gradient, make a copy instead of view
|
||||
GGML_API struct ggml_tensor * ggml_reshape(
|
||||
@@ -1147,16 +1196,27 @@ extern "C" {
|
||||
int mode,
|
||||
int n_ctx);
|
||||
|
||||
// custom RoPE, in-place, returns view(a)
|
||||
// custom RoPE
|
||||
GGML_API struct ggml_tensor * ggml_rope_custom(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
float freq_base,
|
||||
float freq_scale);
|
||||
|
||||
// in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
int n_ctx);
|
||||
float freq_scale);
|
||||
|
||||
// rotary position embedding backward, i.e compute dx from dy
|
||||
// a - dy
|
||||
@@ -1165,7 +1225,8 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int mode);
|
||||
int mode,
|
||||
int n_ctx);
|
||||
|
||||
// alibi position embedding
|
||||
// in-place, returns view(a)
|
||||
@@ -1289,6 +1350,16 @@ extern "C" {
|
||||
typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
||||
typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_unary(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_unary_op op);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_unary_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_unary_op op);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_map_unary_f32(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@@ -1368,11 +1439,17 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * tensor);
|
||||
|
||||
|
||||
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
|
||||
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
|
||||
|
||||
// graph allocation in a context
|
||||
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx);
|
||||
GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor);
|
||||
GGML_API size_t ggml_graph_overhead(void);
|
||||
|
||||
// ggml_graph_plan() has to be called before ggml_graph_compute()
|
||||
// when plan.work_size > 0, caller must allocate memory for plan.work_data
|
||||
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
|
||||
|
376
llama/k_quants.c
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
@@ -65,6 +65,8 @@
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
||||
|
||||
//
|
||||
// 2-6 bit quantization in super-blocks
|
||||
//
|
||||
@@ -1379,7 +1381,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
|
||||
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
|
||||
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
|
||||
|
||||
__m256i sumi = _mm256_setzero_si256();
|
||||
|
||||
@@ -1447,7 +1449,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
|
||||
|
||||
// sumf += -dmin * summs in 32bits*8
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(_mm256_set_m128i(summs_1, summs_0))), acc);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
|
||||
|
||||
const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
|
||||
const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
|
||||
@@ -1519,7 +1521,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
}
|
||||
|
||||
// sumf += dall * isum - dmin * summs in 32bits
|
||||
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
||||
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
|
||||
}
|
||||
|
||||
@@ -1670,8 +1672,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
summs += dmin * smin;
|
||||
|
||||
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
|
||||
const __m256i q2_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 2), q2bits), m3);
|
||||
const __m256i q2_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
|
||||
const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3);
|
||||
const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
|
||||
|
||||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||
@@ -1692,6 +1694,62 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
|
||||
*s = hsum_float_8(acc) + summs;
|
||||
|
||||
#elif defined __AVX__
|
||||
|
||||
const __m128i m3 = _mm_set1_epi8(3);
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
uint32_t ud, um;
|
||||
const uint8_t * restrict db = (const uint8_t *)&ud;
|
||||
const uint8_t * restrict mb = (const uint8_t *)&um;
|
||||
|
||||
float summs = 0;
|
||||
|
||||
// TODO: optimize this
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
||||
|
||||
const uint8_t * restrict q2 = x[i].qs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
|
||||
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
|
||||
ud = (sc[0] >> 0) & 0x0f0f0f0f;
|
||||
um = (sc[0] >> 4) & 0x0f0f0f0f;
|
||||
|
||||
int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
|
||||
summs += dmin * smin;
|
||||
|
||||
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
|
||||
const __m128i q2_0 = _mm_and_si128(q2bits, m3);
|
||||
const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
|
||||
const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
|
||||
const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
|
||||
|
||||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||
|
||||
const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0));
|
||||
const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1));
|
||||
const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
|
||||
const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
|
||||
|
||||
const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
|
||||
const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
|
||||
const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
|
||||
const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
|
||||
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc);
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc) + summs;
|
||||
|
||||
#else
|
||||
|
||||
float sumf = 0;
|
||||
@@ -1887,7 +1945,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
|
||||
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
|
||||
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
|
||||
|
||||
// high bit
|
||||
const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
|
||||
@@ -2098,7 +2156,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
}
|
||||
|
||||
// multiply with block scale and accumulate
|
||||
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
||||
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
|
||||
|
||||
}
|
||||
@@ -2273,13 +2331,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
aux16[0] = a & 0x0f0f;
|
||||
aux16[1] = (a >> 4) & 0x0f0f;
|
||||
|
||||
const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
|
||||
const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
|
||||
const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
|
||||
const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
|
||||
|
||||
memcpy(&aux64, x[i].hmask, 8);
|
||||
|
||||
const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
|
||||
__m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux);
|
||||
__m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux);
|
||||
__m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
|
||||
q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
|
||||
q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
|
||||
@@ -2288,7 +2346,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
|
||||
|
||||
// prepare low and high bits
|
||||
const __m256i q3aux = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits);
|
||||
const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits);
|
||||
const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
|
||||
const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
|
||||
|
||||
@@ -2321,6 +2379,93 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
|
||||
#elif defined __AVX__
|
||||
|
||||
const __m128i m3 = _mm_set1_epi8(3);
|
||||
const __m128i m1 = _mm_set1_epi8(1);
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
uint64_t aux64;
|
||||
|
||||
uint16_t aux16[2];
|
||||
const int8_t * aux8 = (const int8_t *)aux16;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||
|
||||
const uint8_t * restrict q3 = x[i].qs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
|
||||
const uint16_t a = *(const uint16_t *)x[i].scales;
|
||||
aux16[0] = a & 0x0f0f;
|
||||
aux16[1] = (a >> 4) & 0x0f0f;
|
||||
|
||||
const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
|
||||
const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
|
||||
const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
|
||||
const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
|
||||
|
||||
memcpy(&aux64, x[i].hmask, 8);
|
||||
|
||||
__m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
|
||||
__m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
|
||||
__m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
|
||||
__m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
|
||||
q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
|
||||
q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
|
||||
q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
|
||||
q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
|
||||
|
||||
// load low 2 bits
|
||||
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
|
||||
|
||||
// prepare low and high bits
|
||||
const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
|
||||
const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
|
||||
const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
|
||||
const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
|
||||
|
||||
// load Q8 quants
|
||||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||
|
||||
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
|
||||
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
|
||||
// and 2 if the high bit was set)
|
||||
const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
|
||||
const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
|
||||
const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
|
||||
const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
|
||||
|
||||
__m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
|
||||
__m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
|
||||
__m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
|
||||
__m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
|
||||
|
||||
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
||||
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
||||
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
||||
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
||||
|
||||
// multiply with scales
|
||||
p16_0 = _mm_madd_epi16(scale_0, p16_0);
|
||||
p16_1 = _mm_madd_epi16(scale_1, p16_1);
|
||||
p16_2 = _mm_madd_epi16(scale_2, p16_2);
|
||||
p16_3 = _mm_madd_epi16(scale_3, p16_3);
|
||||
|
||||
p16_0 = _mm_add_epi32(p16_0, p16_2);
|
||||
p16_1 = _mm_add_epi32(p16_1, p16_3);
|
||||
__m256i p16 = MM256_SET_M128I(p16_1, p16_0);
|
||||
|
||||
// multiply with block scale and accumulate
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
|
||||
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
|
||||
#else
|
||||
|
||||
int8_t aux8[QK_K];
|
||||
@@ -2503,7 +2648,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
|
||||
|
||||
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
||||
const __m256i scales = _mm256_set_m128i(sc128, sc128);
|
||||
const __m256i scales = MM256_SET_M128I(sc128, sc128);
|
||||
|
||||
__m256i sumi = _mm256_setzero_si256();
|
||||
|
||||
@@ -2610,7 +2755,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
}
|
||||
|
||||
__m256 vd = _mm256_set1_ps(d);
|
||||
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
||||
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
|
||||
|
||||
}
|
||||
@@ -2807,6 +2952,60 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
|
||||
*s = hsum_float_8(acc) - summs;
|
||||
|
||||
#elif defined __AVX__
|
||||
|
||||
const __m128i m4 = _mm_set1_epi8(0xF);
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
float summs = 0;
|
||||
|
||||
uint16_t aux16[2];
|
||||
const uint8_t * scales = (const uint8_t *)aux16;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
|
||||
const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
|
||||
const __m256 vd = _mm256_set1_ps(d);
|
||||
|
||||
const uint16_t * a = (const uint16_t *)x[i].scales;
|
||||
aux16[0] = a[0] & 0x0f0f;
|
||||
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
||||
|
||||
summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
||||
|
||||
const uint8_t * restrict q4 = x[i].qs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
|
||||
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
|
||||
const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
|
||||
const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
|
||||
const __m128i q4_0 = _mm_and_si128(q4bits_0, m4);
|
||||
const __m128i q4_1 = _mm_and_si128(q4bits_1, m4);
|
||||
const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
|
||||
const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
|
||||
|
||||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||
|
||||
const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
|
||||
const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
|
||||
const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
|
||||
const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
|
||||
|
||||
const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
|
||||
const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc);
|
||||
|
||||
const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
|
||||
const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc);
|
||||
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc) - summs;
|
||||
|
||||
#else
|
||||
|
||||
uint8_t aux8[QK_K];
|
||||
@@ -2989,7 +3188,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
summs += dmin * _mm_extract_epi32(hsum, 0);
|
||||
|
||||
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
||||
const __m256i scales = _mm256_set_m128i(sc128, sc128);
|
||||
const __m256i scales = MM256_SET_M128I(sc128, sc128);
|
||||
|
||||
const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
|
||||
__m256i hmask = mone;
|
||||
@@ -3128,7 +3327,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
}
|
||||
|
||||
__m256 vd = _mm256_set1_ps(d);
|
||||
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
||||
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
|
||||
|
||||
}
|
||||
@@ -3291,13 +3490,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
|
||||
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
|
||||
|
||||
const __m256i scale_l = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
|
||||
const __m256i scale_h = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
|
||||
const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
|
||||
const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
|
||||
|
||||
int64_t aux64;
|
||||
memcpy(&aux64, x[i].qh, 8);
|
||||
const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
|
||||
const __m256i haux256 = _mm256_set_m128i(_mm_srli_epi16(haux128, 2), haux128);
|
||||
const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128);
|
||||
|
||||
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
|
||||
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
|
||||
@@ -3321,10 +3520,66 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
|
||||
#elif defined __AVX__
|
||||
|
||||
const __m128i m4 = _mm_set1_epi8(0xF);
|
||||
const __m128i mone = _mm_set1_epi8(1);
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const uint8_t * restrict q5 = x[i].qs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
|
||||
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||
|
||||
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
|
||||
|
||||
const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]);
|
||||
const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]);
|
||||
const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]);
|
||||
const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]);
|
||||
|
||||
int64_t aux64;
|
||||
memcpy(&aux64, x[i].qh, 8);
|
||||
const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64);
|
||||
const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2);
|
||||
|
||||
const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4);
|
||||
const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4);
|
||||
const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4);
|
||||
const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4);
|
||||
|
||||
const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4);
|
||||
const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4);
|
||||
const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4);
|
||||
const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4);
|
||||
|
||||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||
|
||||
const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0)));
|
||||
const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1)));
|
||||
const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0)));
|
||||
const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1)));
|
||||
const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0)));
|
||||
const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1)));
|
||||
const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0)));
|
||||
const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1)));
|
||||
|
||||
const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
|
||||
const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
|
||||
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc);
|
||||
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
|
||||
#else
|
||||
|
||||
|
||||
uint8_t aux8[QK_K];
|
||||
int8_t aux8[QK_K];
|
||||
int16_t aux16[16];
|
||||
float sums [8];
|
||||
memset(sums, 0, 8*sizeof(float));
|
||||
@@ -3334,7 +3589,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
const uint8_t * restrict q4 = x[i].qs;
|
||||
const uint8_t * restrict hm = x[i].qh;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
uint8_t * restrict a = aux8;
|
||||
int8_t * restrict a = aux8;
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
a[l+ 0] = q4[l] & 0xF;
|
||||
a[l+32] = q4[l] >> 4;
|
||||
@@ -3698,7 +3953,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
|
||||
}
|
||||
|
||||
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
||||
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
|
||||
}
|
||||
|
||||
@@ -3856,8 +4111,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
|
||||
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
|
||||
|
||||
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
|
||||
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
|
||||
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
|
||||
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
|
||||
|
||||
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
|
||||
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
|
||||
@@ -3884,6 +4139,77 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
|
||||
#elif defined __AVX__
|
||||
|
||||
const __m128i m4 = _mm_set1_epi8(0xF);
|
||||
const __m128i m2 = _mm_set1_epi8(3);
|
||||
const __m128i m32s = _mm_set1_epi8(32);
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||
|
||||
const uint8_t * restrict q4 = x[i].ql;
|
||||
const uint8_t * restrict qh = x[i].qh;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
|
||||
const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
|
||||
const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
|
||||
const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
|
||||
const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
|
||||
|
||||
__m128i sumi_0 = _mm_setzero_si128();
|
||||
__m128i sumi_1 = _mm_setzero_si128();
|
||||
|
||||
const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
|
||||
const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
|
||||
|
||||
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
|
||||
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
|
||||
|
||||
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
|
||||
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
|
||||
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
|
||||
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
|
||||
|
||||
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0);
|
||||
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1);
|
||||
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2);
|
||||
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3);
|
||||
|
||||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||
|
||||
__m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
|
||||
__m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
|
||||
__m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
|
||||
__m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
|
||||
|
||||
__m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
|
||||
__m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
|
||||
__m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
|
||||
__m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
|
||||
|
||||
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
||||
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
||||
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
||||
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
||||
|
||||
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
|
||||
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
|
||||
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
|
||||
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
|
||||
|
||||
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
||||
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
||||
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc);
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
|
||||
#else
|
||||
|
||||
int8_t aux8[QK_K];
|
||||
|
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
|
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
|
912
llama/llama.cpp
308
llama/llama.go
@@ -1,9 +1,10 @@
|
||||
package llama
|
||||
|
||||
/*
|
||||
#cgo CPPFLAGS: -O3 -DNDEBUG=1
|
||||
#cgo CXXFLAGS: -std=c++11
|
||||
#cgo darwin CPPFLAGS: -DGGML_USE_METAL=1 -DGGML_METAL_NDEBUG=1
|
||||
#cgo CPPFLAGS: -O3 -Wall -Wextra -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
|
||||
#cgo CXXFLAGS: -std=gnu++11
|
||||
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE
|
||||
#cgo darwin,arm64 CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG
|
||||
#cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
|
||||
#include <stdlib.h>
|
||||
#include "llama.h"
|
||||
@@ -21,6 +22,7 @@ struct llama_sample_options
|
||||
int mirostat;
|
||||
float mirostat_tau;
|
||||
float mirostat_eta;
|
||||
bool penalize_newline;
|
||||
};
|
||||
|
||||
llama_token llama_sample(
|
||||
@@ -37,6 +39,8 @@ llama_token llama_sample(
|
||||
false,
|
||||
};
|
||||
|
||||
struct llama_token_data newline = candidates_p.data[llama_token_nl()];
|
||||
|
||||
llama_sample_repetition_penalty(
|
||||
ctx, &candidates_p,
|
||||
last_tokens, n_last_tokens,
|
||||
@@ -47,6 +51,10 @@ llama_token llama_sample(
|
||||
last_tokens, n_last_tokens,
|
||||
opts->frequency_penalty, opts->presence_penalty);
|
||||
|
||||
if (!opts->penalize_newline) {
|
||||
candidates_p.data[llama_token_nl()] = newline;
|
||||
}
|
||||
|
||||
if (opts->temperature <= 0) {
|
||||
return llama_sample_token_greedy(ctx, &candidates_p);
|
||||
}
|
||||
@@ -79,32 +87,44 @@ llama_token llama_sample(
|
||||
import "C"
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
"unsafe"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
type llama struct {
|
||||
//go:embed ggml-metal.metal
|
||||
var fs embed.FS
|
||||
|
||||
type LLM struct {
|
||||
params *C.struct_llama_context_params
|
||||
model *C.struct_llama_model
|
||||
ctx *C.struct_llama_context
|
||||
|
||||
last []C.llama_token
|
||||
embd []C.llama_token
|
||||
cursor int
|
||||
|
||||
mu sync.Mutex
|
||||
gc bool
|
||||
|
||||
api.Options
|
||||
}
|
||||
|
||||
func New(model string, opts api.Options) (*llama, error) {
|
||||
func New(model string, opts api.Options) (*LLM, error) {
|
||||
if _, err := os.Stat(model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
llm := llama{Options: opts}
|
||||
llm := LLM{Options: opts}
|
||||
|
||||
C.llama_backend_init(C.bool(llm.UseNUMA))
|
||||
|
||||
@@ -112,6 +132,7 @@ func New(model string, opts api.Options) (*llama, error) {
|
||||
params.seed = C.uint(llm.Seed)
|
||||
params.n_ctx = C.int(llm.NumCtx)
|
||||
params.n_batch = C.int(llm.NumBatch)
|
||||
params.n_gqa = C.int(llm.NumGQA)
|
||||
params.n_gpu_layers = C.int(llm.NumGPU)
|
||||
params.main_gpu = C.int(llm.MainGPU)
|
||||
params.low_vram = C.bool(llm.LowVRAM)
|
||||
@@ -144,31 +165,147 @@ func New(model string, opts api.Options) (*llama, error) {
|
||||
return &llm, nil
|
||||
}
|
||||
|
||||
func (llm *llama) Close() {
|
||||
func (llm *LLM) Close() {
|
||||
llm.gc = true
|
||||
|
||||
llm.mu.Lock()
|
||||
defer llm.mu.Unlock()
|
||||
|
||||
defer C.llama_free_model(llm.model)
|
||||
defer C.llama_free(llm.ctx)
|
||||
|
||||
C.llama_print_timings(llm.ctx)
|
||||
}
|
||||
|
||||
func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||
if input := llm.tokenize(prompt); input != nil {
|
||||
embd := make([]C.llama_token, len(ctx))
|
||||
for i := range ctx {
|
||||
embd[i] = C.llama_token(ctx[i])
|
||||
}
|
||||
var errNeedMoreData = errors.New("need more data")
|
||||
|
||||
return llm.generate(append(embd, input...), fn)
|
||||
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||
C.llama_reset_timings(llm.ctx)
|
||||
|
||||
tokens := make([]C.llama_token, len(ctx))
|
||||
for i := range tokens {
|
||||
tokens[i] = C.llama_token(ctx[i])
|
||||
}
|
||||
|
||||
return errors.New("llama: tokenize")
|
||||
if len(tokens) == 0 {
|
||||
tokens = llm.tokenize(" ")
|
||||
}
|
||||
|
||||
llm.marshalPrompt(tokens, prompt)
|
||||
|
||||
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
|
||||
|
||||
var b bytes.Buffer
|
||||
for {
|
||||
token, err := llm.next()
|
||||
if llm.gc {
|
||||
return nil
|
||||
} else if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.WriteString(llm.detokenize(token))
|
||||
|
||||
if err := llm.checkStopConditions(b); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if errors.Is(err, errNeedMoreData) {
|
||||
continue
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
|
||||
fn(api.GenerateResponse{Response: b.String()})
|
||||
b.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
last := make([]int, 0, len(llm.last))
|
||||
for _, i := range llm.last {
|
||||
if i != 0 {
|
||||
last = append(last, int(i))
|
||||
}
|
||||
}
|
||||
|
||||
timings := C.llama_get_timings(llm.ctx)
|
||||
fn(api.GenerateResponse{
|
||||
Done: true,
|
||||
Context: last,
|
||||
SampleCount: int(timings.n_sample),
|
||||
SampleDuration: parseDurationMs(float64(timings.t_sample_ms)),
|
||||
PromptEvalCount: int(timings.n_p_eval),
|
||||
PromptEvalDuration: parseDurationMs(float64(timings.t_p_eval_ms)),
|
||||
EvalCount: int(timings.n_eval),
|
||||
EvalDuration: parseDurationMs(float64(timings.t_eval_ms)),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *llama) tokenize(prompt string) []C.llama_token {
|
||||
func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
|
||||
for _, stopCondition := range llm.Stop {
|
||||
if stopCondition == b.String() {
|
||||
return io.EOF
|
||||
} else if strings.HasPrefix(stopCondition, b.String()) {
|
||||
return errNeedMoreData
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
|
||||
tokens := append(ctx, llm.tokenize(prompt)...)
|
||||
if llm.NumKeep < 0 {
|
||||
llm.NumKeep = len(tokens)
|
||||
}
|
||||
|
||||
// min(llm.NumCtx - 4, llm.NumKeep)
|
||||
if llm.NumCtx-4 < llm.NumKeep {
|
||||
llm.NumKeep = llm.NumCtx - 4
|
||||
}
|
||||
|
||||
if len(tokens) >= llm.NumCtx {
|
||||
// truncate input
|
||||
numLeft := (llm.NumCtx - llm.NumKeep) / 2
|
||||
truncated := tokens[:llm.NumKeep]
|
||||
erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
|
||||
truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
|
||||
copy(llm.last, tokens[len(tokens)-llm.NumCtx:])
|
||||
|
||||
tokens = truncated
|
||||
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
|
||||
} else {
|
||||
llm.last = make([]C.llama_token, llm.NumCtx-len(tokens))
|
||||
llm.last = append(llm.last, tokens...)
|
||||
}
|
||||
|
||||
var i int
|
||||
for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ {
|
||||
// noop
|
||||
}
|
||||
|
||||
llm.embd = tokens
|
||||
if i == len(tokens) {
|
||||
// evaluate at least one token to generate logits
|
||||
i--
|
||||
}
|
||||
|
||||
llm.cursor = i
|
||||
|
||||
log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
|
||||
return tokens
|
||||
}
|
||||
|
||||
func (llm *LLM) tokenize(prompt string) []C.llama_token {
|
||||
cPrompt := C.CString(prompt)
|
||||
defer C.free(unsafe.Pointer(cPrompt))
|
||||
|
||||
tokens := make([]C.llama_token, llm.NumCtx)
|
||||
tokens := make([]C.llama_token, len(prompt)+1)
|
||||
if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(tokens), C.int(len(tokens)), true); n > 0 {
|
||||
return tokens[:n]
|
||||
}
|
||||
@@ -176,7 +313,7 @@ func (llm *llama) tokenize(prompt string) []C.llama_token {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *llama) detokenize(tokens ...C.llama_token) string {
|
||||
func (llm *LLM) detokenize(tokens ...C.llama_token) string {
|
||||
var sb strings.Builder
|
||||
for _, token := range tokens {
|
||||
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
|
||||
@@ -185,98 +322,93 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
|
||||
var opts C.struct_llama_sample_options
|
||||
opts.repeat_penalty = C.float(llm.RepeatPenalty)
|
||||
opts.frequency_penalty = C.float(llm.FrequencyPenalty)
|
||||
opts.presence_penalty = C.float(llm.PresencePenalty)
|
||||
opts.temperature = C.float(llm.Temperature)
|
||||
opts.top_k = C.int(llm.TopK)
|
||||
opts.top_p = C.float(llm.TopP)
|
||||
opts.tfs_z = C.float(llm.TFSZ)
|
||||
opts.typical_p = C.float(llm.TypicalP)
|
||||
opts.mirostat = C.int(llm.Mirostat)
|
||||
opts.mirostat_tau = C.float(llm.MirostatTau)
|
||||
opts.mirostat_eta = C.float(llm.MirostatEta)
|
||||
func (llm *LLM) next() (C.llama_token, error) {
|
||||
llm.mu.Lock()
|
||||
defer llm.mu.Unlock()
|
||||
|
||||
output := deque[C.llama_token]{capacity: llm.NumCtx}
|
||||
if len(llm.embd) >= llm.NumCtx {
|
||||
numLeft := (llm.NumCtx - llm.NumKeep) / 2
|
||||
truncated := llm.embd[:llm.NumKeep]
|
||||
truncated = append(truncated, llm.embd[len(llm.embd)-numLeft:]...)
|
||||
|
||||
context := deque[int]{capacity: llm.NumCtx / 2}
|
||||
for _, in := range input {
|
||||
context.PushLeft(int(in))
|
||||
llm.embd = truncated
|
||||
llm.cursor = llm.NumKeep
|
||||
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d cursor=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated), llm.cursor)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
|
||||
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
|
||||
return errors.New("llama: eval")
|
||||
for {
|
||||
if llm.gc {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
token, err := llm.sample(output, &opts)
|
||||
if errors.Is(err, io.EOF) {
|
||||
if llm.cursor >= len(llm.embd) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.WriteString(llm.detokenize(token))
|
||||
if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
|
||||
// call the callback
|
||||
fn(api.GenerateResponse{
|
||||
Response: b.String(),
|
||||
})
|
||||
|
||||
output.PushLeft(token)
|
||||
context.PushLeft(int(token))
|
||||
b.Reset()
|
||||
numEval := len(llm.embd) - llm.cursor
|
||||
if numEval > llm.NumBatch {
|
||||
numEval = llm.NumBatch
|
||||
}
|
||||
|
||||
input = []C.llama_token{token}
|
||||
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(llm.embd[llm.cursor:]), C.int(numEval), C.int(llm.cursor), C.int(llm.NumThread)); retval != 0 {
|
||||
return 0, fmt.Errorf("llama_eval: %d", retval)
|
||||
}
|
||||
|
||||
llm.cursor += numEval
|
||||
}
|
||||
|
||||
dur := func(ms float64) time.Duration {
|
||||
d, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var sampleOpts C.struct_llama_sample_options
|
||||
sampleOpts.repeat_penalty = C.float(llm.RepeatPenalty)
|
||||
sampleOpts.frequency_penalty = C.float(llm.FrequencyPenalty)
|
||||
sampleOpts.presence_penalty = C.float(llm.PresencePenalty)
|
||||
sampleOpts.temperature = C.float(llm.Temperature)
|
||||
sampleOpts.top_k = C.int(llm.TopK)
|
||||
sampleOpts.top_p = C.float(llm.TopP)
|
||||
sampleOpts.tfs_z = C.float(llm.TFSZ)
|
||||
sampleOpts.typical_p = C.float(llm.TypicalP)
|
||||
sampleOpts.mirostat = C.int(llm.Mirostat)
|
||||
sampleOpts.mirostat_tau = C.float(llm.MirostatTau)
|
||||
sampleOpts.mirostat_eta = C.float(llm.MirostatEta)
|
||||
sampleOpts.penalize_newline = C.bool(llm.PenalizeNewline)
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
timings := C.llama_get_timings(llm.ctx)
|
||||
fn(api.GenerateResponse{
|
||||
Done: true,
|
||||
Context: context.Data(),
|
||||
PromptEvalCount: int(timings.n_p_eval),
|
||||
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
|
||||
EvalCount: int(timings.n_eval),
|
||||
EvalDuration: dur(float64(timings.t_eval_ms)),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
|
||||
numVocab := int(C.llama_n_vocab(llm.ctx))
|
||||
numVocab := C.llama_n_vocab(llm.ctx)
|
||||
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
|
||||
|
||||
candidates := deque[C.struct_llama_token_data]{capacity: numVocab}
|
||||
for i := 0; i < candidates.Cap(); i++ {
|
||||
candidates.PushLeft(C.struct_llama_token_data{
|
||||
// TODO: logit bias
|
||||
|
||||
candidates := make([]C.llama_token_data, numVocab)
|
||||
for i := range logits {
|
||||
candidates[i] = C.llama_token_data{
|
||||
id: C.int(i),
|
||||
logit: logits[i],
|
||||
p: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
repeatLastN := llm.RepeatLastN
|
||||
if len(llm.last) < repeatLastN {
|
||||
repeatLastN = len(llm.last)
|
||||
}
|
||||
|
||||
if llm.NumCtx < repeatLastN {
|
||||
repeatLastN = llm.NumCtx
|
||||
}
|
||||
|
||||
lastN := llm.last[len(llm.last)-repeatLastN:]
|
||||
|
||||
token := C.llama_sample(
|
||||
llm.ctx,
|
||||
unsafe.SliceData(candidates.Data()), C.size_t(candidates.Len()),
|
||||
unsafe.SliceData(output.Data()), C.size_t(output.Len()),
|
||||
opts)
|
||||
if token != C.llama_token_eos() {
|
||||
return token, nil
|
||||
unsafe.SliceData(candidates), C.size_t(len(candidates)),
|
||||
unsafe.SliceData(lastN), C.size_t(len(lastN)),
|
||||
&sampleOpts,
|
||||
)
|
||||
|
||||
llm.last = append(llm.last, token)
|
||||
llm.embd = append(llm.embd, token)
|
||||
|
||||
if token == C.llama_token_eos() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
return 0, io.EOF
|
||||
return token, nil
|
||||
}
|
||||
|
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc
|
||||
* llama.cpp - git 8183159cf3def112f6d1fe94815fce70e1bffa12
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
@@ -79,6 +79,10 @@
|
||||
#define LLAMA_SUPPORTS_GPU_OFFLOAD
|
||||
#endif
|
||||
|
||||
#ifndef LLAMA_DEFAULT_RMS_EPS
|
||||
#define LLAMA_DEFAULT_RMS_EPS 5e-6f
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@@ -109,12 +113,15 @@ extern "C" {
|
||||
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
||||
|
||||
struct llama_context_params {
|
||||
uint32_t seed; // RNG seed, -1 for random
|
||||
int32_t n_ctx; // text context
|
||||
int32_t n_batch; // prompt processing batch size
|
||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
|
||||
uint32_t seed; // RNG seed, -1 for random
|
||||
int32_t n_ctx; // text context
|
||||
int32_t n_batch; // prompt processing batch size
|
||||
int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
|
||||
float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams)
|
||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||
|
||||
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
|
||||
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
float rope_freq_base; // RoPE base frequency
|
||||
@@ -127,6 +134,7 @@ extern "C" {
|
||||
|
||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
bool low_vram; // if true, reduce VRAM usage at the cost of performance
|
||||
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
|
||||
bool f16_kv; // use fp16 for KV cache
|
||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
@@ -165,6 +173,40 @@ extern "C" {
|
||||
bool quantize_output_tensor; // quantize output.weight
|
||||
} llama_model_quantize_params;
|
||||
|
||||
// grammar types
|
||||
struct llama_grammar;
|
||||
|
||||
// grammar element type
|
||||
enum llama_gretype {
|
||||
// end of rule definition
|
||||
LLAMA_GRETYPE_END = 0,
|
||||
|
||||
// start of alternate definition for rule
|
||||
LLAMA_GRETYPE_ALT = 1,
|
||||
|
||||
// non-terminal element: reference to rule
|
||||
LLAMA_GRETYPE_RULE_REF = 2,
|
||||
|
||||
// terminal element: character (code point)
|
||||
LLAMA_GRETYPE_CHAR = 3,
|
||||
|
||||
// inverse char(s) ([^a], [^a-b] [^abc])
|
||||
LLAMA_GRETYPE_CHAR_NOT = 4,
|
||||
|
||||
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
||||
// be an inclusive range ([a-z])
|
||||
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
||||
|
||||
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
||||
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||
LLAMA_GRETYPE_CHAR_ALT = 6,
|
||||
};
|
||||
|
||||
typedef struct llama_grammar_element {
|
||||
enum llama_gretype type;
|
||||
uint32_t value; // Unicode code point or rule ID
|
||||
} llama_grammar_element;
|
||||
|
||||
// performance timing information
|
||||
struct llama_timings {
|
||||
double t_start_ms;
|
||||
@@ -357,6 +399,15 @@ extern "C" {
|
||||
LLAMA_API llama_token llama_token_eos(); // end-of-sentence
|
||||
LLAMA_API llama_token llama_token_nl(); // next-line
|
||||
|
||||
// Grammar
|
||||
//
|
||||
LLAMA_API struct llama_grammar * llama_grammar_init(
|
||||
const llama_grammar_element ** rules,
|
||||
size_t n_rules,
|
||||
size_t start_rule_index);
|
||||
|
||||
LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
|
||||
|
||||
// Sampling functions
|
||||
|
||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
@@ -369,13 +420,11 @@ extern "C" {
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
|
||||
/// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
|
||||
/// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
|
||||
/// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits.
|
||||
LLAMA_API void llama_sample_classifier_free_guidance(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_context * guidance_ctx,
|
||||
float scale,
|
||||
float smooth_factor);
|
||||
float scale);
|
||||
|
||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
@@ -393,6 +442,9 @@ extern "C" {
|
||||
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
||||
|
||||
/// @details Apply constraints from grammar
|
||||
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
@@ -414,6 +466,9 @@ extern "C" {
|
||||
/// @details Randomly selects a token from the candidates based on their probabilities.
|
||||
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
|
||||
/// @details Accepts the sampled token into the grammar
|
||||
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
|
||||
|
||||
// Performance information
|
||||
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
||||
|
80
llama/llama_darwin.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if err := initBackend(); err != nil {
|
||||
log.Printf("WARNING: GPU could not be initialized correctly: %v", err)
|
||||
log.Printf("WARNING: falling back to CPU")
|
||||
}
|
||||
}
|
||||
|
||||
func initBackend() error {
|
||||
exec, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
exec, err = filepath.EvalSymlinks(exec)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
metal := filepath.Join(filepath.Dir(exec), "ggml-metal.metal")
|
||||
fi, err := os.Stat(metal)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
|
||||
if fi != nil {
|
||||
actual, err := os.Open(metal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
actualSum := sha256.New()
|
||||
if _, err := io.Copy(actualSum, actual); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expect, err := fs.Open("ggml-metal.metal")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expectSum := sha256.New()
|
||||
if _, err := io.Copy(expectSum, expect); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if bytes.Equal(actualSum.Sum(nil), expectSum.Sum(nil)) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
dst, err := os.Create(filepath.Join(filepath.Dir(exec), "ggml-metal.metal"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
src, err := fs.Open("ggml-metal.metal")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
if _, err := io.Copy(dst, src); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
0
llama/update-llama-cpp.sh
Normal file → Executable file
107
llama/utils.go
@@ -1,104 +1,15 @@
|
||||
package llama
|
||||
|
||||
type node[T any] struct {
|
||||
t T
|
||||
next *node[T]
|
||||
prev *node[T]
|
||||
}
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type deque[T any] struct {
|
||||
head *node[T]
|
||||
tail *node[T]
|
||||
size int
|
||||
capacity int
|
||||
}
|
||||
|
||||
func (d *deque[T]) Empty() bool {
|
||||
return d.size == 0
|
||||
}
|
||||
|
||||
func (d *deque[T]) Len() int {
|
||||
return d.size
|
||||
}
|
||||
|
||||
func (d *deque[T]) Cap() int {
|
||||
return d.capacity
|
||||
}
|
||||
|
||||
func (d *deque[T]) Push(t T) {
|
||||
if d.capacity > 0 && d.size >= d.capacity {
|
||||
d.PopLeft()
|
||||
func parseDurationMs(ms float64) time.Duration {
|
||||
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
n := node[T]{t: t}
|
||||
if d.head != nil {
|
||||
n.next = d.head
|
||||
d.head.prev = &n
|
||||
d.head = &n
|
||||
} else {
|
||||
d.head = &n
|
||||
d.tail = &n
|
||||
}
|
||||
|
||||
d.size++
|
||||
}
|
||||
|
||||
func (d *deque[T]) PushLeft(t T) {
|
||||
if d.capacity > 0 && d.size >= d.capacity {
|
||||
d.Pop()
|
||||
}
|
||||
|
||||
n := node[T]{t: t}
|
||||
if d.tail != nil {
|
||||
n.prev = d.tail
|
||||
d.tail.next = &n
|
||||
d.tail = &n
|
||||
} else {
|
||||
d.head = &n
|
||||
d.tail = &n
|
||||
}
|
||||
|
||||
d.size++
|
||||
}
|
||||
|
||||
func (d *deque[T]) Pop() *T {
|
||||
if d.Empty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
head := d.head
|
||||
d.head = head.next
|
||||
if d.head != nil {
|
||||
d.head.prev = nil
|
||||
} else {
|
||||
d.tail = nil
|
||||
}
|
||||
|
||||
d.size--
|
||||
return &head.t
|
||||
}
|
||||
|
||||
func (d *deque[T]) PopLeft() *T {
|
||||
if d.Empty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
tail := d.tail
|
||||
d.tail = tail.prev
|
||||
if d.tail != nil {
|
||||
d.tail.next = nil
|
||||
} else {
|
||||
d.head = nil
|
||||
}
|
||||
|
||||
d.size--
|
||||
return &tail.t
|
||||
}
|
||||
|
||||
func (d *deque[T]) Data() (data []T) {
|
||||
for n := d.head; n != nil; n = n.next {
|
||||
data = append(data, n.t)
|
||||
}
|
||||
|
||||
return data
|
||||
return dur
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
)
|
||||
|
||||
type Command struct {
|
||||
@@ -20,16 +21,16 @@ func (c *Command) Reset() {
|
||||
|
||||
func Parse(reader io.Reader) ([]Command, error) {
|
||||
var commands []Command
|
||||
|
||||
var command, modelCommand Command
|
||||
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
|
||||
scanner.Split(scanModelfile)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
|
||||
fields := bytes.SplitN(line, []byte(" "), 2)
|
||||
if len(fields) == 0 {
|
||||
if len(fields) == 0 || len(fields[0]) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -47,6 +48,8 @@ func Parse(reader io.Reader) ([]Command, error) {
|
||||
command.Name = string(fields[0])
|
||||
command.Args = string(fields[1])
|
||||
default:
|
||||
// log a warning for unknown commands
|
||||
log.Printf("WARNING: Unknown command: %s", fields[0])
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -55,28 +58,53 @@ func Parse(reader io.Reader) ([]Command, error) {
|
||||
}
|
||||
|
||||
if modelCommand.Args == "" {
|
||||
return nil, fmt.Errorf("no FROM line for the model was specified")
|
||||
return nil, errors.New("no FROM line for the model was specified")
|
||||
}
|
||||
|
||||
return commands, scanner.Err()
|
||||
}
|
||||
|
||||
func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if advance > 0 && token != nil {
|
||||
return advance, token, nil
|
||||
}
|
||||
|
||||
advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if advance > 0 && token != nil {
|
||||
return advance, token, nil
|
||||
}
|
||||
|
||||
return bufio.ScanLines(data, atEOF)
|
||||
}
|
||||
|
||||
func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
newline := bytes.IndexByte(data, '\n')
|
||||
|
||||
if start := bytes.Index(data, []byte(`"""`)); start >= 0 && start < newline {
|
||||
end := bytes.Index(data[start+3:], []byte(`"""`))
|
||||
if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
|
||||
end := bytes.Index(data[start+len(openBytes):], closeBytes)
|
||||
if end < 0 {
|
||||
if atEOF {
|
||||
return 0, nil, errors.New(`unterminated multiline string: """`)
|
||||
return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
|
||||
} else {
|
||||
return 0, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
n := start + 3 + end + 3
|
||||
return n, bytes.Replace(data[:n], []byte(`"""`), []byte(""), 2), nil
|
||||
n := start + len(openBytes) + end + len(closeBytes)
|
||||
|
||||
newData := data[:start]
|
||||
newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...)
|
||||
return n, newData, nil
|
||||
}
|
||||
|
||||
return bufio.ScanLines(data, atEOF)
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
22
scripts/build_darwin.sh
Executable file
@@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
|
||||
mkdir -p dist
|
||||
|
||||
# build universal binary
|
||||
CGO_ENABLED=1 GOARCH=arm64 go build -o dist/ollama-darwin-arm64
|
||||
CGO_ENABLED=1 GOARCH=amd64 go build -o dist/ollama-darwin-amd64
|
||||
lipo -create -output dist/ollama dist/ollama-darwin-arm64 dist/ollama-darwin-amd64
|
||||
rm dist/ollama-darwin-amd64 dist/ollama-darwin-arm64
|
||||
codesign --deep --force --options=runtime --sign "$APPLE_IDENTITY" --timestamp dist/ollama
|
||||
|
||||
# build and sign the mac app
|
||||
npm install --prefix app
|
||||
npm run --prefix app make:sign
|
||||
cp app/out/make/zip/darwin/universal/Ollama-darwin-universal-${VERSION:-0.0.0}.zip dist/Ollama-darwin.zip
|
||||
|
||||
# sign the binary and rename it
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/ollama
|
||||
ditto -c -k --keepParent dist/ollama dist/temp.zip
|
||||
xcrun notarytool submit dist/temp.zip --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
mv dist/ollama dist/ollama-darwin
|
||||
rm dist/temp.zip
|
@@ -8,28 +8,18 @@ if [[ -z "${VERSION}" ]]; then
|
||||
fi
|
||||
|
||||
OS=$(go env GOOS)
|
||||
ARCH=$(go env GOARCH)
|
||||
|
||||
go build .
|
||||
|
||||
npm --prefix app run make:sign
|
||||
./script/build_${OS}.sh
|
||||
|
||||
# Create a new tag if it doesn't exist.
|
||||
if ! git rev-parse v$VERSION >/dev/null 2>&1; then
|
||||
git tag v$VERSION
|
||||
git push origin v$VERSION
|
||||
fi
|
||||
|
||||
mkdir -p dist
|
||||
cp app/out/make/zip/${OS}/${ARCH}/Ollama-${OS}-${ARCH}-${VERSION}.zip dist/Ollama-${OS}-${ARCH}.zip
|
||||
cp ./ollama dist/ollama-${OS}-${ARCH}
|
||||
git push origin v$VERSION
|
||||
|
||||
# Create a new release.
|
||||
gh release create v$VERSION
|
||||
gh release create -p v$VERSION -t v$VERSION
|
||||
|
||||
# Upload the zip file.
|
||||
gh release upload v$VERSION ./dist/Ollama-${OS}-${ARCH}.zip
|
||||
|
||||
# Upload the binary.
|
||||
gh release upload v$VERSION ./dist/ollama-${OS}-${ARCH}
|
||||
|
||||
gh release upload v$VERSION ./dist/* --clobber
|
||||
|
460
server/images.go
@@ -7,7 +7,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -22,12 +21,19 @@ import (
|
||||
"github.com/jmorganca/ollama/parser"
|
||||
)
|
||||
|
||||
type RegistryOptions struct {
|
||||
Insecure bool
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Name string `json:"name"`
|
||||
ModelPath string
|
||||
Template string
|
||||
System string
|
||||
Options api.Options
|
||||
Digest string
|
||||
Options map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
||||
@@ -45,7 +51,7 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
||||
Context []int
|
||||
}
|
||||
|
||||
vars.First = len(vars.Context) == 0
|
||||
vars.First = len(request.Context) == 0
|
||||
vars.System = m.System
|
||||
vars.Prompt = request.Prompt
|
||||
vars.Context = request.Context
|
||||
@@ -102,8 +108,8 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err = os.Stat(fp); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("couldn't find model '%s'", mp.GetShortTagname())
|
||||
if _, err = os.Stat(fp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var manifest *ManifestV2
|
||||
@@ -129,7 +135,8 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
|
||||
model := &Model{
|
||||
Name: mp.GetFullTagname(),
|
||||
Name: mp.GetFullTagname(),
|
||||
Digest: manifest.Config.Digest,
|
||||
}
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
@@ -169,40 +176,38 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
defer params.Close()
|
||||
|
||||
var opts api.Options
|
||||
if err = json.NewDecoder(params).Decode(&opts); err != nil {
|
||||
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
||||
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model.Options = opts
|
||||
}
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func CreateModel(name string, path string, fn func(status string)) error {
|
||||
func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error {
|
||||
mf, err := os.Open(path)
|
||||
if err != nil {
|
||||
fn(fmt.Sprintf("couldn't open modelfile '%s'", path))
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)})
|
||||
return fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer mf.Close()
|
||||
|
||||
fn("parsing modelfile")
|
||||
fn(api.ProgressResponse{Status: "parsing modelfile"})
|
||||
commands, err := parser.Parse(mf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var layers []*LayerReader
|
||||
params := make(map[string]string)
|
||||
params := make(map[string][]string)
|
||||
|
||||
for _, c := range commands {
|
||||
log.Printf("[%s] - %s\n", c.Name, c.Args)
|
||||
switch c.Name {
|
||||
case "model":
|
||||
fn("looking for model")
|
||||
fn(api.ProgressResponse{Status: "looking for model"})
|
||||
mf, err := GetManifest(ParseModelPath(c.Args))
|
||||
if err != nil {
|
||||
fp := c.Args
|
||||
@@ -223,20 +228,40 @@ func CreateModel(name string, path string, fn func(status string)) error {
|
||||
fp = filepath.Join(filepath.Dir(path), fp)
|
||||
}
|
||||
|
||||
fn("creating model layer")
|
||||
file, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
if _, err := os.Stat(fp); err != nil {
|
||||
// the model file does not exist, try pulling it
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
fn(api.ProgressResponse{Status: "pulling model file"})
|
||||
if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
mf, err = GetManifest(ParseModelPath(c.Args))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file after pull: %v", err)
|
||||
}
|
||||
|
||||
l, err := CreateLayer(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create layer: %v", err)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// create a model from this specified file
|
||||
fn(api.ProgressResponse{Status: "creating model layer"})
|
||||
|
||||
file, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
l, err := CreateLayer(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create layer: %v", err)
|
||||
}
|
||||
l.MediaType = "application/vnd.ollama.image.model"
|
||||
layers = append(layers, l)
|
||||
}
|
||||
l.MediaType = "application/vnd.ollama.image.model"
|
||||
layers = append(layers, l)
|
||||
} else {
|
||||
}
|
||||
if mf != nil {
|
||||
log.Printf("manifest = %#v", mf)
|
||||
for _, l := range mf.Layers {
|
||||
newLayer, err := GetLayerWithBufferFromLayer(l)
|
||||
@@ -246,8 +271,20 @@ func CreateModel(name string, path string, fn func(status string)) error {
|
||||
layers = append(layers, newLayer)
|
||||
}
|
||||
}
|
||||
case "license", "template", "system", "prompt":
|
||||
fn(fmt.Sprintf("creating %s layer", c.Name))
|
||||
case "license":
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
||||
// remove the prompt layer if one exists
|
||||
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
||||
|
||||
layer, err := CreateLayer(strings.NewReader(c.Args))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layer.MediaType = mediaType
|
||||
layers = append(layers, layer)
|
||||
case "template", "system", "prompt":
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
||||
// remove the prompt layer if one exists
|
||||
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
||||
layers = removeLayerFromLayers(layers, mediaType)
|
||||
@@ -260,13 +297,14 @@ func CreateModel(name string, path string, fn func(status string)) error {
|
||||
layer.MediaType = mediaType
|
||||
layers = append(layers, layer)
|
||||
default:
|
||||
params[c.Name] = c.Args
|
||||
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens)
|
||||
params[c.Name] = append(params[c.Name], c.Args)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a single layer for the parameters
|
||||
if len(params) > 0 {
|
||||
fn("creating parameter layer")
|
||||
fn(api.ProgressResponse{Status: "creating parameter layer"})
|
||||
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
|
||||
paramData, err := paramsToReader(params)
|
||||
if err != nil {
|
||||
@@ -291,7 +329,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
|
||||
}
|
||||
|
||||
// Create a layer for the config object
|
||||
fn("creating config layer")
|
||||
fn(api.ProgressResponse{Status: "creating config layer"})
|
||||
cfg, err := createConfigLayer(digests)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -304,13 +342,13 @@ func CreateModel(name string, path string, fn func(status string)) error {
|
||||
}
|
||||
|
||||
// Create the manifest
|
||||
fn("writing manifest")
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
err = CreateManifest(name, cfg, manifestLayers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fn("success")
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -325,7 +363,7 @@ func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerRead
|
||||
return layers[:j]
|
||||
}
|
||||
|
||||
func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error {
|
||||
func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
|
||||
// Write each of the layers to disk
|
||||
for _, layer := range layers {
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
@@ -335,7 +373,8 @@ func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error
|
||||
|
||||
_, err = os.Stat(fp)
|
||||
if os.IsNotExist(err) || force {
|
||||
fn(fmt.Sprintf("writing layer %s", layer.Digest))
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
|
||||
|
||||
out, err := os.Create(fp)
|
||||
if err != nil {
|
||||
log.Printf("couldn't create %s", fp)
|
||||
@@ -348,7 +387,7 @@ func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error
|
||||
}
|
||||
|
||||
} else {
|
||||
fn(fmt.Sprintf("using already created layer %s", layer.Digest))
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -401,11 +440,13 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
|
||||
return newLayer, nil
|
||||
}
|
||||
|
||||
func paramsToReader(params map[string]string) (io.ReadSeeker, error) {
|
||||
opts := api.DefaultOptions()
|
||||
typeOpts := reflect.TypeOf(opts)
|
||||
// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json
|
||||
func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
|
||||
opts := api.Options{}
|
||||
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
||||
|
||||
// build map of json struct tags
|
||||
// build map of json struct tags to their types
|
||||
jsonOpts := make(map[string]reflect.StructField)
|
||||
for _, field := range reflect.VisibleFields(typeOpts) {
|
||||
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
|
||||
@@ -414,36 +455,39 @@ func paramsToReader(params map[string]string) (io.ReadSeeker, error) {
|
||||
}
|
||||
}
|
||||
|
||||
valueOpts := reflect.ValueOf(&opts).Elem()
|
||||
out := make(map[string]interface{})
|
||||
// iterate params and set values based on json struct tags
|
||||
for key, val := range params {
|
||||
for key, vals := range params {
|
||||
if opt, ok := jsonOpts[key]; ok {
|
||||
field := valueOpts.FieldByName(opt.Name)
|
||||
if field.IsValid() && field.CanSet() {
|
||||
switch field.Kind() {
|
||||
case reflect.Float32:
|
||||
floatVal, err := strconv.ParseFloat(val, 32)
|
||||
floatVal, err := strconv.ParseFloat(vals[0], 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid float value %s", val)
|
||||
return nil, fmt.Errorf("invalid float value %s", vals)
|
||||
}
|
||||
|
||||
field.SetFloat(floatVal)
|
||||
out[key] = floatVal
|
||||
case reflect.Int:
|
||||
intVal, err := strconv.ParseInt(val, 10, 0)
|
||||
intVal, err := strconv.ParseInt(vals[0], 10, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid int value %s", val)
|
||||
return nil, fmt.Errorf("invalid int value %s", vals)
|
||||
}
|
||||
|
||||
field.SetInt(intVal)
|
||||
out[key] = intVal
|
||||
case reflect.Bool:
|
||||
boolVal, err := strconv.ParseBool(val)
|
||||
boolVal, err := strconv.ParseBool(vals[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid bool value %s", val)
|
||||
return nil, fmt.Errorf("invalid bool value %s", vals)
|
||||
}
|
||||
|
||||
field.SetBool(boolVal)
|
||||
out[key] = boolVal
|
||||
case reflect.String:
|
||||
field.SetString(val)
|
||||
out[key] = vals[0]
|
||||
case reflect.Slice:
|
||||
// TODO: only string slices are supported right now
|
||||
out[key] = vals
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
|
||||
}
|
||||
@@ -451,7 +495,7 @@ func paramsToReader(params map[string]string) (io.ReadSeeker, error) {
|
||||
}
|
||||
}
|
||||
|
||||
bts, err := json.Marshal(opts)
|
||||
bts, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -487,7 +531,117 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
|
||||
return layer, nil
|
||||
}
|
||||
|
||||
func PushModel(name, username, password string, fn func(api.ProgressResponse)) error {
|
||||
func CopyModel(src, dest string) error {
|
||||
srcPath, err := ParseModelPath(src).GetManifestPath(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destPath, err := ParseModelPath(dest).GetManifestPath(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// copy the file
|
||||
input, err := os.ReadFile(srcPath)
|
||||
if err != nil {
|
||||
fmt.Println("Error reading file:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.WriteFile(destPath, input, 0o644)
|
||||
if err != nil {
|
||||
fmt.Println("Error reading file:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteModel(name string) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
manifest, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deleteMap := make(map[string]bool)
|
||||
for _, layer := range manifest.Layers {
|
||||
deleteMap[layer.Digest] = true
|
||||
}
|
||||
deleteMap[manifest.Config.Digest] = true
|
||||
|
||||
fp, err := GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
path := path[len(fp)+1:]
|
||||
slashIndex := strings.LastIndex(path, "/")
|
||||
if slashIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
tag := path[:slashIndex] + ":" + path[slashIndex+1:]
|
||||
fmp := ParseModelPath(tag)
|
||||
|
||||
// skip the manifest we're trying to delete
|
||||
if mp.GetFullTagname() == fmp.GetFullTagname() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// save (i.e. delete from the deleteMap) any files used in other manifests
|
||||
manifest, err := GetManifest(fmp)
|
||||
if err != nil {
|
||||
log.Printf("skipping file: %s", fp)
|
||||
return nil
|
||||
}
|
||||
for _, layer := range manifest.Layers {
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// only delete the files which are still in the deleteMap
|
||||
for k, v := range deleteMap {
|
||||
if v {
|
||||
fp, err := GetBlobsPath(k)
|
||||
if err != nil {
|
||||
log.Printf("couldn't get file path for '%s': %v", k, err)
|
||||
continue
|
||||
}
|
||||
if err := os.Remove(fp); err != nil {
|
||||
log.Printf("couldn't remove file '%s': %v", fp, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fp, err = mp.GetManifestPath(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = os.Remove(fp)
|
||||
if err != nil {
|
||||
log.Printf("couldn't remove manifest file '%s': %v", fp, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
@@ -499,65 +653,47 @@ func PushModel(name, username, password string, fn func(api.ProgressResponse)) e
|
||||
}
|
||||
|
||||
var layers []*Layer
|
||||
var total int
|
||||
var completed int
|
||||
for _, layer := range manifest.Layers {
|
||||
layers = append(layers, layer)
|
||||
total += layer.Size
|
||||
}
|
||||
layers = append(layers, manifest.Layers...)
|
||||
layers = append(layers, &manifest.Config)
|
||||
total += manifest.Config.Size
|
||||
|
||||
for _, layer := range layers {
|
||||
exists, err := checkBlobExistence(mp, layer.Digest, username, password)
|
||||
exists, err := checkBlobExistence(mp, layer.Digest, regOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if exists {
|
||||
completed += layer.Size
|
||||
fn(api.ProgressResponse{
|
||||
Status: "using existing layer",
|
||||
Digest: layer.Digest,
|
||||
Total: total,
|
||||
Completed: completed,
|
||||
Total: layer.Size,
|
||||
Completed: layer.Size,
|
||||
})
|
||||
log.Printf("Layer %s already exists", layer.Digest)
|
||||
continue
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{
|
||||
Status: "starting upload",
|
||||
Digest: layer.Digest,
|
||||
Total: total,
|
||||
Completed: completed,
|
||||
Status: "starting upload",
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
})
|
||||
|
||||
location, err := startUpload(mp, username, password)
|
||||
location, err := startUpload(mp, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't start upload: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = uploadBlob(location, layer, username, password)
|
||||
err = uploadBlobChunked(mp, location, layer, regOpts, fn)
|
||||
if err != nil {
|
||||
log.Printf("error uploading blob: %v", err)
|
||||
return err
|
||||
}
|
||||
completed += layer.Size
|
||||
fn(api.ProgressResponse{
|
||||
Status: "upload complete",
|
||||
Digest: layer.Digest,
|
||||
Total: total,
|
||||
Completed: completed,
|
||||
})
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{
|
||||
Status: "pushing manifest",
|
||||
Total: total,
|
||||
Completed: completed,
|
||||
})
|
||||
url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
|
||||
fn(api.ProgressResponse{Status: "pushing manifest"})
|
||||
url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
|
||||
headers := map[string]string{
|
||||
"Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
}
|
||||
@@ -567,7 +703,7 @@ func PushModel(name, username, password string, fn func(api.ProgressResponse)) e
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), username, password)
|
||||
resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -576,26 +712,22 @@ func PushModel(name, username, password string, fn func(api.ProgressResponse)) e
|
||||
// Check for success: For a successful upload, the Docker registry will respond with a 201 Created
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
return fmt.Errorf("on push registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{
|
||||
Status: "success",
|
||||
Total: total,
|
||||
Completed: completed,
|
||||
})
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func PullModel(name, username, password string, fn func(api.ProgressResponse)) error {
|
||||
func PullModel(name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
|
||||
manifest, err := pullModelManifest(mp, username, password)
|
||||
manifest, err := pullModelManifest(mp, regOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pull model manifest: %q", err)
|
||||
return fmt.Errorf("pull model manifest: %s", err)
|
||||
}
|
||||
|
||||
var layers []*Layer
|
||||
@@ -603,7 +735,7 @@ func PullModel(name, username, password string, fn func(api.ProgressResponse)) e
|
||||
layers = append(layers, &manifest.Config)
|
||||
|
||||
for _, layer := range layers {
|
||||
if err := downloadBlob(mp, layer.Digest, username, password, fn); err != nil {
|
||||
if err := downloadBlob(mp, layer.Digest, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -611,6 +743,17 @@ func PullModel(name, username, password string, fn func(api.ProgressResponse)) e
|
||||
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
|
||||
for _, layer := range layers {
|
||||
if err := verifyBlob(layer.Digest); err != nil {
|
||||
if errors.Is(err, errDigestMismatch) {
|
||||
// something went wrong, delete the blob
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(fp); err != nil {
|
||||
// log this, but return the original error
|
||||
log.Printf("couldn't remove file with digest mismatch '%s': %v", fp, err)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -638,13 +781,13 @@ func PullModel(name, username, password string, fn func(api.ProgressResponse)) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func pullModelManifest(mp ModelPath, username, password string) (*ManifestV2, error) {
|
||||
url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
|
||||
func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
|
||||
url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
|
||||
headers := map[string]string{
|
||||
"Accept": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
}
|
||||
|
||||
resp, err := makeRequest("GET", url, headers, nil, username, password)
|
||||
resp, err := makeRequest("GET", url, headers, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't get manifest: %v", err)
|
||||
return nil, err
|
||||
@@ -653,8 +796,11 @@ func pullModelManifest(mp ModelPath, username, password string) (*ManifestV2, er
|
||||
|
||||
// Check for success: For a successful upload, the Docker registry will respond with a 201 Created
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, fmt.Errorf("model not found")
|
||||
}
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("registry responded with code %d: %s", resp.StatusCode, body)
|
||||
return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var m *ManifestV2
|
||||
@@ -705,10 +851,10 @@ func GetSHA256Digest(r io.Reader) (string, int) {
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
|
||||
}
|
||||
|
||||
func startUpload(mp ModelPath, username string, password string) (string, error) {
|
||||
url := fmt.Sprintf("%s://%s/v2/%s/blobs/uploads/", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository())
|
||||
func startUpload(mp ModelPath, regOpts *RegistryOptions) (string, error) {
|
||||
url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
|
||||
|
||||
resp, err := makeRequest("POST", url, nil, nil, username, password)
|
||||
resp, err := makeRequest("POST", url, nil, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't start upload: %v", err)
|
||||
return "", err
|
||||
@@ -718,7 +864,7 @@ func startUpload(mp ModelPath, username string, password string) (string, error)
|
||||
// Check for success
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("registry responded with code %d: %s", resp.StatusCode, body)
|
||||
return "", fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
// Extract UUID location from header
|
||||
@@ -731,10 +877,10 @@ func startUpload(mp ModelPath, username string, password string) (string, error)
|
||||
}
|
||||
|
||||
// Function to check if a blob already exists in the Docker registry
|
||||
func checkBlobExistence(mp ModelPath, digest string, username string, password string) (bool, error) {
|
||||
url := fmt.Sprintf("%s://%s/v2/%s/blobs/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), digest)
|
||||
func checkBlobExistence(mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
|
||||
url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
|
||||
|
||||
resp, err := makeRequest("HEAD", url, nil, nil, username, password)
|
||||
resp, err := makeRequest("HEAD", url, nil, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't check for blob: %v", err)
|
||||
return false, err
|
||||
@@ -745,15 +891,7 @@ func checkBlobExistence(mp ModelPath, digest string, username string, password s
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
|
||||
func uploadBlob(location string, layer *Layer, username string, password string) error {
|
||||
// Create URL
|
||||
url := fmt.Sprintf("%s&digest=%s", location, layer.Digest)
|
||||
|
||||
headers := make(map[string]string)
|
||||
headers["Content-Length"] = fmt.Sprintf("%d", layer.Size)
|
||||
headers["Content-Type"] = "application/octet-stream"
|
||||
|
||||
// TODO change from monolithic uploads to chunked uploads
|
||||
func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
// TODO allow resumability
|
||||
// TODO allow canceling uploads via DELETE
|
||||
// TODO allow cross repo blob mount
|
||||
@@ -768,23 +906,63 @@ func uploadBlob(location string, layer *Layer, username string, password string)
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := makeRequest("PUT", url, headers, f, username, password)
|
||||
totalUploaded := 0
|
||||
|
||||
r, w := io.Pipe()
|
||||
defer r.Close()
|
||||
|
||||
go func() {
|
||||
defer w.Close()
|
||||
for {
|
||||
n, err := io.CopyN(w, f, 1024*1024)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("error copying pipe: %v", err),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: totalUploaded,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
totalUploaded += int(n)
|
||||
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("uploading %s", layer.Digest),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: totalUploaded,
|
||||
})
|
||||
|
||||
if totalUploaded >= layer.Size {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
url = fmt.Sprintf("%s&digest=%s", url, layer.Digest)
|
||||
|
||||
headers := make(map[string]string)
|
||||
headers["Content-Type"] = "application/octet-stream"
|
||||
headers["Content-Range"] = fmt.Sprintf("0-%d", layer.Size-1)
|
||||
headers["Content-Length"] = strconv.Itoa(int(layer.Size))
|
||||
|
||||
// finish the upload
|
||||
resp, err := makeRequest("PUT", url, headers, r, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't upload blob: %v", err)
|
||||
log.Printf("couldn't finish upload: %v", err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check for success: For a successful upload, the Docker registry will respond with a 201 Created
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func downloadBlob(mp ModelPath, digest string, username, password string, fn func(api.ProgressResponse)) error {
|
||||
func downloadBlob(mp ModelPath, digest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
fp, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -802,6 +980,7 @@ func downloadBlob(mp ModelPath, digest string, username, password string, fn fun
|
||||
}
|
||||
|
||||
var size int64
|
||||
chunkSize := 1024 * 1024 // 1 MiB in bytes
|
||||
|
||||
fi, err := os.Stat(fp + "-partial")
|
||||
switch {
|
||||
@@ -811,14 +990,21 @@ func downloadBlob(mp ModelPath, digest string, username, password string, fn fun
|
||||
return fmt.Errorf("stat: %w", err)
|
||||
default:
|
||||
size = fi.Size()
|
||||
// Ensure the size is divisible by the chunk size by removing excess bytes
|
||||
size -= size % int64(chunkSize)
|
||||
|
||||
err := os.Truncate(fp+"-partial", size)
|
||||
if err != nil {
|
||||
return fmt.Errorf("truncate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s://%s/v2/%s/blobs/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), digest)
|
||||
url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
|
||||
headers := map[string]string{
|
||||
"Range": fmt.Sprintf("bytes=%d-", size),
|
||||
}
|
||||
|
||||
resp, err := makeRequest("GET", url, headers, nil, username, password)
|
||||
resp, err := makeRequest("GET", url, headers, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't download blob: %v", err)
|
||||
return err
|
||||
@@ -826,8 +1012,8 @@ func downloadBlob(mp ModelPath, digest string, username, password string, fn fun
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
err = os.MkdirAll(path.Dir(fp), 0o700)
|
||||
@@ -837,7 +1023,7 @@ func downloadBlob(mp ModelPath, digest string, username, password string, fn fun
|
||||
|
||||
out, err := os.OpenFile(fp+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
@@ -871,7 +1057,7 @@ func downloadBlob(mp ModelPath, digest string, username, password string, fn fun
|
||||
break
|
||||
}
|
||||
|
||||
n, err := io.CopyN(out, resp.Body, 8192)
|
||||
n, err := io.CopyN(out, resp.Body, int64(chunkSize))
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
@@ -882,7 +1068,15 @@ func downloadBlob(mp ModelPath, digest string, username, password string, fn fun
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeRequest(method, url string, headers map[string]string, body io.Reader, username, password string) (*http.Response, error) {
|
||||
func makeRequest(method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
if !strings.HasPrefix(url, "http") {
|
||||
if regOpts.Insecure {
|
||||
url = "http://" + url
|
||||
} else {
|
||||
url = "https://" + url
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -893,8 +1087,8 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
|
||||
}
|
||||
|
||||
// TODO: better auth
|
||||
if username != "" && password != "" {
|
||||
req.SetBasicAuth(username, password)
|
||||
if regOpts.Username != "" && regOpts.Password != "" {
|
||||
req.SetBasicAuth(regOpts.Username, regOpts.Password)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
@@ -914,6 +1108,8 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
|
||||
|
||||
func verifyBlob(digest string) error {
|
||||
fp, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
@@ -928,7 +1124,7 @@ func verifyBlob(digest string) error {
|
||||
|
||||
fileDigest, _ := GetSHA256Digest(f)
|
||||
if digest != fileDigest {
|
||||
return fmt.Errorf("digest mismatch: want %s, got %s", digest, fileDigest)
|
||||
return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@@ -45,7 +45,7 @@ func ParseModelPath(name string) ModelPath {
|
||||
return ModelPath{}
|
||||
}
|
||||
|
||||
colonParts := strings.Split(name, ":")
|
||||
colonParts := strings.Split(slashParts[len(slashParts)-1], ":")
|
||||
if len(colonParts) == 2 {
|
||||
tag = colonParts[1]
|
||||
} else {
|
||||
@@ -70,10 +70,13 @@ func (mp ModelPath) GetFullTagname() string {
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetShortTagname() string {
|
||||
if mp.Registry == DefaultRegistry && mp.Namespace == DefaultNamespace {
|
||||
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
|
||||
if mp.Registry == DefaultRegistry {
|
||||
if mp.Namespace == DefaultNamespace {
|
||||
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetManifestPath(createDir bool) (string, error) {
|
||||
|
191
server/routes.go
@@ -2,24 +2,43 @@ package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/llama"
|
||||
)
|
||||
|
||||
func generate(c *gin.Context) {
|
||||
start := time.Now()
|
||||
var loaded struct {
|
||||
mu sync.Mutex
|
||||
|
||||
llm *llama.LLM
|
||||
|
||||
expireAt time.Time
|
||||
expireTimer *time.Timer
|
||||
|
||||
digest string
|
||||
options api.Options
|
||||
}
|
||||
|
||||
func GenerateHandler(c *gin.Context) {
|
||||
loaded.mu.Lock()
|
||||
defer loaded.mu.Unlock()
|
||||
|
||||
checkpointStart := time.Now()
|
||||
|
||||
var req api.GenerateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -34,43 +53,85 @@ func generate(c *gin.Context) {
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
|
||||
if err := opts.FromMap(model.Options); err != nil {
|
||||
log.Printf("could not load model options: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
|
||||
if err := opts.FromMap(req.Options); err != nil {
|
||||
log.Printf("could not merge model options: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) {
|
||||
if loaded.llm != nil {
|
||||
loaded.llm.Close()
|
||||
loaded.llm = nil
|
||||
loaded.digest = ""
|
||||
}
|
||||
|
||||
llm, err := llama.New(model.ModelPath, opts)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
loaded.llm = llm
|
||||
loaded.digest = model.Digest
|
||||
loaded.options = opts
|
||||
}
|
||||
|
||||
sessionDuration := 5 * time.Minute
|
||||
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
if loaded.expireTimer == nil {
|
||||
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
|
||||
loaded.mu.Lock()
|
||||
defer loaded.mu.Unlock()
|
||||
|
||||
if time.Now().Before(loaded.expireAt) {
|
||||
return
|
||||
}
|
||||
|
||||
if loaded.llm == nil {
|
||||
return
|
||||
}
|
||||
|
||||
loaded.llm.Close()
|
||||
loaded.llm = nil
|
||||
loaded.digest = ""
|
||||
})
|
||||
}
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt, err := model.Prompt(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
llm, err := llama.New(model.ModelPath, opts)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer llm.Close()
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
fn := func(r api.GenerateResponse) {
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
|
||||
r.Model = req.Model
|
||||
r.CreatedAt = time.Now().UTC()
|
||||
if r.Done {
|
||||
r.TotalDuration = time.Since(start)
|
||||
r.TotalDuration = time.Since(checkpointStart)
|
||||
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
ch <- r
|
||||
}
|
||||
|
||||
if err := llm.Predict(req.Context, prompt, fn); err != nil {
|
||||
if err := loaded.llm.Predict(req.Context, prompt, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -78,7 +139,7 @@ func generate(c *gin.Context) {
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func pull(c *gin.Context) {
|
||||
func PullModelHandler(c *gin.Context) {
|
||||
var req api.PullRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -92,7 +153,13 @@ func pull(c *gin.Context) {
|
||||
ch <- r
|
||||
}
|
||||
|
||||
if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
|
||||
regOpts := &RegistryOptions{
|
||||
Insecure: req.Insecure,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
}
|
||||
|
||||
if err := PullModel(req.Name, regOpts, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -100,7 +167,7 @@ func pull(c *gin.Context) {
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func push(c *gin.Context) {
|
||||
func PushModelHandler(c *gin.Context) {
|
||||
var req api.PushRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -114,7 +181,13 @@ func push(c *gin.Context) {
|
||||
ch <- r
|
||||
}
|
||||
|
||||
if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
|
||||
regOpts := &RegistryOptions{
|
||||
Insecure: req.Insecure,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
}
|
||||
|
||||
if err := PushModel(req.Name, regOpts, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -122,7 +195,7 @@ func push(c *gin.Context) {
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func create(c *gin.Context) {
|
||||
func CreateModelHandler(c *gin.Context) {
|
||||
var req api.CreateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
||||
@@ -132,10 +205,8 @@ func create(c *gin.Context) {
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
fn := func(status string) {
|
||||
ch <- api.CreateProgress{
|
||||
Status: status,
|
||||
}
|
||||
fn := func(resp api.ProgressResponse) {
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
if err := CreateModel(req.Name, req.Path, fn); err != nil {
|
||||
@@ -146,7 +217,24 @@ func create(c *gin.Context) {
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func list(c *gin.Context) {
|
||||
func DeleteModelHandler(c *gin.Context) {
|
||||
var req api.DeleteRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := DeleteModel(req.Name); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func ListModelsHandler(c *gin.Context) {
|
||||
var models []api.ListResponseModel
|
||||
fp, err := GetManifestPath()
|
||||
if err != nil {
|
||||
@@ -155,6 +243,10 @@ func list(c *gin.Context) {
|
||||
}
|
||||
err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
log.Printf("manifest file does not exist: %s", fp)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
@@ -189,21 +281,58 @@ func list(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, api.ListResponse{models})
|
||||
c.JSON(http.StatusOK, api.ListResponse{Models: models})
|
||||
}
|
||||
|
||||
func CopyModelHandler(c *gin.Context) {
|
||||
var req api.CopyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := CopyModel(req.Source, req.Destination); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func Serve(ln net.Listener) error {
|
||||
config := cors.DefaultConfig()
|
||||
config.AllowWildcard = true
|
||||
// only allow http/https from localhost
|
||||
config.AllowOrigins = []string{
|
||||
"http://localhost",
|
||||
"http://localhost:*",
|
||||
"https://localhost",
|
||||
"https://localhost:*",
|
||||
"http://127.0.0.1",
|
||||
"http://127.0.0.1:*",
|
||||
"https://127.0.0.1",
|
||||
"https://127.0.0.1:*",
|
||||
}
|
||||
|
||||
r := gin.Default()
|
||||
r.Use(cors.New(config))
|
||||
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "Ollama is running")
|
||||
})
|
||||
r.HEAD("/", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
r.POST("/api/pull", pull)
|
||||
r.POST("/api/generate", generate)
|
||||
r.POST("/api/create", create)
|
||||
r.POST("/api/push", push)
|
||||
r.GET("/api/tags", list)
|
||||
r.POST("/api/pull", PullModelHandler)
|
||||
r.POST("/api/generate", GenerateHandler)
|
||||
r.POST("/api/create", CreateModelHandler)
|
||||
r.POST("/api/push", PushModelHandler)
|
||||
r.POST("/api/copy", CopyModelHandler)
|
||||
r.GET("/api/tags", ListModelsHandler)
|
||||
r.DELETE("/api/delete", DeleteModelHandler)
|
||||
|
||||
log.Printf("Listening on %s", ln.Addr())
|
||||
s := &http.Server{
|
||||
@@ -222,11 +351,13 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
|
||||
bts, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
log.Printf("streamResponse: json.Marshal failed with %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
bts = append(bts, '\n')
|
||||
if _, err := w.Write(bts); err != nil {
|
||||
log.Printf("streamResponse: w.Write failed with %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
|
@@ -14,11 +14,12 @@ export async function GET(req: Request) {
|
||||
const res = await fetch('https://api.github.com/repos/jmorganca/ollama/releases', { next: { revalidate: 60 } })
|
||||
const data = await res.json()
|
||||
|
||||
if (data.length === 0) {
|
||||
const latest = data?.filter((f: any) => !f.prerelease)?.[0]
|
||||
|
||||
if (!latest) {
|
||||
return new Response('not found', { status: 404 })
|
||||
}
|
||||
|
||||
const latest = data[0]
|
||||
const assets = latest.assets || []
|
||||
|
||||
if (assets.length === 0) {
|
||||
|
@@ -1,8 +1,8 @@
|
||||
import Link from "next/link"
|
||||
|
||||
const navigation = [
|
||||
{ name: 'Discord', href: 'https://discord.gg/MrfB5FbNWN' },
|
||||
{ name: 'Github', href: 'https://github.com/jmorganca/ollama' },
|
||||
{ name: 'Discord', href: 'https://discord.com/invite/ollama' },
|
||||
{ name: 'GitHub', href: 'https://github.com/jmorganca/ollama' },
|
||||
{ name: 'Download', href: '/download' },
|
||||
]
|
||||
|
||||
@@ -21,6 +21,6 @@ export default function Header() {
|
||||
))}
|
||||
</div>
|
||||
</nav>
|
||||
</header >
|
||||
</header>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@@ -11,18 +11,23 @@ export default async function Home() {
|
||||
<Image src='/ollama.png' width={64} height={64} alt='ollamaIcon' />
|
||||
<section className='my-12 text-center'>
|
||||
<div className='flex flex-col space-y-2'>
|
||||
<h2 className='md:max-w-[18rem] mx-auto my-2 text-3xl tracking-tight'>Portable large language models</h2>
|
||||
<h2 className='md:max-w-md mx-auto my-2 text-3xl tracking-tight'>
|
||||
Get up and running with large language models, locally.
|
||||
</h2>
|
||||
<h3 className='md:max-w-xs mx-auto text-base text-neutral-500'>
|
||||
Bundle a model’s weights, configuration, prompts, data and more into self-contained packages that run anywhere.
|
||||
Run Llama 2 and other models on macOS. Customize and create your own.
|
||||
</h3>
|
||||
</div>
|
||||
<div className='mx-auto flex flex-col space-y-4 mt-12'>
|
||||
<Link href='/download' className='md:mx-10 lg:mx-14 bg-black text-white rounded-full px-4 py-2 focus:outline-none cursor-pointer'>
|
||||
<div className='mx-auto max-w-xs flex flex-col space-y-4 mt-12'>
|
||||
<Link
|
||||
href='/download'
|
||||
className='md:mx-10 lg:mx-14 bg-black text-white rounded-full px-4 py-2 focus:outline-none cursor-pointer'
|
||||
>
|
||||
Download
|
||||
</Link>
|
||||
<p className='text-neutral-500 text-sm '>
|
||||
Available for macOS with Apple Silicon <br />
|
||||
Windows & Linux support coming soon.
|
||||
Available for macOS with Apple Silicon <br />
|
||||
Windows & Linux support coming soon.
|
||||
</p>
|
||||
</div>
|
||||
</section>
|
||||
|