Compare commits
74 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7db5bcf73b | ||
![]() |
fa2f095bd9 | ||
![]() |
045b855db9 | ||
![]() |
32064a0646 | ||
![]() |
d9a250e9b5 | ||
![]() |
944519ed16 | ||
![]() |
2dd040d04c | ||
![]() |
bbe41ce41a | ||
![]() |
9e1406e4ed | ||
![]() |
b74580c913 | ||
![]() |
7e9405fd07 | ||
![]() |
3b0b8930d4 | ||
![]() |
e3f925fc1b | ||
![]() |
2a2289fb6b | ||
![]() |
dd427f499a | ||
![]() |
2ae573c7ed | ||
![]() |
02fe26c44b | ||
![]() |
16c7548460 | ||
![]() |
fa75998c0d | ||
![]() |
5344f886c8 | ||
![]() |
6cc823c9b5 | ||
![]() |
b84d34e632 | ||
![]() |
30229a913c | ||
![]() |
1ade380bd7 | ||
![]() |
ba264e9da8 | ||
![]() |
a2405ec831 | ||
![]() |
ce809bb529 | ||
![]() |
76bc4d0458 | ||
![]() |
4a02945a15 | ||
![]() |
aec742b6d2 | ||
![]() |
f337642e94 | ||
![]() |
51131cc6e2 | ||
![]() |
43027789dc | ||
![]() |
f9b7d65e2b | ||
![]() |
1f05d77110 | ||
![]() |
c3ff36088b | ||
![]() |
13524b5e72 | ||
![]() |
f1b049fed8 | ||
![]() |
97c5696945 | ||
![]() |
47d4e22673 | ||
![]() |
32f62fbb8e | ||
![]() |
5d75505ebd | ||
![]() |
b9495ea162 | ||
![]() |
409bb9674e | ||
![]() |
d3479c07a1 | ||
![]() |
b12f1b984f | ||
![]() |
195e3d9dbd | ||
![]() |
38fe1a368b | ||
![]() |
4b77fcb2b9 | ||
![]() |
cde13bcdea | ||
![]() |
0f0cd265a7 | ||
![]() |
0db4706ec2 | ||
![]() |
1ebdbd9694 | ||
![]() |
5c59455b59 | ||
![]() |
00d06619a1 | ||
![]() |
f1ef3f9947 | ||
![]() |
5a5dca13b2 | ||
![]() |
7232f1fa41 | ||
![]() |
72e7a49aa9 | ||
![]() |
a3737cbd33 | ||
![]() |
998f1785b6 | ||
![]() |
70a93057cd | ||
![]() |
2cb0fa7d40 | ||
![]() |
b2816bca67 | ||
![]() |
bf704423c5 | ||
![]() |
7a0899d62d | ||
![]() |
0cca1486dd | ||
![]() |
2113c9d31a | ||
![]() |
6deebf2489 | ||
![]() |
95cb38ae47 | ||
![]() |
1f126afb2d | ||
![]() |
f6201a7a6c | ||
![]() |
b3f6c6598f | ||
![]() |
88620e983a |
16
README.md
16
README.md
@@ -214,14 +214,21 @@ curl http://localhost:11434/api/generate -d '{
|
||||
}'
|
||||
```
|
||||
|
||||
Or send a chat message (coming in 0.1.14):
|
||||
|
||||
```
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "mistral",
|
||||
"messages": [
|
||||
{ "role": "user", "content": "why is the sky blue?" }
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
## Community Integrations
|
||||
|
||||
### Mobile
|
||||
|
||||
- [Mobile Artificial Intelligence Distribution](https://github.com/MaidFoundation/Maid) (Maid)
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
@@ -277,6 +284,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Logseq Ollama plugin](https://github.com/omagdy7/ollama-logseq)
|
||||
- [Dagger Chatbot](https://github.com/samalba/dagger-chatbot)
|
||||
- [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot)
|
||||
- [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
|
||||
- [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
|
||||
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
|
||||
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
||||
|
@@ -221,6 +221,19 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
|
||||
})
|
||||
}
|
||||
|
||||
type ChatResponseFunc func(ChatResponse) error
|
||||
|
||||
func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
|
||||
var resp ChatResponse
|
||||
if err := json.Unmarshal(bts, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fn(resp)
|
||||
})
|
||||
}
|
||||
|
||||
type PullProgressFunc func(ProgressResponse) error
|
||||
|
||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
|
||||
@@ -311,3 +324,15 @@ func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) err
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
var version struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
if err := c.do(ctx, http.MethodGet, "/api/version", nil, &version); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return version.Version, nil
|
||||
}
|
||||
|
70
api/types.go
70
api/types.go
@@ -44,6 +44,39 @@ type GenerateRequest struct {
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Format string `json:"format"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"` // one of ["system", "user", "assistant"]
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
|
||||
Done bool `json:"done"`
|
||||
|
||||
Metrics
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration time.Duration `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
|
||||
type Options struct {
|
||||
Runner
|
||||
@@ -173,39 +206,34 @@ type GenerateResponse struct {
|
||||
Done bool `json:"done"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration time.Duration `json:"eval_duration,omitempty"`
|
||||
Metrics
|
||||
}
|
||||
|
||||
func (r *GenerateResponse) Summary() {
|
||||
if r.TotalDuration > 0 {
|
||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
|
||||
func (m *Metrics) Summary() {
|
||||
if m.TotalDuration > 0 {
|
||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||
}
|
||||
|
||||
if r.LoadDuration > 0 {
|
||||
fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration)
|
||||
if m.LoadDuration > 0 {
|
||||
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
||||
}
|
||||
|
||||
if r.PromptEvalCount > 0 {
|
||||
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
|
||||
if m.PromptEvalCount > 0 {
|
||||
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", m.PromptEvalCount)
|
||||
}
|
||||
|
||||
if r.PromptEvalDuration > 0 {
|
||||
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration)
|
||||
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds())
|
||||
if m.PromptEvalDuration > 0 {
|
||||
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", m.PromptEvalDuration)
|
||||
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(m.PromptEvalCount)/m.PromptEvalDuration.Seconds())
|
||||
}
|
||||
|
||||
if r.EvalCount > 0 {
|
||||
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount)
|
||||
if m.EvalCount > 0 {
|
||||
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", m.EvalCount)
|
||||
}
|
||||
|
||||
if r.EvalDuration > 0 {
|
||||
fmt.Fprintf(os.Stderr, "eval duration: %s\n", r.EvalDuration)
|
||||
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds())
|
||||
if m.EvalDuration > 0 {
|
||||
fmt.Fprintf(os.Stderr, "eval duration: %s\n", m.EvalDuration)
|
||||
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(m.EvalCount)/m.EvalDuration.Seconds())
|
||||
}
|
||||
}
|
||||
|
||||
|
111
cmd/cmd.go
111
cmd/cmd.go
@@ -133,7 +133,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile)}
|
||||
if err := client.Create(context.Background(), &request, fn); err != nil {
|
||||
if err := client.Create(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -148,7 +148,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
name := args[0]
|
||||
// check if the model exists on the server
|
||||
_, err = client.Show(context.Background(), &api.ShowRequest{Name: name})
|
||||
_, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
var statusError api.StatusError
|
||||
switch {
|
||||
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
|
||||
@@ -208,7 +208,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||
if err := client.Push(context.Background(), &request, fn); err != nil {
|
||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -222,7 +222,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
models, err := client.List(context.Background())
|
||||
models, err := client.List(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -257,7 +257,7 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
for _, name := range args {
|
||||
req := api.DeleteRequest{Name: name}
|
||||
if err := client.Delete(context.Background(), &req); err != nil {
|
||||
if err := client.Delete(cmd.Context(), &req); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("deleted '%s'\n", name)
|
||||
@@ -322,7 +322,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
req := api.ShowRequest{Name: args[0]}
|
||||
resp, err := client.Show(context.Background(), &req)
|
||||
resp, err := client.Show(cmd.Context(), &req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -350,7 +350,7 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
req := api.CopyRequest{Source: args[0], Destination: args[1]}
|
||||
if err := client.Copy(context.Background(), &req); err != nil {
|
||||
if err := client.Copy(cmd.Context(), &req); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("copied '%s' to '%s'\n", args[0], args[1])
|
||||
@@ -404,7 +404,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: args[0], Insecure: insecure}
|
||||
if err := client.Pull(context.Background(), &request, fn); err != nil {
|
||||
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -493,7 +493,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||
opts.WordWrap = false
|
||||
}
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
@@ -507,23 +507,22 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||
var currentLineLength int
|
||||
var wordBuffer string
|
||||
|
||||
request := api.GenerateRequest{
|
||||
Model: opts.Model,
|
||||
Prompt: opts.Prompt,
|
||||
Context: generateContext,
|
||||
Format: opts.Format,
|
||||
System: opts.System,
|
||||
Template: opts.Template,
|
||||
Options: opts.Options,
|
||||
}
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
p.StopAndClear()
|
||||
|
||||
latest = response
|
||||
|
||||
if opts.WordWrap {
|
||||
termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
|
||||
if opts.WordWrap && termWidth >= 10 {
|
||||
for _, ch := range response.Response {
|
||||
if currentLineLength+1 > termWidth-5 {
|
||||
if len(wordBuffer) > termWidth-10 {
|
||||
fmt.Printf("%s%c", wordBuffer, ch)
|
||||
wordBuffer = ""
|
||||
currentLineLength = 0
|
||||
continue
|
||||
}
|
||||
|
||||
// backtrack the length of the last word and clear to the end of the line
|
||||
fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer))
|
||||
fmt.Printf("%s%c", wordBuffer, ch)
|
||||
@@ -543,13 +542,26 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fmt.Print(response.Response)
|
||||
fmt.Printf("%s%s", wordBuffer, response.Response)
|
||||
if len(wordBuffer) > 0 {
|
||||
wordBuffer = ""
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := client.Generate(cancelCtx, &request, fn); err != nil {
|
||||
request := api.GenerateRequest{
|
||||
Model: opts.Model,
|
||||
Prompt: opts.Prompt,
|
||||
Context: generateContext,
|
||||
Format: opts.Format,
|
||||
System: opts.System,
|
||||
Template: opts.Template,
|
||||
Options: opts.Options,
|
||||
}
|
||||
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
@@ -573,10 +585,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
ctx := cmd.Context()
|
||||
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
|
||||
cmd.SetContext(ctx)
|
||||
|
||||
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -705,11 +714,11 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||
case MultilineSystem:
|
||||
opts.System = prompt
|
||||
prompt = ""
|
||||
fmt.Println("Set system template.\n")
|
||||
fmt.Println("Set system template.")
|
||||
case MultilineTemplate:
|
||||
opts.Template = prompt
|
||||
prompt = ""
|
||||
fmt.Println("Set model template.\n")
|
||||
fmt.Println("Set model template.")
|
||||
}
|
||||
multiline = MultilineNone
|
||||
case strings.HasPrefix(line, `"""`) && len(prompt) == 0:
|
||||
@@ -784,9 +793,9 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||
if found {
|
||||
opts.System = prompt
|
||||
if args[1] == "system" {
|
||||
fmt.Println("Set system template.\n")
|
||||
fmt.Println("Set system template.")
|
||||
} else {
|
||||
fmt.Println("Set prompt template.\n")
|
||||
fmt.Println("Set prompt template.")
|
||||
}
|
||||
prompt = ""
|
||||
} else {
|
||||
@@ -800,7 +809,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||
}
|
||||
} else {
|
||||
opts.System = line
|
||||
fmt.Println("Set system template.\n")
|
||||
fmt.Println("Set system template.")
|
||||
}
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
@@ -977,7 +986,7 @@ func initializeKeypair() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func startMacApp(client *api.Client) error {
|
||||
func startMacApp(ctx context.Context, client *api.Client) error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1001,24 +1010,24 @@ func startMacApp(client *api.Client) error {
|
||||
case <-timeout:
|
||||
return errors.New("timed out waiting for server to start")
|
||||
case <-tick:
|
||||
if err := client.Heartbeat(context.Background()); err == nil {
|
||||
if err := client.Heartbeat(ctx); err == nil {
|
||||
return nil // server has started
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
|
||||
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := client.Heartbeat(context.Background()); err != nil {
|
||||
if err := client.Heartbeat(cmd.Context()); err != nil {
|
||||
if !strings.Contains(err.Error(), "connection refused") {
|
||||
return err
|
||||
}
|
||||
if runtime.GOOS == "darwin" {
|
||||
if err := startMacApp(client); err != nil {
|
||||
if err := startMacApp(cmd.Context(), client); err != nil {
|
||||
return fmt.Errorf("could not connect to ollama app, is it running?")
|
||||
}
|
||||
} else {
|
||||
@@ -1028,8 +1037,29 @@ func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func versionHandler(cmd *cobra.Command, _ []string) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
serverVersion, err := client.Version(cmd.Context())
|
||||
if err != nil {
|
||||
fmt.Println("Warning: could not connect to a running Ollama instance")
|
||||
}
|
||||
|
||||
if serverVersion != "" {
|
||||
fmt.Printf("ollama version is %s\n", serverVersion)
|
||||
}
|
||||
|
||||
if serverVersion != version.Version {
|
||||
fmt.Printf("Warning: client version is %s\n", version.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func NewCLI() *cobra.Command {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
cobra.EnableCommandSorting = false
|
||||
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "ollama",
|
||||
@@ -1039,10 +1069,17 @@ func NewCLI() *cobra.Command {
|
||||
CompletionOptions: cobra.CompletionOptions{
|
||||
DisableDefaultCmd: true,
|
||||
},
|
||||
Version: version.Version,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if version, _ := cmd.Flags().GetBool("version"); version {
|
||||
versionHandler(cmd, args)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Print(cmd.UsageString())
|
||||
},
|
||||
}
|
||||
|
||||
cobra.EnableCommandSorting = false
|
||||
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
|
148
docs/api.md
148
docs/api.md
@@ -24,7 +24,7 @@ All durations are returned in nanoseconds.
|
||||
|
||||
### Streaming responses
|
||||
|
||||
Certain endpoints stream responses as JSON objects delineated with the newline (`\n`) character.
|
||||
Certain endpoints stream responses as JSON objects.
|
||||
|
||||
## Generate a completion
|
||||
|
||||
@@ -32,7 +32,7 @@ Certain endpoints stream responses as JSON objects delineated with the newline (
|
||||
POST /api/generate
|
||||
```
|
||||
|
||||
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses. The final response object will include statistics and additional data from the request.
|
||||
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
|
||||
|
||||
### Parameters
|
||||
|
||||
@@ -47,7 +47,7 @@ Advanced parameters (optional):
|
||||
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
|
||||
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself.
|
||||
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API.
|
||||
|
||||
### JSON mode
|
||||
|
||||
@@ -114,6 +114,8 @@ To calculate how fast the response is generated in tokens per second (token/s),
|
||||
|
||||
#### Request (No streaming)
|
||||
|
||||
A response can be recieved in one reply when streaming is off.
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
@@ -144,9 +146,9 @@ If `stream` is set to `false`, the response will be a single JSON object:
|
||||
}
|
||||
```
|
||||
|
||||
#### Request (Raw mode)
|
||||
#### Request (Raw Mode)
|
||||
|
||||
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context.
|
||||
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting.
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
@@ -164,6 +166,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
"model": "mistral",
|
||||
"created_at": "2023-11-03T15:36:02.583064Z",
|
||||
"response": " The sky appears blue because of a phenomenon called Rayleigh scattering.",
|
||||
"context": [1, 2, 3],
|
||||
"done": true,
|
||||
"total_duration": 14648695333,
|
||||
"load_duration": 3302671417,
|
||||
@@ -249,7 +252,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
"penalize_newline": true,
|
||||
"stop": ["\n", "user:"],
|
||||
"numa": false,
|
||||
"num_ctx": 4,
|
||||
"num_ctx": 1024,
|
||||
"num_batch": 2,
|
||||
"num_gqa": 1,
|
||||
"num_gpu": 1,
|
||||
@@ -264,7 +267,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
"rope_frequency_base": 1.1,
|
||||
"rope_frequency_scale": 0.8,
|
||||
"num_thread": 8
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -275,7 +278,6 @@ curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"response": "The sky is blue because it is the color of the sky.",
|
||||
"context": [1, 2, 3],
|
||||
"done": true,
|
||||
"total_duration": 5589157167,
|
||||
"load_duration": 3013701500,
|
||||
@@ -288,6 +290,136 @@ curl http://localhost:11434/api/generate -d '{
|
||||
}
|
||||
```
|
||||
|
||||
## Send Chat Messages (coming in 0.1.14)
|
||||
|
||||
```shell
|
||||
POST /api/chat
|
||||
```
|
||||
|
||||
Generate the next message in a chat with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
|
||||
|
||||
### Parameters
|
||||
|
||||
- `model`: (required) the [model name](#model-names)
|
||||
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Currently the only accepted value is `json`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
|
||||
Send a chat message with a streaming response.
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "why is the sky blue?"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
A stream of JSON objects is returned:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||
"message": {
|
||||
"role": "assisant",
|
||||
"content": "The"
|
||||
},
|
||||
"done": false
|
||||
}
|
||||
```
|
||||
|
||||
Final response:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"done": true,
|
||||
"total_duration": 5589157167,
|
||||
"load_duration": 3013701500,
|
||||
"sample_count": 114,
|
||||
"sample_duration": 81442000,
|
||||
"prompt_eval_count": 46,
|
||||
"prompt_eval_duration": 1160282000,
|
||||
"eval_count": 113,
|
||||
"eval_duration": 1325948000
|
||||
}
|
||||
```
|
||||
|
||||
#### Request (With History)
|
||||
|
||||
Send a chat message with a conversation history.
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "why is the sky blue?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "due to rayleigh scattering."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how is that different than mie scattering?"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
A stream of JSON objects is returned:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||
"message": {
|
||||
"role": "assisant",
|
||||
"content": "The"
|
||||
},
|
||||
"done": false
|
||||
}
|
||||
```
|
||||
|
||||
Final response:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"done": true,
|
||||
"total_duration": 5589157167,
|
||||
"load_duration": 3013701500,
|
||||
"sample_count": 114,
|
||||
"sample_duration": 81442000,
|
||||
"prompt_eval_count": 46,
|
||||
"prompt_eval_duration": 1160282000,
|
||||
"eval_count": 113,
|
||||
"eval_duration": 1325948000
|
||||
}
|
||||
```
|
||||
|
||||
## Create a Model
|
||||
|
||||
```shell
|
||||
|
@@ -43,7 +43,6 @@ Ollama supports a set of model architectures, with support for more coming soon:
|
||||
|
||||
- Llama & Mistral
|
||||
- Falcon & RW
|
||||
- GPT-NeoX
|
||||
- BigCode
|
||||
|
||||
To view a model's architecture, check the `config.json` file in its HuggingFace repo. You should see an entry under `architectures` (e.g. `LlamaForCausalLM`).
|
||||
@@ -184,9 +183,6 @@ python convert.py <path to model directory>
|
||||
# FalconForCausalLM
|
||||
python convert-falcon-hf-to-gguf.py <path to model directory>
|
||||
|
||||
# GPTNeoXForCausalLM
|
||||
python convert-gptneox-hf-to-gguf.py <path to model directory>
|
||||
|
||||
# GPTBigCodeForCausalLM
|
||||
python convert-starcoder-hf-to-gguf.py <path to model directory>
|
||||
```
|
||||
|
83
docs/tutorials/fly-gpu.md
Normal file
83
docs/tutorials/fly-gpu.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# Running Ollama on Fly.io GPU Instances
|
||||
|
||||
Ollama runs with little to no configuration on [Fly.io GPU instances](https://fly.io/docs/gpus/gpu-quickstart/). If you don't have access to GPUs yet, you'll need to [apply for access](https://fly.io/gpu/) on the waitlist. Once you're accepted, you'll get an email with instructions on how to get started.
|
||||
|
||||
Create a new app with `fly apps create`:
|
||||
|
||||
```bash
|
||||
fly apps create
|
||||
```
|
||||
|
||||
Then create a `fly.toml` file in a new folder that looks like this:
|
||||
|
||||
```toml
|
||||
app = "sparkling-violet-709"
|
||||
primary_region = "ord"
|
||||
vm.size = "a100-40gb" # see https://fly.io/docs/gpus/gpu-quickstart/ for more info
|
||||
|
||||
[build]
|
||||
image = "ollama/ollama"
|
||||
|
||||
[http_service]
|
||||
internal_port = 11434
|
||||
force_https = false
|
||||
auto_stop_machines = true
|
||||
auto_start_machines = true
|
||||
min_machines_running = 0
|
||||
processes = ["app"]
|
||||
|
||||
[mounts]
|
||||
source = "models"
|
||||
destination = "/root/.ollama"
|
||||
initial_size = "100gb"
|
||||
```
|
||||
|
||||
Then create a [new private IPv6 address](https://fly.io/docs/reference/private-networking/#flycast-private-load-balancing) for your app:
|
||||
|
||||
```bash
|
||||
fly ips allocate-v6 --private
|
||||
```
|
||||
|
||||
Then deploy your app:
|
||||
|
||||
```bash
|
||||
fly deploy
|
||||
```
|
||||
|
||||
And finally you can access it interactively with a new Fly.io Machine:
|
||||
|
||||
```
|
||||
fly machine run -e OLLAMA_HOST=http://your-app-name.flycast --shell ollama/ollama
|
||||
```
|
||||
|
||||
```bash
|
||||
$ ollama run openchat:7b-v3.5-fp16
|
||||
>>> How do I bake chocolate chip cookies?
|
||||
To bake chocolate chip cookies, follow these steps:
|
||||
|
||||
1. Preheat the oven to 375°F (190°C) and line a baking sheet with parchment paper or silicone baking mat.
|
||||
|
||||
2. In a large bowl, mix together 1 cup of unsalted butter (softened), 3/4 cup granulated sugar, and 3/4
|
||||
cup packed brown sugar until light and fluffy.
|
||||
|
||||
3. Add 2 large eggs, one at a time, to the butter mixture, beating well after each addition. Stir in 1
|
||||
teaspoon of pure vanilla extract.
|
||||
|
||||
4. In a separate bowl, whisk together 2 cups all-purpose flour, 1/2 teaspoon baking soda, and 1/2 teaspoon
|
||||
salt. Gradually add the dry ingredients to the wet ingredients, stirring until just combined.
|
||||
|
||||
5. Fold in 2 cups of chocolate chips (or chunks) into the dough.
|
||||
|
||||
6. Drop rounded tablespoons of dough onto the prepared baking sheet, spacing them about 2 inches apart.
|
||||
|
||||
7. Bake for 10-12 minutes, or until the edges are golden brown. The centers should still be slightly soft.
|
||||
|
||||
8. Allow the cookies to cool on the baking sheet for a few minutes before transferring them to a wire rack
|
||||
to cool completely.
|
||||
|
||||
Enjoy your homemade chocolate chip cookies!
|
||||
```
|
||||
|
||||
When you set it up like this, it will automatically turn off when you're done using it. Then when you access it again, it will automatically turn back on. This is a great way to save money on GPU instances when you're not using them. If you want a persistent wake-on-use connection to your Ollama instance, you can set up a [connection to your Fly network using WireGuard](https://fly.io/docs/reference/private-networking/#discovering-apps-through-dns-on-a-wireguard-connection). Then you can access your Ollama instance at `http://your-app-name.flycast`.
|
||||
|
||||
And that's it!
|
46
examples/python-simplechat/client.py
Normal file
46
examples/python-simplechat/client.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import json
|
||||
import requests
|
||||
|
||||
# NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
|
||||
model = "llama2" # TODO: update this for whatever model you wish to use
|
||||
|
||||
|
||||
def chat(messages):
|
||||
r = requests.post(
|
||||
"http://0.0.0.0:11434/api/chat",
|
||||
json={"model": model, "messages": messages, "stream": True},
|
||||
)
|
||||
r.raise_for_status()
|
||||
output = ""
|
||||
|
||||
for line in r.iter_lines():
|
||||
body = json.loads(line)
|
||||
if "error" in body:
|
||||
raise Exception(body["error"])
|
||||
if body.get("done") is False:
|
||||
message = body.get("message", "")
|
||||
content = message.get("content", "")
|
||||
output += content
|
||||
# the response streams one token at a time, print that as we receive it
|
||||
print(content, end="", flush=True)
|
||||
|
||||
|
||||
if body.get("done", False):
|
||||
message["content"] = output
|
||||
return message
|
||||
|
||||
|
||||
def main():
|
||||
messages = []
|
||||
|
||||
while True:
|
||||
user_input = input("Enter a prompt: ")
|
||||
print()
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
message = chat(messages)
|
||||
messages.append(message)
|
||||
print("\n\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
24
examples/python-simplechat/readme.md
Normal file
24
examples/python-simplechat/readme.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Simple Chat Example
|
||||
|
||||
The **chat** endpoint is one of two ways to generate text from an LLM with Ollama. At a high level you provide the endpoint an array of objects with a role and content specified. Then with each output and prompt, you add more of those role/content objects, which builds up the history.
|
||||
|
||||
## Review the Code
|
||||
|
||||
You can see in the **chat** function that actually calling the endpoint is done simply with:
|
||||
|
||||
```python
|
||||
r = requests.post(
|
||||
"http://0.0.0.0:11434/api/chat",
|
||||
json={"model": model, "messages": messages, "stream": True},
|
||||
)
|
||||
```
|
||||
|
||||
With the **generate** endpoint, you need to provide a `prompt`. But with **chat**, you provide `messages`. And the resulting stream of responses includes a `message` object with a `content` field.
|
||||
|
||||
The final JSON object doesn't provide the full content, so you will need to build the content yourself.
|
||||
|
||||
In the **main** function, we collect `user_input` and add it as a message to our messages and that is passed to the chat function. When the LLM is done responding the output is added as another message.
|
||||
|
||||
## Next Steps
|
||||
|
||||
In this example, all generations are kept. You might want to experiment with summarizing everything older than 10 conversations to enable longer history with less context being used.
|
77
examples/typescript-simplechat/client.ts
Normal file
77
examples/typescript-simplechat/client.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
import * as readline from "readline";
|
||||
|
||||
const model = "llama2";
|
||||
type Message = {
|
||||
role: "assistant" | "user" | "system";
|
||||
content: string;
|
||||
}
|
||||
const messages: Message[] = [{
|
||||
role: "system",
|
||||
content: "You are a helpful AI agent."
|
||||
}]
|
||||
|
||||
const rl = readline.createInterface({
|
||||
input: process.stdin,
|
||||
output: process.stdout
|
||||
})
|
||||
|
||||
async function chat(messages: Message[]): Promise<Message> {
|
||||
const body = {
|
||||
model: model,
|
||||
messages: messages
|
||||
}
|
||||
|
||||
const response = await fetch("http://localhost:11434/api/chat", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(body)
|
||||
})
|
||||
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) {
|
||||
throw new Error("Failed to read response body")
|
||||
}
|
||||
let content = ""
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
const rawjson = new TextDecoder().decode(value);
|
||||
const json = JSON.parse(rawjson)
|
||||
|
||||
if (json.done === false) {
|
||||
process.stdout.write(json.message.content);
|
||||
content += json.message.content
|
||||
}
|
||||
|
||||
}
|
||||
return { role: "assistant", content: content };
|
||||
}
|
||||
|
||||
async function askQuestion(): Promise<void> {
|
||||
return new Promise<void>((resolve) => {
|
||||
rl.question("\n\nAsk a question: (press enter alone to quit)\n\n", async (user_input) => {
|
||||
if (user_input.trim() === "") {
|
||||
rl.close();
|
||||
console.log("Thankyou. Goodbye.\n")
|
||||
console.log("=======\nHere is the message history that was used in this conversation.\n=======\n")
|
||||
messages.forEach(message => {
|
||||
console.log(message)
|
||||
})
|
||||
resolve();
|
||||
} else {
|
||||
console.log();
|
||||
messages.push({ role: "user", content: user_input });
|
||||
messages.push(await chat(messages));
|
||||
await askQuestion(); // Ask the next question
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
async function main() {
|
||||
await askQuestion();
|
||||
|
||||
}
|
||||
|
||||
main();
|
1
examples/typescript-simplechat/package.json
Normal file
1
examples/typescript-simplechat/package.json
Normal file
@@ -0,0 +1 @@
|
||||
{ "dependencies": { "@types/node": "^20.10.4", "prompt-sync": "^4.2.0", "readline": "^1.3.0" } }
|
39
examples/typescript-simplechat/readme.md
Normal file
39
examples/typescript-simplechat/readme.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# Simple Chat Example
|
||||
|
||||
The **chat** endpoint is one of two ways to generate text from an LLM with Ollama. At a high level you provide the endpoint an array of message objects with a role and content specified. Then with each output and prompt, you add more messages, which builds up the history.
|
||||
|
||||
## Run the Example
|
||||
|
||||
There are a few ways to run this, just like any Typescript code:
|
||||
|
||||
1. Compile with `tsc` and then run it with `node client.js`.
|
||||
2. Install `tsx` and run it with `tsx client.ts`.
|
||||
3. Install `bun` and run it with `bun client.ts`.
|
||||
|
||||
## Review the Code
|
||||
|
||||
You can see in the **chat** function that is actually calling the endpoint is simply done with:
|
||||
|
||||
```typescript
|
||||
const body = {
|
||||
model: model,
|
||||
messages: messages
|
||||
}
|
||||
|
||||
const response = await fetch("http://localhost:11434/api/chat", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(body)
|
||||
})
|
||||
```
|
||||
|
||||
With the **generate** endpoint, you need to provide a `prompt`. But with **chat**, you provide `messages`. And the resulting stream of responses includes a `message` object with a `content` field.
|
||||
|
||||
The final JSON object doesn't provide the full content, so you will need to build the content yourself. In this example, **chat** takes the full array of messages and outputs the resulting message from this call of the chat endpoint.
|
||||
|
||||
In the **askQuestion** function, we collect `user_input` and add it as a message to our messages and that is passed to the chat function. When the LLM is done responding the output is added as another message to the messages array.
|
||||
|
||||
At the end, you will see a printout of all the messages.
|
||||
|
||||
## Next Steps
|
||||
|
||||
In this example, all generations are kept. You might want to experiment with summarizing everything older than 10 conversations to enable longer history with less context being used.
|
7
go.mod
7
go.mod
@@ -5,14 +5,15 @@ go 1.20
|
||||
require (
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
golang.org/x/sync v0.3.0
|
||||
)
|
||||
|
||||
require github.com/rivo/uniseg v0.2.0 // indirect
|
||||
require (
|
||||
github.com/mattn/go-runewidth v0.0.14 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
|
2
go.sum
2
go.sum
@@ -63,8 +63,6 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
|
||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
|
89
llm/ggml.go
89
llm/ggml.go
@@ -7,9 +7,10 @@ import (
|
||||
)
|
||||
|
||||
type GGML struct {
|
||||
magic uint32
|
||||
container
|
||||
model
|
||||
|
||||
Size int64
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -82,7 +83,7 @@ type model interface {
|
||||
|
||||
type container interface {
|
||||
Name() string
|
||||
Decode(io.Reader) (model, error)
|
||||
Decode(*readSeekOffset) (model, error)
|
||||
}
|
||||
|
||||
type containerGGML struct{}
|
||||
@@ -91,7 +92,9 @@ func (c *containerGGML) Name() string {
|
||||
return "ggml"
|
||||
}
|
||||
|
||||
func (c *containerGGML) Decode(r io.Reader) (model, error) {
|
||||
func (c *containerGGML) Decode(ro *readSeekOffset) (model, error) {
|
||||
// file contents aren't decoded
|
||||
ro.Seek(0, io.SeekEnd)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -103,9 +106,9 @@ func (c *containerGGMF) Name() string {
|
||||
return "ggmf"
|
||||
}
|
||||
|
||||
func (c *containerGGMF) Decode(r io.Reader) (model, error) {
|
||||
func (c *containerGGMF) Decode(ro *readSeekOffset) (model, error) {
|
||||
var version uint32
|
||||
binary.Read(r, binary.LittleEndian, &version)
|
||||
binary.Read(ro, binary.LittleEndian, &version)
|
||||
|
||||
switch version {
|
||||
case 1:
|
||||
@@ -114,6 +117,10 @@ func (c *containerGGMF) Decode(r io.Reader) (model, error) {
|
||||
}
|
||||
|
||||
c.version = version
|
||||
|
||||
// remaining file contents aren't decoded
|
||||
ro.Seek(0, io.SeekEnd)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -125,9 +132,9 @@ func (c *containerGGJT) Name() string {
|
||||
return "ggjt"
|
||||
}
|
||||
|
||||
func (c *containerGGJT) Decode(r io.Reader) (model, error) {
|
||||
func (c *containerGGJT) Decode(ro *readSeekOffset) (model, error) {
|
||||
var version uint32
|
||||
binary.Read(r, binary.LittleEndian, &version)
|
||||
binary.Read(ro, binary.LittleEndian, &version)
|
||||
|
||||
switch version {
|
||||
case 1, 2, 3:
|
||||
@@ -139,7 +146,11 @@ func (c *containerGGJT) Decode(r io.Reader) (model, error) {
|
||||
|
||||
// different model types may have different layouts for hyperparameters
|
||||
var llama llamaModel
|
||||
binary.Read(r, binary.LittleEndian, &llama.hyperparameters)
|
||||
binary.Read(ro, binary.LittleEndian, &llama.hyperparameters)
|
||||
|
||||
// remaining file contents aren't decoded
|
||||
ro.Seek(0, io.SeekEnd)
|
||||
|
||||
return &llama, nil
|
||||
}
|
||||
|
||||
@@ -151,9 +162,9 @@ func (c *containerLORA) Name() string {
|
||||
return "ggla"
|
||||
}
|
||||
|
||||
func (c *containerLORA) Decode(r io.Reader) (model, error) {
|
||||
func (c *containerLORA) Decode(ro *readSeekOffset) (model, error) {
|
||||
var version uint32
|
||||
binary.Read(r, binary.LittleEndian, &version)
|
||||
binary.Read(ro, binary.LittleEndian, &version)
|
||||
|
||||
switch version {
|
||||
case 1:
|
||||
@@ -162,6 +173,10 @@ func (c *containerLORA) Decode(r io.Reader) (model, error) {
|
||||
}
|
||||
|
||||
c.version = version
|
||||
|
||||
// remaining file contents aren't decoded
|
||||
ro.Seek(0, io.SeekEnd)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -180,33 +195,61 @@ const (
|
||||
)
|
||||
|
||||
func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
|
||||
var ggml GGML
|
||||
binary.Read(r, binary.LittleEndian, &ggml.magic)
|
||||
ro := readSeekOffset{ReadSeeker: r}
|
||||
|
||||
switch ggml.magic {
|
||||
var magic uint32
|
||||
if err := binary.Read(&ro, binary.LittleEndian, &magic); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var c container
|
||||
switch magic {
|
||||
case FILE_MAGIC_GGML:
|
||||
ggml.container = &containerGGML{}
|
||||
c = &containerGGML{}
|
||||
case FILE_MAGIC_GGMF:
|
||||
ggml.container = &containerGGMF{}
|
||||
c = &containerGGMF{}
|
||||
case FILE_MAGIC_GGJT:
|
||||
ggml.container = &containerGGJT{}
|
||||
c = &containerGGJT{}
|
||||
case FILE_MAGIC_GGLA:
|
||||
ggml.container = &containerLORA{}
|
||||
c = &containerLORA{}
|
||||
case FILE_MAGIC_GGUF_LE:
|
||||
ggml.container = &containerGGUF{bo: binary.LittleEndian}
|
||||
c = &containerGGUF{bo: binary.LittleEndian}
|
||||
case FILE_MAGIC_GGUF_BE:
|
||||
ggml.container = &containerGGUF{bo: binary.BigEndian}
|
||||
c = &containerGGUF{bo: binary.BigEndian}
|
||||
default:
|
||||
return nil, errors.New("invalid file magic")
|
||||
}
|
||||
|
||||
model, err := ggml.Decode(r)
|
||||
model, err := c.Decode(&ro)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ggml.model = model
|
||||
|
||||
// final model type
|
||||
return &ggml, nil
|
||||
return &GGML{
|
||||
container: c,
|
||||
model: model,
|
||||
Size: ro.offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type readSeekOffset struct {
|
||||
io.ReadSeeker
|
||||
offset int64
|
||||
}
|
||||
|
||||
func (rso *readSeekOffset) Seek(offset int64, whence int) (int64, error) {
|
||||
offset, err := rso.ReadSeeker.Seek(offset, whence)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
rso.offset = offset
|
||||
return offset, nil
|
||||
}
|
||||
|
||||
func (rso *readSeekOffset) Read(p []byte) (int, error) {
|
||||
n, err := rso.ReadSeeker.Read(p)
|
||||
rso.offset += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
144
llm/gguf.go
144
llm/gguf.go
@@ -23,26 +23,24 @@ type containerGGUF struct {
|
||||
NumTensor uint64
|
||||
NumKV uint64
|
||||
}
|
||||
|
||||
parameters uint64
|
||||
}
|
||||
|
||||
func (c *containerGGUF) Name() string {
|
||||
return "gguf"
|
||||
}
|
||||
|
||||
func (c *containerGGUF) Decode(r io.Reader) (model, error) {
|
||||
binary.Read(r, c.bo, &c.Version)
|
||||
func (c *containerGGUF) Decode(rso *readSeekOffset) (model, error) {
|
||||
binary.Read(rso, c.bo, &c.Version)
|
||||
|
||||
switch c.Version {
|
||||
case 1:
|
||||
binary.Read(r, c.bo, &c.V1)
|
||||
binary.Read(rso, c.bo, &c.V1)
|
||||
default:
|
||||
binary.Read(r, c.bo, &c.V2)
|
||||
binary.Read(rso, c.bo, &c.V2)
|
||||
}
|
||||
|
||||
model := newGGUFModel(c)
|
||||
if err := model.Decode(r); err != nil {
|
||||
if err := model.Decode(rso); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -67,9 +65,23 @@ const (
|
||||
|
||||
type kv map[string]any
|
||||
|
||||
type tensor struct {
|
||||
name string
|
||||
kind uint32
|
||||
offset uint64
|
||||
size uint64
|
||||
|
||||
// shape is the number of elements in each dimension
|
||||
shape [4]uint64
|
||||
}
|
||||
|
||||
type ggufModel struct {
|
||||
*containerGGUF
|
||||
|
||||
kv
|
||||
tensors []tensor
|
||||
|
||||
parameters uint64
|
||||
}
|
||||
|
||||
func newGGUFModel(container *containerGGUF) *ggufModel {
|
||||
@@ -96,8 +108,7 @@ func (llm *ggufModel) NumKV() uint64 {
|
||||
}
|
||||
|
||||
func (llm *ggufModel) ModelFamily() string {
|
||||
t, ok := llm.kv["general.architecture"].(string)
|
||||
if ok {
|
||||
if t, ok := llm.kv["general.architecture"].(string); ok {
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -134,57 +145,56 @@ func (llm *ggufModel) ModelType() string {
|
||||
}
|
||||
|
||||
func (llm *ggufModel) FileType() string {
|
||||
t, ok := llm.kv["general.file_type"].(uint32)
|
||||
if ok {
|
||||
if t, ok := llm.kv["general.file_type"].(uint32); ok {
|
||||
return fileType(t)
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (llm *ggufModel) Decode(r io.Reader) error {
|
||||
func (llm *ggufModel) Decode(rso *readSeekOffset) error {
|
||||
// decode key-values
|
||||
for i := 0; uint64(i) < llm.NumKV(); i++ {
|
||||
k, err := llm.readString(r)
|
||||
k, err := llm.readString(rso)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vtype := llm.readU32(r)
|
||||
vtype := llm.readU32(rso)
|
||||
|
||||
var v any
|
||||
switch vtype {
|
||||
case ggufTypeUint8:
|
||||
v = llm.readU8(r)
|
||||
v = llm.readU8(rso)
|
||||
case ggufTypeInt8:
|
||||
v = llm.readI8(r)
|
||||
v = llm.readI8(rso)
|
||||
case ggufTypeUint16:
|
||||
v = llm.readU16(r)
|
||||
v = llm.readU16(rso)
|
||||
case ggufTypeInt16:
|
||||
v = llm.readI16(r)
|
||||
v = llm.readI16(rso)
|
||||
case ggufTypeUint32:
|
||||
v = llm.readU32(r)
|
||||
v = llm.readU32(rso)
|
||||
case ggufTypeInt32:
|
||||
v = llm.readI32(r)
|
||||
v = llm.readI32(rso)
|
||||
case ggufTypeUint64:
|
||||
v = llm.readU64(r)
|
||||
v = llm.readU64(rso)
|
||||
case ggufTypeInt64:
|
||||
v = llm.readI64(r)
|
||||
v = llm.readI64(rso)
|
||||
case ggufTypeFloat32:
|
||||
v = llm.readF32(r)
|
||||
v = llm.readF32(rso)
|
||||
case ggufTypeFloat64:
|
||||
v = llm.readF64(r)
|
||||
v = llm.readF64(rso)
|
||||
case ggufTypeBool:
|
||||
v = llm.readBool(r)
|
||||
v = llm.readBool(rso)
|
||||
case ggufTypeString:
|
||||
s, err := llm.readString(r)
|
||||
s, err := llm.readString(rso)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v = s
|
||||
case ggufTypeArray:
|
||||
a, err := llm.readArray(r)
|
||||
a, err := llm.readArray(rso)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -199,21 +209,85 @@ func (llm *ggufModel) Decode(r io.Reader) error {
|
||||
|
||||
// decode tensors
|
||||
for i := 0; uint64(i) < llm.NumTensor(); i++ {
|
||||
if _, err := llm.readString(r); err != nil {
|
||||
name, err := llm.readString(rso)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dimensions := llm.readU32(r)
|
||||
// dims is the number of dimensions in the tensor
|
||||
dims := llm.readU32(rso)
|
||||
|
||||
var elements uint64 = 1
|
||||
for i := 0; uint32(i) < dimensions; i++ {
|
||||
elements *= llm.readU64(r)
|
||||
shape := [4]uint64{1, 1, 1, 1}
|
||||
for i := 0; uint32(i) < dims; i++ {
|
||||
shape[i] = llm.readU64(rso)
|
||||
}
|
||||
|
||||
llm.readU32(r) // type
|
||||
llm.readU64(r) // offset
|
||||
kind := llm.readU32(rso)
|
||||
offset := llm.readU64(rso)
|
||||
|
||||
llm.parameters += elements
|
||||
var blockSize uint64
|
||||
switch {
|
||||
case kind < 2:
|
||||
blockSize = 1
|
||||
case kind < 10:
|
||||
blockSize = 32
|
||||
default:
|
||||
blockSize = 256
|
||||
}
|
||||
|
||||
var typeSize uint64
|
||||
switch kind {
|
||||
case 0: // FP32
|
||||
typeSize = 4
|
||||
case 1: // FP16
|
||||
typeSize = 2
|
||||
case 2: // Q4_0
|
||||
typeSize = 2 + blockSize/2
|
||||
case 3: // Q4_1
|
||||
typeSize = 2 + 2 + blockSize/2
|
||||
case 6: // Q5_0
|
||||
typeSize = 2 + 4 + blockSize/2
|
||||
case 7: // Q5_1
|
||||
typeSize = 2 + 2 + 4 + blockSize/2
|
||||
case 8: // Q8_0
|
||||
typeSize = 2 + blockSize
|
||||
case 9: // Q8_1
|
||||
typeSize = 4 + 4 + blockSize
|
||||
case 10: // Q2_K
|
||||
typeSize = blockSize/16 + blockSize/4 + 2 + 2
|
||||
case 11: // Q3_K
|
||||
typeSize = blockSize/8 + blockSize/4 + 12 + 2
|
||||
case 12: // Q4_K
|
||||
typeSize = 2 + 2 + 12 + blockSize/2
|
||||
case 13: // Q5_K
|
||||
typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2
|
||||
case 14: // Q6_K
|
||||
typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2
|
||||
}
|
||||
|
||||
parameters := shape[0] * shape[1] * shape[2] * shape[3]
|
||||
size := parameters * typeSize / blockSize
|
||||
|
||||
llm.tensors = append(llm.tensors, tensor{
|
||||
name: name,
|
||||
kind: kind,
|
||||
offset: offset,
|
||||
size: size,
|
||||
shape: shape,
|
||||
})
|
||||
|
||||
llm.parameters += parameters
|
||||
}
|
||||
|
||||
alignment, ok := llm.kv["general.alignment"].(uint32)
|
||||
if !ok {
|
||||
alignment = 32
|
||||
}
|
||||
|
||||
rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
|
||||
for _, tensor := range llm.tensors {
|
||||
padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1)
|
||||
rso.Seek(padded, io.SeekCurrent)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
Submodule llm/llama.cpp/gguf updated: 9656026b53...23b5e12eb5
113
llm/llama.go
113
llm/llama.go
@@ -59,6 +59,7 @@ ws ::= ([ \t\n] ws)?
|
||||
var llamaCppEmbed embed.FS
|
||||
|
||||
type ModelRunner struct {
|
||||
Type string // "gguf" or "ggml"
|
||||
Path string // path to the model runner executable
|
||||
Accelerated bool
|
||||
}
|
||||
@@ -72,25 +73,25 @@ func chooseRunners(workDir, runnerType string) []ModelRunner {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
if runtime.GOARCH == "arm64" {
|
||||
runners = []ModelRunner{{Path: path.Join(buildPath, "metal", "bin", "ollama-runner")}}
|
||||
runners = []ModelRunner{{Type: runnerType, Path: path.Join(buildPath, "metal", "bin", "ollama-runner")}}
|
||||
} else {
|
||||
runners = []ModelRunner{{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")}}
|
||||
runners = []ModelRunner{{Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")}}
|
||||
}
|
||||
case "linux":
|
||||
runners = []ModelRunner{
|
||||
{Path: path.Join(buildPath, "cuda", "bin", "ollama-runner"), Accelerated: true},
|
||||
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
|
||||
{Type: runnerType, Path: path.Join(buildPath, "cuda", "bin", "ollama-runner"), Accelerated: true},
|
||||
{Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
|
||||
}
|
||||
case "windows":
|
||||
// TODO: select windows GPU runner here when available
|
||||
runners = []ModelRunner{
|
||||
{Path: path.Join(buildPath, "cuda", "bin", "Release", "ollama-runner.exe"), Accelerated: true},
|
||||
{Path: path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe")},
|
||||
{Type: runnerType, Path: path.Join(buildPath, "cuda", "bin", "Release", "ollama-runner.exe"), Accelerated: true},
|
||||
{Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe")},
|
||||
}
|
||||
default:
|
||||
log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
|
||||
runners = []ModelRunner{
|
||||
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
|
||||
{Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,6 +149,7 @@ func chooseRunners(workDir, runnerType string) []ModelRunner {
|
||||
for _, r := range runners {
|
||||
// clean the ModelRunner paths so that they match the OS we are running on
|
||||
localRunnersByPriority = append(localRunnersByPriority, ModelRunner{
|
||||
Type: r.Type,
|
||||
Path: filepath.Clean(path.Join(workDir, r.Path)),
|
||||
Accelerated: r.Accelerated,
|
||||
})
|
||||
@@ -325,7 +327,7 @@ func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||
return os.Stderr.Write(b)
|
||||
}
|
||||
|
||||
func newLlama(model string, adapters []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
|
||||
func newLlama(model string, adapters, projectors []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
|
||||
fileInfo, err := os.Stat(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -365,6 +367,11 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
|
||||
params = append(params, "--lora", adapters[0])
|
||||
}
|
||||
|
||||
if len(projectors) > 0 {
|
||||
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
|
||||
params = append(params, "--mmproj", projectors[0])
|
||||
}
|
||||
|
||||
if opts.NumThread > 0 {
|
||||
params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
|
||||
}
|
||||
@@ -397,11 +404,17 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
|
||||
}
|
||||
|
||||
port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||
params := append(params, "--port", strconv.Itoa(port))
|
||||
|
||||
if runner.Type == "gguf" {
|
||||
params = append(params, "--parallel", "2")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cmd := exec.CommandContext(
|
||||
ctx,
|
||||
runner.Path,
|
||||
append(params, "--port", strconv.Itoa(port))...,
|
||||
params...,
|
||||
)
|
||||
|
||||
var libraryPaths []string
|
||||
@@ -531,21 +544,28 @@ type prediction struct {
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
|
||||
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
|
||||
prevConvo, err := llm.Decode(ctx, prevContext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
type PredictOpts struct {
|
||||
Prompt string
|
||||
Format string
|
||||
CheckpointStart time.Time
|
||||
CheckpointLoaded time.Time
|
||||
}
|
||||
|
||||
// Remove leading spaces from prevConvo if present
|
||||
prevConvo = strings.TrimPrefix(prevConvo, " ")
|
||||
|
||||
var nextContext strings.Builder
|
||||
nextContext.WriteString(prevConvo)
|
||||
nextContext.WriteString(prompt)
|
||||
type PredictResult struct {
|
||||
CreatedAt time.Time
|
||||
TotalDuration time.Duration
|
||||
LoadDuration time.Duration
|
||||
Content string
|
||||
Done bool
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
}
|
||||
|
||||
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
||||
request := map[string]any{
|
||||
"prompt": nextContext.String(),
|
||||
"prompt": predict.Prompt,
|
||||
"stream": true,
|
||||
"n_predict": llm.NumPredict,
|
||||
"n_keep": llm.NumKeep,
|
||||
@@ -567,7 +587,7 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||
"stop": llm.Stop,
|
||||
}
|
||||
|
||||
if format == "json" {
|
||||
if predict.Format == "json" {
|
||||
request["grammar"] = jsonGrammar
|
||||
}
|
||||
|
||||
@@ -617,34 +637,35 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||
continue
|
||||
}
|
||||
|
||||
if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
|
||||
var p prediction
|
||||
if err := json.Unmarshal(evt, &p); err != nil {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||
}
|
||||
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
||||
if !ok {
|
||||
return fmt.Errorf("error parsing llm response stream: %s", line)
|
||||
}
|
||||
|
||||
if p.Content != "" {
|
||||
fn(api.GenerateResponse{Response: p.Content})
|
||||
nextContext.WriteString(p.Content)
|
||||
}
|
||||
var p prediction
|
||||
if err := json.Unmarshal(evt, &p); err != nil {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||
}
|
||||
|
||||
if p.Stop {
|
||||
embd, err := llm.Encode(ctx, nextContext.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding context: %v", err)
|
||||
}
|
||||
if p.Content != "" {
|
||||
fn(PredictResult{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Content: p.Content,
|
||||
})
|
||||
}
|
||||
|
||||
fn(api.GenerateResponse{
|
||||
Done: true,
|
||||
Context: embd,
|
||||
PromptEvalCount: p.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
||||
EvalCount: p.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
||||
})
|
||||
if p.Stop {
|
||||
fn(PredictResult{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
TotalDuration: time.Since(predict.CheckpointStart),
|
||||
|
||||
return nil
|
||||
}
|
||||
Done: true,
|
||||
PromptEvalCount: p.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
||||
EvalCount: p.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
type LLM interface {
|
||||
Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
|
||||
Predict(context.Context, PredictOpts, func(PredictResult)) error
|
||||
Embedding(context.Context, string) ([]float64, error)
|
||||
Encode(context.Context, string) ([]int, error)
|
||||
Decode(context.Context, []int) (string, error)
|
||||
@@ -23,7 +23,7 @@ type LLM interface {
|
||||
Ping(context.Context) error
|
||||
}
|
||||
|
||||
func New(workDir, model string, adapters []string, opts api.Options) (LLM, error) {
|
||||
func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
|
||||
if _, err := os.Stat(model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -82,9 +82,9 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
|
||||
opts.NumGQA = 0
|
||||
opts.RopeFrequencyBase = 0.0
|
||||
opts.RopeFrequencyScale = 0.0
|
||||
return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
|
||||
return newLlama(model, adapters, projectors, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
|
||||
case "ggml", "ggmf", "ggjt", "ggla":
|
||||
return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
|
||||
return newLlama(model, adapters, projectors, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
|
||||
}
|
||||
|
@@ -37,10 +37,13 @@ func Parse(reader io.Reader) ([]Command, error) {
|
||||
switch string(bytes.ToUpper(fields[0])) {
|
||||
case "FROM":
|
||||
command.Name = "model"
|
||||
command.Args = string(fields[1])
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
// copy command for validation
|
||||
modelCommand = command
|
||||
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "ADAPTER":
|
||||
case "ADAPTER":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
command.Args = string(fields[1])
|
||||
case "PARAMETER":
|
||||
@@ -50,7 +53,7 @@ func Parse(reader io.Reader) ([]Command, error) {
|
||||
}
|
||||
|
||||
command.Name = string(fields[0])
|
||||
command.Args = string(fields[1])
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
case "EMBED":
|
||||
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
|
||||
default:
|
||||
|
@@ -191,6 +191,15 @@ func (i *Instance) Readline() (string, error) {
|
||||
buf.ClearScreen()
|
||||
case CharCtrlW:
|
||||
buf.DeleteWord()
|
||||
case CharCtrlZ:
|
||||
if err := UnsetRawMode(fd, termios); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
syscall.Kill(0, syscall.SIGSTOP)
|
||||
|
||||
// on resume...
|
||||
return "", nil
|
||||
case CharEnter:
|
||||
output := buf.String()
|
||||
if output != "" {
|
||||
|
@@ -217,7 +217,7 @@ fi
|
||||
|
||||
if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
|
||||
case $OS_NAME in
|
||||
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
|
||||
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
|
||||
rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
|
||||
fedora) install_cuda_driver_yum $OS_NAME $OS_VERSION ;;
|
||||
amzn) install_cuda_driver_yum 'fedora' '35' ;;
|
||||
@@ -230,7 +230,8 @@ fi
|
||||
if ! lsmod | grep -q nvidia; then
|
||||
KERNEL_RELEASE="$(uname -r)"
|
||||
case $OS_NAME in
|
||||
centos|rhel|rocky|amzn) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE kernel-headers-$KERNEL_RELEASE ;;
|
||||
rocky) $SUDO $PACKAGE_MANAGER -y install kernel-devel kernel-headers ;;
|
||||
centos|rhel|amzn) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE kernel-headers-$KERNEL_RELEASE ;;
|
||||
fedora) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE ;;
|
||||
debian|ubuntu) $SUDO apt-get -y install linux-headers-$KERNEL_RELEASE ;;
|
||||
*) exit ;;
|
||||
|
466
server/images.go
466
server/images.go
@@ -35,80 +35,157 @@ type RegistryOptions struct {
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Name string `json:"name"`
|
||||
ShortName string
|
||||
ModelPath string
|
||||
OriginalModel string
|
||||
AdapterPaths []string
|
||||
Template string
|
||||
System string
|
||||
License []string
|
||||
Digest string
|
||||
Options map[string]interface{}
|
||||
Name string `json:"name"`
|
||||
Config ConfigV2
|
||||
ShortName string
|
||||
ModelPath string
|
||||
OriginalModel string
|
||||
AdapterPaths []string
|
||||
ProjectorPaths []string
|
||||
Template string
|
||||
System string
|
||||
License []string
|
||||
Digest string
|
||||
Options map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
||||
t := m.Template
|
||||
if request.Template != "" {
|
||||
t = request.Template
|
||||
}
|
||||
type PromptVars struct {
|
||||
System string
|
||||
Prompt string
|
||||
Response string
|
||||
First bool
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").Parse(t)
|
||||
func (m *Model) Prompt(p PromptVars) (string, error) {
|
||||
var prompt strings.Builder
|
||||
// Use the "missingkey=zero" option to handle missing variables without panicking
|
||||
tmpl, err := template.New("").Option("missingkey=zero").Parse(m.Template)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var vars struct {
|
||||
First bool
|
||||
System string
|
||||
Prompt string
|
||||
if p.System == "" {
|
||||
// use the default system prompt for this model if one is not specified
|
||||
p.System = m.System
|
||||
}
|
||||
|
||||
vars.First = len(request.Context) == 0
|
||||
vars.System = m.System
|
||||
vars.Prompt = request.Prompt
|
||||
|
||||
if request.System != "" {
|
||||
vars.System = request.System
|
||||
vars := map[string]any{
|
||||
"System": p.System,
|
||||
"Prompt": p.Prompt,
|
||||
"Response": p.Response,
|
||||
"First": p.First,
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
if err := tmpl.Execute(&sb, vars); err != nil {
|
||||
return "", err
|
||||
}
|
||||
prompt.WriteString(sb.String())
|
||||
prompt.WriteString(p.Response)
|
||||
return prompt.String(), nil
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
|
||||
// build the prompt from the list of messages
|
||||
var prompt strings.Builder
|
||||
currentVars := PromptVars{
|
||||
First: true,
|
||||
}
|
||||
|
||||
writePrompt := func() error {
|
||||
p, err := m.Prompt(currentVars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prompt.WriteString(p)
|
||||
currentVars = PromptVars{}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch strings.ToLower(msg.Role) {
|
||||
case "system":
|
||||
if currentVars.System != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
currentVars.System = msg.Content
|
||||
case "user":
|
||||
if currentVars.Prompt != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
currentVars.Prompt = msg.Content
|
||||
case "assistant":
|
||||
currentVars.Response = msg.Content
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
default:
|
||||
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
}
|
||||
}
|
||||
|
||||
// Append the last set of vars if they are non-empty
|
||||
if currentVars.Prompt != "" || currentVars.System != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return prompt.String(), nil
|
||||
}
|
||||
|
||||
type ManifestV2 struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config Layer `json:"config"`
|
||||
Config *Layer `json:"config"`
|
||||
Layers []*Layer `json:"layers"`
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
From string `json:"from,omitempty"`
|
||||
}
|
||||
|
||||
type LayerReader struct {
|
||||
Layer
|
||||
io.Reader
|
||||
}
|
||||
|
||||
type ConfigV2 struct {
|
||||
ModelFormat string `json:"model_format"`
|
||||
ModelFamily string `json:"model_family"`
|
||||
ModelType string `json:"model_type"`
|
||||
FileType string `json:"file_type"`
|
||||
RootFS RootFS `json:"rootfs"`
|
||||
ModelFormat string `json:"model_format"`
|
||||
ModelFamily string `json:"model_family"`
|
||||
ModelFamilies []string `json:"model_families"`
|
||||
ModelType string `json:"model_type"`
|
||||
FileType string `json:"file_type"`
|
||||
|
||||
// required by spec
|
||||
Architecture string `json:"architecture"`
|
||||
OS string `json:"os"`
|
||||
RootFS RootFS `json:"rootfs"`
|
||||
}
|
||||
|
||||
func (c *ConfigV2) SetModelFormat(format string) {
|
||||
if c.ModelFormat == "" {
|
||||
c.ModelFormat = format
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConfigV2) SetModelFamily(families ...string) {
|
||||
for _, family := range families {
|
||||
if c.ModelFamily == "" {
|
||||
c.ModelFamily = family
|
||||
}
|
||||
|
||||
if !slices.Contains(c.ModelFamilies, family) {
|
||||
c.ModelFamilies = append(c.ModelFamilies, family)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConfigV2) SetModelType(modelType string) {
|
||||
if c.ModelType == "" {
|
||||
c.ModelType = modelType
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConfigV2) SetFileType(fileType string) {
|
||||
if c.FileType == "" {
|
||||
c.FileType = fileType
|
||||
}
|
||||
}
|
||||
|
||||
type RootFS struct {
|
||||
@@ -167,6 +244,21 @@ func GetModel(name string) (*Model, error) {
|
||||
License: []string{},
|
||||
}
|
||||
|
||||
filename, err := GetBlobsPath(manifest.Config.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
configFile, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer configFile.Close()
|
||||
|
||||
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
filename, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
@@ -183,6 +275,8 @@ func GetModel(name string) (*Model, error) {
|
||||
log.Print("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
||||
case "application/vnd.ollama.image.adapter":
|
||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||
case "application/vnd.ollama.image.projector":
|
||||
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
||||
case "application/vnd.ollama.image.template":
|
||||
bts, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
@@ -256,11 +350,14 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
config := ConfigV2{
|
||||
OS: "linux",
|
||||
Architecture: "amd64",
|
||||
RootFS: RootFS{
|
||||
Type: "layers",
|
||||
},
|
||||
}
|
||||
|
||||
deleteMap := make(map[string]struct{})
|
||||
|
||||
var layers []*LayerReader
|
||||
var layers Layers
|
||||
|
||||
params := make(map[string][]string)
|
||||
fromParams := make(map[string]any)
|
||||
@@ -317,10 +414,10 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
return err
|
||||
}
|
||||
|
||||
config.ModelFormat = fromConfig.ModelFormat
|
||||
config.ModelFamily = fromConfig.ModelFamily
|
||||
config.ModelType = fromConfig.ModelType
|
||||
config.FileType = fromConfig.FileType
|
||||
config.SetModelFormat(fromConfig.ModelFormat)
|
||||
config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...)
|
||||
config.SetModelType(fromConfig.ModelType)
|
||||
config.SetFileType(fromConfig.FileType)
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
deleteMap[layer.Digest] = struct{}{}
|
||||
@@ -341,13 +438,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
}
|
||||
}
|
||||
|
||||
layer, err := GetLayerWithBufferFromLayer(layer)
|
||||
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layer.From = modelpath.GetShortTagname()
|
||||
layers = append(layers, layer)
|
||||
layers.Add(layer)
|
||||
}
|
||||
|
||||
deleteMap[manifest.Config.Digest] = struct{}{}
|
||||
@@ -355,25 +451,38 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
fn(api.ProgressResponse{Status: "creating model layer"})
|
||||
ggml, err := llm.DecodeGGML(bin)
|
||||
if err != nil {
|
||||
return err
|
||||
var offset int64
|
||||
for {
|
||||
fn(api.ProgressResponse{Status: "creating model layer"})
|
||||
|
||||
bin.Seek(offset, io.SeekStart)
|
||||
ggml, err := llm.DecodeGGML(bin)
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.SetModelFormat(ggml.Name())
|
||||
config.SetModelFamily(ggml.ModelFamily())
|
||||
config.SetModelType(ggml.ModelType())
|
||||
config.SetFileType(ggml.FileType())
|
||||
|
||||
mediatype := mediatype
|
||||
if ggml.ModelFamily() == "clip" {
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
sr := io.NewSectionReader(bin, offset, ggml.Size)
|
||||
layer, err := NewLayer(sr, mediatype)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layers.Add(layer)
|
||||
|
||||
offset += ggml.Size
|
||||
}
|
||||
|
||||
config.ModelFormat = ggml.Name()
|
||||
config.ModelFamily = ggml.ModelFamily()
|
||||
config.ModelType = ggml.ModelType()
|
||||
config.FileType = ggml.FileType()
|
||||
|
||||
bin.Seek(0, io.SeekStart)
|
||||
layer, err := CreateLayer(bin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layer.MediaType = mediatype
|
||||
layers = append(layers, layer)
|
||||
case "adapter":
|
||||
if strings.HasPrefix(c.Args, "@") {
|
||||
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
|
||||
@@ -383,7 +492,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
|
||||
c.Args = blobPath
|
||||
}
|
||||
|
||||
|
||||
fn(api.ProgressResponse{Status: "creating adapter layer"})
|
||||
bin, err := os.Open(realpath(modelFileDir, c.Args))
|
||||
if err != nil {
|
||||
@@ -391,41 +500,32 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
layer, err := CreateLayer(bin)
|
||||
layer, err := NewLayer(bin, mediatype)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if layer.Size > 0 {
|
||||
layer.MediaType = mediatype
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
layers.Add(layer)
|
||||
case "license":
|
||||
fn(api.ProgressResponse{Status: "creating license layer"})
|
||||
layer, err := CreateLayer(strings.NewReader(c.Args))
|
||||
|
||||
bin := strings.NewReader(c.Args)
|
||||
layer, err := NewLayer(bin, mediatype)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if layer.Size > 0 {
|
||||
layer.MediaType = mediatype
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
layers.Add(layer)
|
||||
case "template", "system":
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
|
||||
|
||||
// remove duplicate layers
|
||||
layers = removeLayerFromLayers(layers, mediatype)
|
||||
|
||||
layer, err := CreateLayer(strings.NewReader(c.Args))
|
||||
bin := strings.NewReader(c.Args)
|
||||
layer, err := NewLayer(bin, mediatype)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if layer.Size > 0 {
|
||||
layer.MediaType = mediatype
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
layers.Replace(layer)
|
||||
default:
|
||||
params[c.Name] = append(params[c.Name], c.Args)
|
||||
}
|
||||
@@ -457,40 +557,51 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "creating config layer"})
|
||||
layer, err := CreateLayer(bytes.NewReader(b.Bytes()))
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layer.MediaType = "application/vnd.ollama.image.params"
|
||||
layers = append(layers, layer)
|
||||
layers.Replace(layer)
|
||||
}
|
||||
|
||||
digests, err := getLayerDigests(layers)
|
||||
digests := make([]string, len(layers.items))
|
||||
for i, layer := range layers.items {
|
||||
digests[i] = layer.Digest
|
||||
}
|
||||
|
||||
config.RootFS.DiffIDs = digests
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configLayer, err := createConfigLayer(config, digests)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layers = append(layers, configLayer)
|
||||
delete(deleteMap, configLayer.Digest)
|
||||
|
||||
if err := SaveLayers(layers, fn, false); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, layer := range append(layers.items, configLayer) {
|
||||
committed, err := layer.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status := "writing layer"
|
||||
if !committed {
|
||||
status = "using already created layer"
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)})
|
||||
|
||||
var contentLayers []*Layer
|
||||
for _, layer := range layers {
|
||||
contentLayers = append(contentLayers, &layer.Layer)
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
if err := CreateManifest(name, configLayer, contentLayers); err != nil {
|
||||
if err := WriteManifest(name, configLayer, layers.items); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -504,119 +615,6 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
|
||||
return slices.DeleteFunc(layers, func(layer *LayerReader) bool {
|
||||
return layer.MediaType == mediaType
|
||||
})
|
||||
}
|
||||
|
||||
func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
|
||||
// Write each of the layers to disk
|
||||
for _, layer := range layers {
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = os.Stat(fp)
|
||||
if os.IsNotExist(err) || force {
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
|
||||
|
||||
out, err := os.Create(fp)
|
||||
if err != nil {
|
||||
log.Printf("couldn't create %s", fp)
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
if _, err = io.Copy(out, layer.Reader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
} else {
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
|
||||
mp := ParseModelPath(name)
|
||||
manifest := ManifestV2{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Config: Layer{
|
||||
MediaType: cfg.MediaType,
|
||||
Size: cfg.Size,
|
||||
Digest: cfg.Digest,
|
||||
},
|
||||
Layers: layers,
|
||||
}
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(fp, manifestJSON, 0o644)
|
||||
}
|
||||
|
||||
func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open blob: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
newLayer, err := CreateLayer(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newLayer.MediaType = layer.MediaType
|
||||
return newLayer, nil
|
||||
}
|
||||
|
||||
func getLayerDigests(layers []*LayerReader) ([]string, error) {
|
||||
var digests []string
|
||||
for _, l := range layers {
|
||||
if l.Digest == "" {
|
||||
return nil, fmt.Errorf("layer is missing a digest")
|
||||
}
|
||||
digests = append(digests, l.Digest)
|
||||
}
|
||||
return digests, nil
|
||||
}
|
||||
|
||||
// CreateLayer creates a Layer object from a given file
|
||||
func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
|
||||
digest, size := GetSHA256Digest(f)
|
||||
f.Seek(0, io.SeekStart)
|
||||
|
||||
layer := &LayerReader{
|
||||
Layer: Layer{
|
||||
MediaType: "application/vnd.docker.image.rootfs.diff.tar",
|
||||
Digest: digest,
|
||||
Size: size,
|
||||
},
|
||||
Reader: f,
|
||||
}
|
||||
|
||||
return layer, nil
|
||||
}
|
||||
|
||||
func CopyModel(src, dest string) error {
|
||||
srcModelPath := ParseModelPath(src)
|
||||
srcPath, err := srcModelPath.GetManifestPath()
|
||||
@@ -884,7 +882,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
|
||||
var layers []*Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
layers = append(layers, &manifest.Config)
|
||||
layers = append(layers, manifest.Config)
|
||||
|
||||
for _, layer := range layers {
|
||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||
@@ -955,7 +953,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
|
||||
var layers []*Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
layers = append(layers, &manifest.Config)
|
||||
layers = append(layers, manifest.Config)
|
||||
|
||||
for _, layer := range layers {
|
||||
if err := downloadBlob(
|
||||
@@ -1043,30 +1041,6 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptio
|
||||
return m, err
|
||||
}
|
||||
|
||||
func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
|
||||
config.RootFS = RootFS{
|
||||
Type: "layers",
|
||||
DiffIDs: layers,
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
|
||||
|
||||
layer := &LayerReader{
|
||||
Layer: Layer{
|
||||
MediaType: "application/vnd.docker.container.image.v1+json",
|
||||
Digest: digest,
|
||||
Size: size,
|
||||
},
|
||||
Reader: bytes.NewBuffer(configJSON),
|
||||
}
|
||||
return layer, nil
|
||||
}
|
||||
|
||||
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
|
||||
func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||
h := sha256.New()
|
||||
|
@@ -1,23 +1,98 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
func TestModelPrompt(t *testing.T) {
|
||||
var m Model
|
||||
req := api.GenerateRequest{
|
||||
Template: "a{{ .Prompt }}b",
|
||||
Prompt: "<h1>",
|
||||
func TestChat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
msgs []api.Message
|
||||
want string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "Single Message",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a Wizard.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What are the potion ingredients?",
|
||||
},
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Message History",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a Wizard.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What are the potion ingredients?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "sugar",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Anything else?",
|
||||
},
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Assistant Only",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "everything nice",
|
||||
},
|
||||
},
|
||||
want: "[INST] [/INST]everything nice",
|
||||
},
|
||||
{
|
||||
name: "Invalid Role",
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "not-a-role",
|
||||
Content: "howdy",
|
||||
},
|
||||
},
|
||||
wantErr: "invalid role: not-a-role",
|
||||
},
|
||||
}
|
||||
s, err := m.Prompt(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := "a<h1>b"
|
||||
if s != want {
|
||||
t.Errorf("got %q, want %q", s, want)
|
||||
|
||||
for _, tt := range tests {
|
||||
m := Model{
|
||||
Template: tt.template,
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := m.ChatPrompt(tt.msgs)
|
||||
if tt.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Errorf("ChatPrompt() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
109
server/layers.go
Normal file
109
server/layers.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type Layers struct {
|
||||
items []*Layer
|
||||
}
|
||||
|
||||
func (ls *Layers) Add(layer *Layer) {
|
||||
if layer.Size > 0 {
|
||||
ls.items = append(ls.items, layer)
|
||||
}
|
||||
}
|
||||
|
||||
func (ls *Layers) Replace(layer *Layer) {
|
||||
if layer.Size > 0 {
|
||||
mediatype := layer.MediaType
|
||||
layers := slices.DeleteFunc(ls.items, func(l *Layer) bool {
|
||||
return l.MediaType == mediatype
|
||||
})
|
||||
|
||||
ls.items = append(layers, layer)
|
||||
}
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
From string `json:"from,omitempty"`
|
||||
|
||||
tempFileName string
|
||||
}
|
||||
|
||||
func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
|
||||
blobs, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
delimiter := ":"
|
||||
if runtime.GOOS == "windows" {
|
||||
delimiter = "-"
|
||||
}
|
||||
|
||||
pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter)
|
||||
temp, err := os.CreateTemp(blobs, pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer temp.Close()
|
||||
|
||||
sha256sum := sha256.New()
|
||||
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Layer{
|
||||
MediaType: mediatype,
|
||||
Digest: fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)),
|
||||
Size: n,
|
||||
tempFileName: temp.Name(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
|
||||
blob, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fi, err := os.Stat(blob)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Layer{
|
||||
MediaType: mediatype,
|
||||
Digest: digest,
|
||||
Size: fi.Size(),
|
||||
From: from,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l *Layer) Commit() (bool, error) {
|
||||
// always remove temp
|
||||
defer os.Remove(l.tempFileName)
|
||||
|
||||
blob, err := GetBlobsPath(l.Digest)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if _, err := os.Stat(blob); err != nil {
|
||||
return true, os.Rename(l.tempFileName, blob)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
34
server/manifests.go
Normal file
34
server/manifests.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func WriteManifest(name string, config *Layer, layers []*Layer) error {
|
||||
manifest := ManifestV2{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Config: config,
|
||||
Layers: layers,
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(manifest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelpath := ParseModelPath(name)
|
||||
manifestPath, err := modelpath.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(manifestPath, b.Bytes(), 0644)
|
||||
}
|
351
server/routes.go
351
server/routes.go
@@ -2,7 +2,6 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -60,17 +59,26 @@ var loaded struct {
|
||||
var defaultSessionDuration = 5 * time.Minute
|
||||
|
||||
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
||||
func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
|
||||
func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
|
||||
model, err := GetModel(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
workDir := c.GetString("workDir")
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
if err := opts.FromMap(model.Options); err != nil {
|
||||
log.Printf("could not load model options: %v", err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := opts.FromMap(reqOpts); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// check if the loaded model is still running in a subprocess, in case something unexpected happened
|
||||
if loaded.runner != nil {
|
||||
if err := loaded.runner.Ping(ctx); err != nil {
|
||||
@@ -97,7 +105,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
|
||||
loaded.Options = nil
|
||||
}
|
||||
|
||||
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
|
||||
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
|
||||
if err != nil {
|
||||
// some older models are not compatible with newer versions of llama.cpp
|
||||
// show a generalized compatibility error until there is a better way to
|
||||
@@ -106,7 +114,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
|
||||
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
|
||||
}
|
||||
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loaded.Model = model
|
||||
@@ -140,7 +148,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
|
||||
}
|
||||
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
return nil
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func GenerateHandler(c *gin.Context) {
|
||||
@@ -173,88 +181,148 @@ func GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
model, err := GetModel(req.Model)
|
||||
sessionDuration := defaultSessionDuration
|
||||
model, err := load(c, req.Model, req.Options, sessionDuration)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
if errors.As(err, &pErr) {
|
||||
switch {
|
||||
case errors.As(err, &pErr):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
return
|
||||
case errors.Is(err, api.ErrInvalidOpts):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
workDir := c.GetString("workDir")
|
||||
|
||||
// TODO: set this duration from the request if specified
|
||||
sessionDuration := defaultSessionDuration
|
||||
if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil {
|
||||
if errors.Is(err, api.ErrInvalidOpts) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
// an empty request loads the model
|
||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
Done: true})
|
||||
return
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt := req.Prompt
|
||||
if !req.Raw {
|
||||
prompt, err = model.Prompt(req)
|
||||
var prompt string
|
||||
switch {
|
||||
case req.Raw:
|
||||
prompt = req.Prompt
|
||||
case req.Prompt != "":
|
||||
if req.Template != "" {
|
||||
// override the default model template
|
||||
model.Template = req.Template
|
||||
}
|
||||
|
||||
var rebuild strings.Builder
|
||||
if req.Context != nil {
|
||||
// TODO: context is deprecated, at some point the context logic within this conditional should be removed
|
||||
prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Remove leading spaces from prevCtx if present
|
||||
prevCtx = strings.TrimPrefix(prevCtx, " ")
|
||||
rebuild.WriteString(prevCtx)
|
||||
}
|
||||
p, err := model.Prompt(PromptVars{
|
||||
System: req.System,
|
||||
Prompt: req.Prompt,
|
||||
First: len(req.Context) == 0,
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
rebuild.WriteString(p)
|
||||
prompt = rebuild.String()
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
var generated strings.Builder
|
||||
go func() {
|
||||
defer close(ch)
|
||||
// an empty request loads the model
|
||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||
ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
fn := func(r api.GenerateResponse) {
|
||||
fn := func(r llm.PredictResult) {
|
||||
// Update model expiration
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
|
||||
r.Model = req.Model
|
||||
r.CreatedAt = time.Now().UTC()
|
||||
if r.Done {
|
||||
r.TotalDuration = time.Since(checkpointStart)
|
||||
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
// Build up the full response
|
||||
if _, err := generated.WriteString(r.Content); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
if req.Raw {
|
||||
// in raw mode the client must manage history on their own
|
||||
r.Context = nil
|
||||
resp := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: r.CreatedAt,
|
||||
Done: r.Done,
|
||||
Response: r.Content,
|
||||
Metrics: api.Metrics{
|
||||
TotalDuration: r.TotalDuration,
|
||||
LoadDuration: r.LoadDuration,
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
EvalCount: r.EvalCount,
|
||||
EvalDuration: r.EvalDuration,
|
||||
},
|
||||
}
|
||||
|
||||
ch <- r
|
||||
if r.Done && !req.Raw {
|
||||
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
resp.Context = embd
|
||||
}
|
||||
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil {
|
||||
// Start prediction
|
||||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
CheckpointStart: checkpointStart,
|
||||
CheckpointLoaded: checkpointLoaded,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
var response api.GenerateResponse
|
||||
generated := ""
|
||||
// Accumulate responses into the final response
|
||||
var final api.GenerateResponse
|
||||
var sb strings.Builder
|
||||
for resp := range ch {
|
||||
if r, ok := resp.(api.GenerateResponse); ok {
|
||||
generated += r.Response
|
||||
response = r
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
switch r := resp.(type) {
|
||||
case api.GenerateResponse:
|
||||
sb.WriteString(r.Response)
|
||||
final = r
|
||||
case gin.H:
|
||||
if errorMsg, ok := r["error"].(string); ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
||||
return
|
||||
}
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
||||
return
|
||||
}
|
||||
}
|
||||
response.Response = generated
|
||||
c.JSON(http.StatusOK, response)
|
||||
|
||||
final.Response = sb.String()
|
||||
c.JSON(http.StatusOK, final)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -281,15 +349,18 @@ func EmbeddingHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
model, err := GetModel(req.Model)
|
||||
sessionDuration := defaultSessionDuration
|
||||
_, err = load(c, req.Model, req.Options, sessionDuration)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
workDir := c.GetString("workDir")
|
||||
if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
var pErr *fs.PathError
|
||||
switch {
|
||||
case errors.As(err, &pErr):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
case errors.Is(err, api.ErrInvalidOpts):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -676,37 +747,18 @@ func HeadBlobHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func CreateBlobHandler(c *gin.Context) {
|
||||
targetPath, err := GetBlobsPath(c.Param("digest"))
|
||||
layer, err := NewLayer(c.Request.Body, "")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
hash := sha256.New()
|
||||
temp, err := os.CreateTemp(filepath.Dir(targetPath), c.Param("digest")+"-")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer temp.Close()
|
||||
defer os.Remove(temp.Name())
|
||||
|
||||
if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
if layer.Digest != c.Param("digest") {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
|
||||
return
|
||||
}
|
||||
|
||||
if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := temp.Close(); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.Rename(temp.Name(), targetPath); err != nil {
|
||||
if _, err := layer.Commit(); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -767,6 +819,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||
|
||||
r.POST("/api/pull", PullModelHandler)
|
||||
r.POST("/api/generate", GenerateHandler)
|
||||
r.POST("/api/chat", ChatHandler)
|
||||
r.POST("/api/embeddings", EmbeddingHandler)
|
||||
r.POST("/api/create", CreateModelHandler)
|
||||
r.POST("/api/push", PushModelHandler)
|
||||
@@ -782,6 +835,9 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||
})
|
||||
|
||||
r.Handle(method, "/api/tags", ListModelsHandler)
|
||||
r.Handle(method, "/api/version", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"version": version.Version})
|
||||
})
|
||||
}
|
||||
|
||||
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
|
||||
@@ -804,7 +860,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||
if runtime.GOOS == "linux" {
|
||||
// check compatibility to log warnings
|
||||
if _, err := llm.CheckVRAM(); err != nil {
|
||||
log.Printf(err.Error())
|
||||
log.Print(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -860,3 +916,136 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func ChatHandler(c *gin.Context) {
|
||||
loaded.mu.Lock()
|
||||
defer loaded.mu.Unlock()
|
||||
|
||||
checkpointStart := time.Now()
|
||||
|
||||
var req api.ChatRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// validate the request
|
||||
switch {
|
||||
case req.Model == "":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
case len(req.Format) > 0 && req.Format != "json":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionDuration := defaultSessionDuration
|
||||
model, err := load(c, req.Model, req.Options, sessionDuration)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
switch {
|
||||
case errors.As(err, &pErr):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
case errors.Is(err, api.ErrInvalidOpts):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// an empty request loads the model
|
||||
if len(req.Messages) == 0 {
|
||||
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
|
||||
return
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt, err := model.ChatPrompt(req.Messages)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
fn := func(r llm.PredictResult) {
|
||||
// Update model expiration
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: r.CreatedAt,
|
||||
Done: r.Done,
|
||||
Metrics: api.Metrics{
|
||||
TotalDuration: r.TotalDuration,
|
||||
LoadDuration: r.LoadDuration,
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
EvalCount: r.EvalCount,
|
||||
EvalDuration: r.EvalDuration,
|
||||
},
|
||||
}
|
||||
|
||||
if !r.Done {
|
||||
resp.Message = &api.Message{Role: "assistant", Content: r.Content}
|
||||
}
|
||||
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
// Start prediction
|
||||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
CheckpointStart: checkpointStart,
|
||||
CheckpointLoaded: checkpointLoaded,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
// Accumulate responses into the final response
|
||||
var final api.ChatResponse
|
||||
var sb strings.Builder
|
||||
for resp := range ch {
|
||||
switch r := resp.(type) {
|
||||
case api.ChatResponse:
|
||||
if r.Message != nil {
|
||||
sb.WriteString(r.Message.Content)
|
||||
}
|
||||
|
||||
final = r
|
||||
case gin.H:
|
||||
if errorMsg, ok := r["error"].(string); ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
||||
return
|
||||
}
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
final.Message = &api.Message{Role: "assistant", Content: sb.String()}
|
||||
c.JSON(http.StatusOK, final)
|
||||
return
|
||||
}
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
Reference in New Issue
Block a user