Compare commits

...

42 Commits

Author SHA1 Message Date
Quinn Slack
f4432e1dba treat stop as stop sequences, not exact tokens (#442)
The `stop` option to the generate API is a list of sequences that should cause generation to stop. Although these are commonly called "stop tokens", they do not necessarily correspond to LLM tokens (per the LLM's tokenizer). For example, if the caller sends a generate request with `"stop":["\n"]`, then generation should stop on any token containing `\n` (and trim `\n` from the output), not just if the token exactly matches `\n`. If `stop` were interpreted strictly as LLM tokens, then it would require callers of the generate API to know the LLM's tokenizer and enumerate many tokens in the `stop` list.

Fixes https://github.com/jmorganca/ollama/issues/295.
2023-08-30 11:53:42 -04:00
Michael Yang
982c535428 Merge pull request #428 from jmorganca/mxyng/upload-chunks
update upload chunks
2023-08-30 07:47:17 -07:00
Michael Yang
7df342a6ea Merge pull request #421 from jmorganca/mxyng/f16-metal
allow F16 to use metal
2023-08-29 06:32:59 -07:00
Patrick Devine
8bbff2df98 add model IDs (#439) 2023-08-28 20:50:24 -07:00
Michael Yang
16b06699fd remove unused parameter 2023-08-28 18:35:18 -04:00
Michael Yang
246dc65417 loosen http status code checks 2023-08-28 18:34:53 -04:00
Michael Yang
865fceb73c chunked pipe 2023-08-28 18:34:53 -04:00
Michael Yang
72266c7684 bump chunk size to 95MB 2023-08-28 18:34:53 -04:00
Jeffrey Morgan
d3b838ce60 update orca to orca-mini 2023-08-27 13:26:30 -04:00
Michael Yang
e639a12fa1 Merge pull request #412 from jmorganca/mxyng/update-readme
update README.md
2023-08-26 21:26:34 -07:00
Michael Yang
e82fcf30c6 Merge pull request #420 from jmorganca/mxyng/34b-mem-check
add 34b to mem check
2023-08-26 14:15:52 -07:00
Michael Yang
495e8b0a6a Merge pull request #426 from jmorganca/default-template
set default template
2023-08-26 14:15:38 -07:00
Michael Yang
59734ca24d set default template 2023-08-26 12:20:48 -07:00
Jeffrey Morgan
22ab7f5f88 default host to 127.0.0.1, fixes #424 2023-08-26 11:59:28 -07:00
Michael Yang
b25dd1795d allow F16 to use metal
warning F16 uses significantly more memory than quantized model so the
standard requires don't apply.
2023-08-26 08:38:48 -07:00
Michael Yang
304f2b6c96 add 34b to mem check 2023-08-26 08:29:21 -07:00
Quinn Slack
2ecc3a33c3 delete all models (not just 1st) in ollama rm (#415)
Previously, `ollama rm model1 model2 modelN` would only delete `model1`. The other model command-line arguments would be silently ignored. Now, all models mentioned are deleted.
2023-08-26 00:47:56 -07:00
Jeffrey Morgan
ee6e1df118 add codellama to model list in readme 2023-08-25 20:44:26 -07:00
Jeffrey Morgan
177b69a211 add missing entries for 34B 2023-08-25 18:35:35 -07:00
Michael Yang
dad63f0821 Merge pull request #411 from jmorganca/mxyng/34b
patch llama.cpp for 34B
2023-08-25 11:59:05 -07:00
Michael Yang
041f9ad1a1 update README.md 2023-08-25 11:44:25 -07:00
Michael Yang
7a378f8b66 patch llama.cpp for 34B 2023-08-25 10:06:55 -07:00
Michael Yang
de0bdd7f29 Merge pull request #405 from jmorganca/mxyng/34b
add 34b model type
2023-08-24 10:37:22 -07:00
Michael Yang
b1cececb8e add 34b model type 2023-08-24 10:35:44 -07:00
Michael Yang
e0d39fa3bf Merge pull request #398 from jmorganca/mxyng/cleanup
Mxyng/cleanup
2023-08-22 15:51:41 -07:00
Michael Yang
968ced2e71 Merge pull request #393 from jmorganca/mxyng/net-url
use url.URL
2023-08-22 15:51:33 -07:00
Michael Yang
32d1a00017 remove unused requestContextKey 2023-08-22 10:49:54 -07:00
Michael Yang
04e2128273 move upload funcs to upload.go 2023-08-22 10:49:53 -07:00
Michael Yang
2cc634689b use url.URL 2023-08-22 10:49:07 -07:00
Michael Yang
8f827641b0 Merge pull request #397 from jmorganca/mxyng/release-mode
build release mode
2023-08-22 10:48:44 -07:00
Michael Yang
95187d7e1e build release mode 2023-08-22 09:52:43 -07:00
Michael Yang
9ec7e37534 Merge pull request #392 from jmorganca/mxyng/version
add version
2023-08-22 09:50:25 -07:00
Michael Yang
2c7f956b38 add version 2023-08-22 09:40:58 -07:00
Jeffrey Morgan
a9f6c56652 fix FROM instruction erroring when referring to a file 2023-08-22 09:39:42 -07:00
Ryan Baker
0a892419ad Strip protocol from model path (#377) 2023-08-21 21:56:56 -07:00
Jeffrey Morgan
e3054fc74e add .env to .dockerignore 2023-08-21 09:32:02 -07:00
Michael Yang
23c2485044 Merge pull request #381 from jmorganca/mxyng/fix-push-chunks
retry on unauthorized chunk push
2023-08-18 13:49:25 -07:00
Michael Yang
386c66f285 Merge pull request #378 from jmorganca/mxyng/copy-metadata-from-source
copy metadata from source
2023-08-18 13:49:09 -07:00
Michael Yang
3b49315f97 retry on unauthorized chunk push
The token printed for authorized requests has a lifetime of 1h. If an
upload exceeds 1h, a chunk push will fail since the token is created on
a "start upload" request.

This replaces the Pipe with SectionReader which is simpler and
implements Seek, a requirement for makeRequestWithRetry. This is
slightly worse than using a Pipe since the progress update is directly
tied to the chunk size instead of controlled separately.
2023-08-18 11:23:47 -07:00
Michael Yang
5ca05c2e88 fix ModelType() 2023-08-18 11:23:38 -07:00
Michael Yang
7eda70f23b copy metadata from source 2023-08-17 21:55:25 -07:00
Jeffrey Morgan
3d79b414d3 app: package ggml-metal.metal from correct directory 2023-08-17 23:55:45 -04:00
23 changed files with 665 additions and 317 deletions

View File

@@ -4,4 +4,5 @@ llama/build
.vscode
ollama
app
web
web
.env

View File

@@ -29,9 +29,9 @@ ollama run llama2
## Model library
Ollama supports a list of open-source models available on [ollama.ai/library](https://ollama.ai/library "ollama model library")
Ollama supports a list of open-source models available on [ollama.ai/library](https://ollama.ai/library 'ollama model library')
Here are some example open-source models that can be downloaded:
Here are some example open-source models that can be downloaded:
| Model | Parameters | Size | Download |
| ------------------------ | ---------- | ----- | ------------------------------- |
@@ -39,6 +39,7 @@ Here are some example open-source models that can be downloaded:
| Llama2 13B | 13B | 7.3GB | `ollama pull llama2:13b` |
| Llama2 70B | 70B | 39GB | `ollama pull llama2:70b` |
| Llama2 Uncensored | 7B | 3.8GB | `ollama pull llama2-uncensored` |
| Code Llama | 7B | 3.8GB | `ollama pull codellama` |
| Orca Mini | 3B | 1.9GB | `ollama pull orca-mini` |
| Vicuna | 7B | 3.8GB | `ollama pull vicuna` |
| Nous-Hermes | 7B | 3.8GB | `ollama pull nous-hermes` |
@@ -104,7 +105,7 @@ For more examples, see the [examples](./examples) directory. For more informatio
### Pull a model from the registry
```
ollama pull orca
ollama pull orca-mini
```
### Listing local models
@@ -126,6 +127,8 @@ Ollama bundles model weights, configuration, and data into a single package, def
## Building
You will also need a C/C++ compiler such as GCC for MacOS and Linux or Mingw-w64 GCC for Windows.
```
go build .
```

View File

@@ -10,10 +10,13 @@ import (
"net/http"
"net/url"
"os"
"runtime"
"strings"
"github.com/jmorganca/ollama/version"
)
const DefaultHost = "localhost:11434"
const DefaultHost = "127.0.0.1:11434"
var (
envHost = os.Getenv("OLLAMA_HOST")
@@ -26,7 +29,7 @@ type Client struct {
}
func checkError(resp *http.Response, body []byte) error {
if resp.StatusCode >= 200 && resp.StatusCode < 400 {
if resp.StatusCode < http.StatusBadRequest {
return nil
}
@@ -83,21 +86,21 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
reqBody = bytes.NewReader(data)
}
url := c.Base.JoinPath(path).String()
req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
requestURL := c.Base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
for k, v := range c.Headers {
req.Header[k] = v
request.Header[k] = v
}
respObj, err := c.HTTP.Do(req)
respObj, err := c.HTTP.Do(request)
if err != nil {
return err
}
@@ -131,13 +134,15 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
buf = bytes.NewBuffer(bts)
}
request, err := http.NewRequestWithContext(ctx, method, c.Base.JoinPath(path).String(), buf)
requestURL := c.Base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
if err != nil {
return err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
response, err := http.DefaultClient.Do(request)
if err != nil {
@@ -160,7 +165,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
return fmt.Errorf(errorResponse.Error)
}
if response.StatusCode >= 400 {
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,

View File

@@ -96,6 +96,7 @@ type ListResponseModel struct {
Name string `json:"name"`
ModifiedAt time.Time `json:"modified_at"`
Size int `json:"size"`
Digest string `json:"digest"`
}
type TokenResponse struct {

View File

@@ -27,7 +27,7 @@ const config: ForgeConfig = {
path.join(__dirname, './assets/iconDarkTemplate@2x.png'),
path.join(__dirname, './assets/iconDarkUpdateTemplate.png'),
path.join(__dirname, './assets/iconDarkUpdateTemplate@2x.png'),
...(process.platform === 'darwin' ? ['../llama/ggml-metal.metal'] : []),
...(process.platform === 'darwin' ? ['../llm/ggml-metal.metal'] : []),
],
...(process.env.SIGN
? {

View File

@@ -30,6 +30,7 @@ import (
"github.com/jmorganca/ollama/format"
"github.com/jmorganca/ollama/progressbar"
"github.com/jmorganca/ollama/server"
"github.com/jmorganca/ollama/version"
)
func CreateHandler(cmd *cobra.Command, args []string) error {
@@ -97,7 +98,20 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
func RunHandler(cmd *cobra.Command, args []string) error {
insecure, err := cmd.Flags().GetBool("insecure")
if err != nil {
return err
}
mp := server.ParseModelPath(args[0])
if err != nil {
return err
}
if mp.ProtocolScheme == "http" && !insecure {
return fmt.Errorf("insecure protocol http")
}
fp, err := mp.GetManifestPath(false)
if err != nil {
return err
@@ -106,7 +120,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
_, err = os.Stat(fp)
switch {
case errors.Is(err, os.ErrNotExist):
if err := pull(args[0], false); err != nil {
if err := pull(args[0], insecure); err != nil {
var apiStatusError api.StatusError
if !errors.As(err, &apiStatusError) {
return err
@@ -182,12 +196,12 @@ func ListHandler(cmd *cobra.Command, args []string) error {
for _, m := range models.Models {
if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
data = append(data, []string{m.Name, humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")})
data = append(data, []string{m.Name, m.Digest[:12], humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")})
}
}
table := tablewriter.NewWriter(os.Stdout)
table.SetHeader([]string{"NAME", "SIZE", "MODIFIED"})
table.SetHeader([]string{"NAME", "ID", "SIZE", "MODIFIED"})
table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
table.SetAlignment(tablewriter.ALIGN_LEFT)
table.SetHeaderLine(false)
@@ -206,11 +220,13 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
return err
}
req := api.DeleteRequest{Name: args[0]}
if err := client.Delete(context.Background(), &req); err != nil {
return err
for _, name := range args {
req := api.DeleteRequest{Name: name}
if err := client.Delete(context.Background(), &req); err != nil {
return err
}
fmt.Printf("deleted '%s'\n", name)
}
fmt.Printf("deleted '%s'\n", args[0])
return nil
}
@@ -507,7 +523,11 @@ func generateInteractive(cmd *cobra.Command, model string) error {
args := strings.Fields(line)
if len(args) > 1 {
mp := server.ParseModelPath(model)
manifest, err := server.GetManifest(mp)
if err != nil {
return err
}
manifest, _, err := server.GetManifest(mp)
if err != nil {
fmt.Println("error: couldn't get a manifest for this model")
continue
@@ -569,7 +589,7 @@ func generateBatch(cmd *cobra.Command, model string) error {
}
func RunServer(cmd *cobra.Command, _ []string) error {
var host, port = "127.0.0.1", "11434"
host, port := "127.0.0.1", "11434"
parts := strings.Split(os.Getenv("OLLAMA_HOST"), ":")
if ip := net.ParseIP(parts[0]); ip != nil {
@@ -630,7 +650,7 @@ func initializeKeypair() error {
return fmt.Errorf("could not create directory %w", err)
}
err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0600)
err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0o600)
if err != nil {
return err
}
@@ -642,7 +662,7 @@ func initializeKeypair() error {
pubKeyData := ssh.MarshalAuthorizedKey(sshPrivateKey.PublicKey())
err = os.WriteFile(pubKeyPath, pubKeyData, 0644)
err = os.WriteFile(pubKeyPath, pubKeyData, 0o644)
if err != nil {
return err
}
@@ -714,6 +734,7 @@ func NewCLI() *cobra.Command {
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: true,
},
Version: version.Version,
}
cobra.EnableCommandSorting = false
@@ -737,6 +758,7 @@ func NewCLI() *cobra.Command {
}
runCmd.Flags().Bool("verbose", false, "Show timings for response")
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
serveCmd := &cobra.Command{
Use: "serve",

View File

@@ -14,7 +14,7 @@
### Model names
Model names follow a `model:tag` format. Some examples are `orca:3b-q4_1` and `llama2:70b`. The tag is optional and if not provided will default to `latest`. The tag is used to identify a specific version.
Model names follow a `model:tag` format. Some examples are `orca-mini:3b-q4_1` and `llama2:70b`. The tag is optional and if not provided will default to `latest`. The tag is used to identify a specific version.
### Durations

View File

@@ -25,20 +25,3 @@ Now you can run `ollama`:
```
./ollama
```
## Releasing
To release a new version of Ollama you'll need to set some environment variables:
- `GITHUB_TOKEN`: your GitHub token
- `APPLE_IDENTITY`: the Apple signing identity (macOS only)
- `APPLE_ID`: your Apple ID
- `APPLE_PASSWORD`: your Apple ID app-specific password
- `APPLE_TEAM_ID`: the Apple team ID for the signing identity
- `TELEMETRY_WRITE_KEY`: segment write key for telemetry
Then run the publish script with the target version:
```
VERSION=0.0.2 ./scripts/publish.sh
```

View File

@@ -123,7 +123,7 @@ PARAMETER <parameter> <parametervalue>
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
| stop | Sets the stop tokens to use. | string | stop "AI assistant:" |
| stop | Sets the stop sequences to use. | string | stop "AI assistant:" |
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |

View File

@@ -15,6 +15,7 @@ const (
ModelType3B ModelType = 26
ModelType7B ModelType = 32
ModelType13B ModelType = 40
ModelType34B ModelType = 48
ModelType30B ModelType = 60
ModelType65B ModelType = 80
)
@@ -27,6 +28,8 @@ func (mt ModelType) String() string {
return "7B"
case ModelType13B:
return "13B"
case ModelType34B:
return "34B"
case ModelType30B:
return "30B"
case ModelType65B:

View File

@@ -105,6 +105,7 @@ enum e_model {
MODEL_7B,
MODEL_13B,
MODEL_30B,
MODEL_34B,
MODEL_65B,
MODEL_70B,
};
@@ -148,6 +149,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx)
{ MODEL_7B, ((size_t) n_ctx / 16ull + 100ull) * MB },
{ MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB },
{ MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB },
{ MODEL_34B, ((size_t) n_ctx / 9ull + 160ull) * MB },
{ MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess
{ MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB },
};
@@ -161,6 +163,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
{ MODEL_7B, 160ull * MB },
{ MODEL_13B, 192ull * MB },
{ MODEL_30B, 256ull * MB },
{ MODEL_34B, 256ull * MB },
{ MODEL_65B, 384ull * MB }, // guess
{ MODEL_70B, 304ull * MB },
};
@@ -175,6 +178,7 @@ static const std::map<e_model, size_t> & MEM_REQ_EVAL()
{ MODEL_7B, 10ull * MB },
{ MODEL_13B, 12ull * MB },
{ MODEL_30B, 16ull * MB },
{ MODEL_34B, 16ull * MB },
{ MODEL_65B, 24ull * MB }, // guess
{ MODEL_70B, 24ull * MB },
};
@@ -190,6 +194,7 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE()
{ MODEL_7B, 512ull * kB },
{ MODEL_13B, 640ull * kB },
{ MODEL_30B, 768ull * kB },
{ MODEL_34B, 768ull * kB },
{ MODEL_65B, 1280ull * kB },
{ MODEL_70B, 1280ull * kB },
};
@@ -205,6 +210,7 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
{ MODEL_7B, 128ull },
{ MODEL_13B, 160ull },
{ MODEL_30B, 208ull },
{ MODEL_34B, 208ull },
{ MODEL_65B, 256ull },
{ MODEL_70B, 256ull },
};
@@ -1053,6 +1059,7 @@ static const char *llama_model_type_name(e_model type) {
case MODEL_7B: return "7B";
case MODEL_13B: return "13B";
case MODEL_30B: return "30B";
case MODEL_34B: return "34B";
case MODEL_65B: return "65B";
case MODEL_70B: return "70B";
default: LLAMA_ASSERT(false);
@@ -1100,6 +1107,7 @@ static void llama_model_load_internal(
case 26: model.type = e_model::MODEL_3B; break;
case 32: model.type = e_model::MODEL_7B; break;
case 40: model.type = e_model::MODEL_13B; break;
case 48: model.type = e_model::MODEL_34B; break;
case 60: model.type = e_model::MODEL_30B; break;
case 80: model.type = e_model::MODEL_65B; break;
default:
@@ -1120,6 +1128,8 @@ static void llama_model_load_internal(
LLAMA_LOG_WARN("%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa);
model.type = e_model::MODEL_70B;
hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model
} else if (model.type == e_model::MODEL_34B && n_gqa == 8) {
hparams.f_ffn_mult = 1.0f; // from the params.json of the 34B model
}
hparams.rope_freq_base = rope_freq_base;

View File

@@ -117,7 +117,21 @@ func (llm *llamaModel) ModelFamily() ModelFamily {
}
func (llm *llamaModel) ModelType() ModelType {
return ModelType30B
switch llm.hyperparameters.NumLayer {
case 26:
return ModelType3B
case 32:
return ModelType7B
case 40:
return ModelType13B
case 60:
return ModelType30B
case 80:
return ModelType65B
}
// TODO: find a better default
return ModelType7B
}
func (llm *llamaModel) FileType() FileType {
@@ -320,20 +334,18 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse
b.WriteString(llm.Decode(int(token)))
if err := llm.checkStopConditions(b); err != nil {
if errors.Is(err, io.EOF) {
break
} else if errors.Is(err, errNeedMoreData) {
continue
}
return err
stop, endsWithStopPrefix := handleStopSequences(&b, llm.Stop)
if endsWithStopPrefix {
continue
}
if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
fn(api.GenerateResponse{Response: b.String()})
b.Reset()
}
if stop {
break
}
}
embd := make([]int, len(llm.embd))
@@ -356,16 +368,31 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse
return nil
}
func (llm *llama) checkStopConditions(b bytes.Buffer) error {
for _, stopCondition := range llm.Stop {
if stopCondition == strings.TrimSpace(b.String()) {
return io.EOF
} else if strings.HasPrefix(stopCondition, strings.TrimSpace(b.String())) {
return errNeedMoreData
// handleStopSequences checks whether b contains any of the stop sequences, or ends with a prefix of
// any stop sequence (and therefore might contain data that should not ultimately be returned to the
// client).
//
// If b contains a stop sequence, it modifies b to remove the stop sequence and all subsequent data.
func handleStopSequences(b *bytes.Buffer, stopSequences []string) (stop bool, endsWithStopPrefix bool) {
s := b.String()
for _, seq := range stopSequences {
// Check for an exact or substring match.
if i := strings.Index(s, seq); i != -1 {
b.Truncate(i)
return true, false
}
// Check if b ends with a prefix of the stop sequence.
if len(seq) > 1 {
for i := 1; i < len(seq); i++ {
if strings.HasSuffix(s, seq[:i]) {
return false, true
}
}
}
}
return nil
return false, false
}
func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token {

79
llm/llama_test.go Normal file
View File

@@ -0,0 +1,79 @@
package llm
import (
"bytes"
"testing"
)
func TestCheckStopConditions(t *testing.T) {
tests := map[string]struct {
b string
stop []string
wantB string
wantStop bool
wantEndsWithStopPrefix bool
}{
"not present": {
b: "abc",
stop: []string{"x"},
wantStop: false,
wantEndsWithStopPrefix: false,
},
"exact": {
b: "abc",
stop: []string{"abc"},
wantStop: true,
wantEndsWithStopPrefix: false,
},
"substring": {
b: "abc",
stop: []string{"b"},
wantB: "a",
wantStop: true,
wantEndsWithStopPrefix: false,
},
"prefix 1": {
b: "abc",
stop: []string{"abcd"},
wantStop: false,
wantEndsWithStopPrefix: true,
},
"prefix 2": {
b: "abc",
stop: []string{"bcd"},
wantStop: false,
wantEndsWithStopPrefix: true,
},
"prefix 3": {
b: "abc",
stop: []string{"cd"},
wantStop: false,
wantEndsWithStopPrefix: true,
},
"no prefix": {
b: "abc",
stop: []string{"bx"},
wantStop: false,
wantEndsWithStopPrefix: false,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
var b bytes.Buffer
b.WriteString(test.b)
stop, endsWithStopPrefix := handleStopSequences(&b, test.stop)
if test.wantB != "" {
gotB := b.String()
if gotB != test.wantB {
t.Errorf("got b %q, want %q", gotB, test.wantB)
}
}
if stop != test.wantStop {
t.Errorf("got stop %v, want %v", stop, test.wantStop)
}
if endsWithStopPrefix != test.wantEndsWithStopPrefix {
t.Errorf("got endsWithStopPrefix %v, want %v", endsWithStopPrefix, test.wantEndsWithStopPrefix)
}
})
}
}

View File

@@ -36,11 +36,11 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
}
switch ggml.FileType().String() {
case "F32", "F16", "Q5_0", "Q5_1", "Q8_0":
case "F32", "Q5_0", "Q5_1", "Q8_0":
if opts.NumGPU != 0 {
// F32, F16, Q5_0, Q5_1, and Q8_0 do not support Metal API and will
// cause the runner to segmentation fault so disable GPU
log.Printf("WARNING: GPU disabled for F32, F16, Q5_0, Q5_1, and Q8_0")
log.Printf("WARNING: GPU disabled for F32, Q5_0, Q5_1, and Q8_0")
opts.NumGPU = 0
}
}
@@ -48,19 +48,27 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
totalResidentMemory := memory.TotalMemory()
switch ggml.ModelType() {
case ModelType3B, ModelType7B:
if totalResidentMemory < 8*1024*1024 {
if ggml.FileType().String() == "F16" && totalResidentMemory < 16*1024*1024 {
return nil, fmt.Errorf("F16 model requires at least 16GB of memory")
} else if totalResidentMemory < 8*1024*1024 {
return nil, fmt.Errorf("model requires at least 8GB of memory")
}
case ModelType13B:
if totalResidentMemory < 16*1024*1024 {
if ggml.FileType().String() == "F16" && totalResidentMemory < 32*1024*1024 {
return nil, fmt.Errorf("F16 model requires at least 32GB of memory")
} else if totalResidentMemory < 16*1024*1024 {
return nil, fmt.Errorf("model requires at least 16GB of memory")
}
case ModelType30B:
if totalResidentMemory < 32*1024*1024 {
case ModelType30B, ModelType34B:
if ggml.FileType().String() == "F16" && totalResidentMemory < 64*1024*1024 {
return nil, fmt.Errorf("F16 model requires at least 64GB of memory")
} else if totalResidentMemory < 32*1024*1024 {
return nil, fmt.Errorf("model requires at least 32GB of memory")
}
case ModelType65B:
if totalResidentMemory < 64*1024*1024 {
if ggml.FileType().String() == "F16" && totalResidentMemory < 128*1024*1024 {
return nil, fmt.Errorf("F16 model requires at least 128GB of memory")
} else if totalResidentMemory < 64*1024*1024 {
return nil, fmt.Errorf("model requires at least 64GB of memory")
}
}

View File

@@ -2,9 +2,12 @@
mkdir -p dist
GO_LDFLAGS="-X github.com/jmorganca/ollama/version.Version=$VERSION"
GO_LDFLAGS="$GO_LDFLAGS -X github.com/jmorganca/ollama/server.mode=release"
# build universal binary
CGO_ENABLED=1 GOARCH=arm64 go build -o dist/ollama-darwin-arm64
CGO_ENABLED=1 GOARCH=amd64 go build -o dist/ollama-darwin-amd64
CGO_ENABLED=1 GOARCH=arm64 go build -ldflags "$GO_LDFLAGS" -o dist/ollama-darwin-arm64
CGO_ENABLED=1 GOARCH=amd64 go build -ldflags "$GO_LDFLAGS" -o dist/ollama-darwin-amd64
lipo -create -output dist/ollama dist/ollama-darwin-arm64 dist/ollama-darwin-amd64
rm dist/ollama-darwin-amd64 dist/ollama-darwin-arm64
codesign --deep --force --options=runtime --sign "$APPLE_IDENTITY" --timestamp dist/ollama

View File

@@ -12,8 +12,10 @@ import (
"io"
"log"
"net/http"
"net/url"
"os"
"path"
"strconv"
"strings"
"time"
@@ -43,21 +45,34 @@ func generateNonce(length int) (string, error) {
return base64.RawURLEncoding.EncodeToString(nonce), nil
}
func (r AuthRedirect) URL() (string, error) {
func (r AuthRedirect) URL() (*url.URL, error) {
redirectURL, err := url.Parse(r.Realm)
if err != nil {
return nil, err
}
values := redirectURL.Query()
values.Add("service", r.Service)
for _, s := range strings.Split(r.Scope, " ") {
values.Add("scope", s)
}
values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
nonce, err := generateNonce(16)
if err != nil {
return "", err
return nil, err
}
scopes := []string{}
for _, s := range strings.Split(r.Scope, " ") {
scopes = append(scopes, fmt.Sprintf("scope=%s", s))
}
scopeStr := strings.Join(scopes, "&")
return fmt.Sprintf("%s?service=%s&%s&ts=%d&nonce=%s", r.Realm, r.Service, scopeStr, time.Now().Unix(), nonce), nil
values.Add("nonce", nonce)
redirectURL.RawQuery = values.Encode()
return redirectURL, nil
}
func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *RegistryOptions) (string, error) {
url, err := redirData.URL()
redirectURL, err := redirData.URL()
if err != nil {
return "", err
}
@@ -77,34 +92,24 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *Registry
s := SignatureData{
Method: "GET",
Path: url,
Path: redirectURL.String(),
Data: nil,
}
if !strings.HasPrefix(s.Path, "http") {
if regOpts.Insecure {
s.Path = "http://" + url
} else {
s.Path = "https://" + url
}
}
sig, err := s.Sign(rawKey)
if err != nil {
return "", err
}
headers := map[string]string{
"Authorization": sig,
}
resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
headers := make(http.Header)
headers.Set("Authorization", sig)
resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, regOpts)
if err != nil {
log.Printf("couldn't get token: %q", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode >= http.StatusBadRequest {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
}

View File

@@ -155,19 +155,20 @@ func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
}
}
url := fmt.Sprintf("%s/v2/%s/blobs/%s", opts.mp.Registry, opts.mp.GetNamespaceRepository(), f.Digest)
headers := map[string]string{
"Range": fmt.Sprintf("bytes=%d-", size),
}
requestURL := opts.mp.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", f.Digest)
resp, err := makeRequest(ctx, "GET", url, headers, nil, opts.regOpts)
headers := make(http.Header)
headers.Set("Range", fmt.Sprintf("bytes=%d-", size))
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts.regOpts)
if err != nil {
log.Printf("couldn't download blob: %v", err)
return fmt.Errorf("%w: %w", errDownload, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
if resp.StatusCode >= http.StatusBadRequest {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%w: on download registry responded with code %d: %v", errDownload, resp.StatusCode, string(body))
}

View File

@@ -5,6 +5,7 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -12,10 +13,12 @@ import (
"io"
"log"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
@@ -23,6 +26,7 @@ import (
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/vector"
"github.com/jmorganca/ollama/version"
)
const MaxRetries = 3
@@ -41,6 +45,7 @@ type Model struct {
Template string
System string
Digest string
ConfigDigest string
Options map[string]interface{}
Embeddings []vector.Embedding
}
@@ -128,41 +133,45 @@ func (m *ManifestV2) GetTotalSize() int {
return total
}
func GetManifest(mp ModelPath) (*ManifestV2, error) {
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
fp, err := mp.GetManifestPath(false)
if err != nil {
return nil, err
return nil, "", err
}
if _, err = os.Stat(fp); err != nil {
return nil, err
return nil, "", err
}
var manifest *ManifestV2
bts, err := os.ReadFile(fp)
if err != nil {
return nil, fmt.Errorf("couldn't open file '%s'", fp)
return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
}
shaSum := sha256.Sum256(bts)
shaStr := hex.EncodeToString(shaSum[:])
if err := json.Unmarshal(bts, &manifest); err != nil {
return nil, err
return nil, "", err
}
return manifest, nil
return manifest, shaStr, nil
}
func GetModel(name string) (*Model, error) {
mp := ParseModelPath(name)
manifest, err := GetManifest(mp)
manifest, digest, err := GetManifest(mp)
if err != nil {
return nil, err
}
model := &Model{
Name: mp.GetFullTagname(),
Digest: manifest.Config.Digest,
Name: mp.GetFullTagname(),
Digest: digest,
ConfigDigest: manifest.Config.Digest,
Template: "{{ .Prompt }}",
}
for _, layer := range manifest.Layers {
@@ -272,8 +281,9 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
case "model":
fn(api.ProgressResponse{Status: "looking for model"})
embed.model = c.Args
mp := ParseModelPath(c.Args)
mf, err := GetManifest(mp)
mf, _, err := GetManifest(mp)
if err != nil {
modelFile, err := filenameWithPath(path, c.Args)
if err != nil {
@@ -286,7 +296,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
return err
}
mf, err = GetManifest(ParseModelPath(c.Args))
mf, _, err = GetManifest(mp)
if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err)
}
@@ -325,7 +335,27 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
}
if mf != nil {
log.Printf("manifest = %#v", mf)
sourceBlobPath, err := GetBlobsPath(mf.Config.Digest)
if err != nil {
return err
}
sourceBlob, err := os.Open(sourceBlobPath)
if err != nil {
return err
}
defer sourceBlob.Close()
var source ConfigV2
if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil {
return err
}
// copie the model metadata
config.ModelFamily = source.ModelFamily
config.ModelType = source.ModelType
config.FileType = source.FileType
for _, l := range mf.Layers {
newLayer, err := GetLayerWithBufferFromLayer(l)
if err != nil {
@@ -400,7 +430,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
layer.MediaType = mediaType
layers = append(layers, layer)
default:
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens)
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences)
params[c.Name] = append(params[c.Name], c.Args)
}
}
@@ -655,7 +685,6 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
mp := ParseModelPath(name)
manifest := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
@@ -786,11 +815,14 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
}
func CopyModel(src, dest string) error {
srcPath, err := ParseModelPath(src).GetManifestPath(false)
srcModelPath := ParseModelPath(src)
srcPath, err := srcModelPath.GetManifestPath(false)
if err != nil {
return err
}
destPath, err := ParseModelPath(dest).GetManifestPath(true)
destModelPath := ParseModelPath(dest)
destPath, err := destModelPath.GetManifestPath(true)
if err != nil {
return err
}
@@ -813,8 +845,7 @@ func CopyModel(src, dest string) error {
func DeleteModel(name string) error {
mp := ParseModelPath(name)
manifest, err := GetManifest(mp)
manifest, _, err := GetManifest(mp)
if err != nil {
return err
}
@@ -847,7 +878,7 @@ func DeleteModel(name string) error {
}
// save (i.e. delete from the deleteMap) any files used in other manifests
manifest, err := GetManifest(fmp)
manifest, _, err := GetManifest(fmp)
if err != nil {
log.Printf("skipping file: %s", fp)
return nil
@@ -893,10 +924,13 @@ func DeleteModel(name string) error {
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
manifest, err := GetManifest(mp)
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
}
manifest, _, err := GetManifest(mp)
if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err
@@ -935,8 +969,8 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return err
}
if strings.HasPrefix(path.Base(location), "sha256:") {
layer.Digest = path.Base(location)
if strings.HasPrefix(path.Base(location.Path), "sha256:") {
layer.Digest = path.Base(location.Path)
fn(api.ProgressResponse{
Status: "using existing layer",
Digest: layer.Digest,
@@ -946,24 +980,24 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
continue
}
if err := uploadBlobChunked(ctx, mp, location, layer, regOpts, fn); err != nil {
if err := uploadBlobChunked(ctx, location, layer, regOpts, fn); err != nil {
log.Printf("error uploading blob: %v", err)
return err
}
}
fn(api.ProgressResponse{Status: "pushing manifest"})
url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
headers := map[string]string{
"Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
}
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
resp, err := makeRequestWithRetry(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
headers := make(http.Header)
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
if err != nil {
return err
}
@@ -977,6 +1011,10 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
}
fn(api.ProgressResponse{Status: "pulling manifest"})
manifest, err := pullModelManifest(ctx, mp, regOpts)
@@ -1043,23 +1081,22 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
headers := map[string]string{
"Accept": "application/vnd.docker.distribution.manifest.v2+json",
}
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts)
if err != nil {
log.Printf("couldn't get manifest: %v", err)
return nil, err
}
defer resp.Body.Close()
// Check for success: For a successful upload, the Docker registry will respond with a 201 Created
if resp.StatusCode != http.StatusOK {
if resp.StatusCode >= http.StatusBadRequest {
if resp.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("model not found")
}
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
}
@@ -1107,35 +1144,12 @@ func GetSHA256Digest(r io.Reader) (string, int) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
}
type requestContextKey string
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) {
url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
if layer.From != "" {
url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From)
}
resp, err := makeRequestWithRetry(ctx, "POST", url, nil, nil, regOpts)
if err != nil {
log.Printf("couldn't start upload: %v", err)
return "", err
}
defer resp.Body.Close()
// Extract UUID location from header
location := resp.Header.Get("Location")
if location == "" {
return "", fmt.Errorf("location header is missing in response")
}
return location, nil
}
// Function to check if a blob already exists in the Docker registry
func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest)
resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts)
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts)
if err != nil {
log.Printf("couldn't check for blob: %v", err)
return false, err
@@ -1143,113 +1157,13 @@ func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpt
defer resp.Body.Close()
// Check for success: If the blob exists, the Docker registry will respond with a 200 OK
return resp.StatusCode == http.StatusOK, nil
return resp.StatusCode < http.StatusBadRequest, nil
}
func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
// TODO allow resumability
// TODO allow canceling uploads via DELETE
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
f, err := os.Open(fp)
if err != nil {
return err
}
defer f.Close()
completed := 0
chunkSize := 10 * 1024 * 1024
for {
r, w := io.Pipe()
defer r.Close()
limit := completed + chunkSize
if chunkSize >= layer.Size-completed {
limit = layer.Size
chunkSize = layer.Size - completed
}
go func() {
defer w.Close()
for {
n, err := io.CopyN(w, f, 1024*1024)
if err != nil && !errors.Is(err, io.EOF) {
fn(api.ProgressResponse{
Status: fmt.Sprintf("error copying pipe: %v", err),
Digest: layer.Digest,
Total: layer.Size,
Completed: completed,
})
return
}
completed += int(n)
fn(api.ProgressResponse{
Status: fmt.Sprintf("uploading %s", layer.Digest),
Digest: layer.Digest,
Total: layer.Size,
Completed: completed,
})
if completed >= limit {
return
}
}
}()
headers := make(map[string]string)
headers["Content-Type"] = "application/octet-stream"
headers["Content-Length"] = strconv.Itoa(chunkSize)
headers["Content-Range"] = fmt.Sprintf("%d-%d", completed, limit-1)
resp, err := makeRequest(ctx, "PATCH", url, headers, r, regOpts)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusAccepted {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
}
url = resp.Header.Get("Location")
if completed >= layer.Size {
break
}
}
url = fmt.Sprintf("%s&digest=%s", url, layer.Digest)
headers := make(map[string]string)
headers["Content-Type"] = "application/octet-stream"
headers["Content-Length"] = "0"
// finish the upload
resp, err := makeRequest(ctx, "PUT", url, headers, nil, regOpts)
if err != nil {
log.Printf("couldn't finish upload: %v", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
}
return nil
}
func makeRequestWithRetry(ctx context.Context, method, url string, headers map[string]string, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
var status string
for try := 0; try < MaxRetries; try++ {
resp, err := makeRequest(ctx, method, url, headers, body, regOpts)
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil {
log.Printf("couldn't start upload: %v", err)
return nil, err
@@ -1257,10 +1171,8 @@ func makeRequestWithRetry(ctx context.Context, method, url string, headers map[s
status = resp.Status
switch resp.StatusCode {
case http.StatusAccepted, http.StatusCreated:
return resp, nil
case http.StatusUnauthorized:
switch {
case resp.StatusCode == http.StatusUnauthorized:
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir, regOpts)
@@ -1276,38 +1188,38 @@ func makeRequestWithRetry(ctx context.Context, method, url string, headers map[s
}
continue
default:
case resp.StatusCode >= http.StatusBadRequest:
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
default:
return resp, nil
}
}
return nil, fmt.Errorf("max retry exceeded: %v", status)
}
func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
if !strings.HasPrefix(url, "http") {
if regOpts.Insecure {
url = "http://" + url
} else {
url = "https://" + url
}
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
if requestURL.Scheme != "http" && regOpts.Insecure {
requestURL.Scheme = "http"
}
req, err := http.NewRequestWithContext(ctx, method, url, body)
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
if err != nil {
return nil, err
}
if headers != nil {
req.Header = headers
}
if regOpts.Token != "" {
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
} else if regOpts.Username != "" && regOpts.Password != "" {
req.SetBasicAuth(regOpts.Username, regOpts.Password)
}
for k, v := range headers {
req.Header.Set(k, v)
}
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {

View File

@@ -1,7 +1,9 @@
package server
import (
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"runtime"
@@ -23,42 +25,46 @@ const (
DefaultProtocolScheme = "https"
)
var (
ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
)
func ParseModelPath(name string) ModelPath {
slashParts := strings.Split(name, "/")
var registry, namespace, repository, tag string
switch len(slashParts) {
case 3:
registry = slashParts[0]
namespace = slashParts[1]
repository = strings.Split(slashParts[2], ":")[0]
case 2:
registry = DefaultRegistry
namespace = slashParts[0]
repository = strings.Split(slashParts[1], ":")[0]
case 1:
registry = DefaultRegistry
namespace = DefaultNamespace
repository = strings.Split(slashParts[0], ":")[0]
default:
fmt.Println("Invalid image format.")
return ModelPath{}
}
colonParts := strings.Split(slashParts[len(slashParts)-1], ":")
if len(colonParts) == 2 {
tag = colonParts[1]
} else {
tag = DefaultTag
}
return ModelPath{
mp := ModelPath{
ProtocolScheme: DefaultProtocolScheme,
Registry: registry,
Namespace: namespace,
Repository: repository,
Tag: tag,
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "",
Tag: DefaultTag,
}
before, after, found := strings.Cut(name, "://")
if found {
mp.ProtocolScheme = before
name = after
}
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
mp.Registry = parts[0]
mp.Namespace = parts[1]
mp.Repository = parts[2]
case 2:
mp.Namespace = parts[0]
mp.Repository = parts[1]
case 1:
mp.Repository = parts[0]
}
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
mp.Repository = repo
mp.Tag = tag
}
return mp
}
func (mp ModelPath) GetNamespaceRepository() string {
@@ -95,6 +101,13 @@ func (mp ModelPath) GetManifestPath(createDir bool) (string, error) {
return path, nil
}
func (mp ModelPath) BaseURL() *url.URL {
return &url.URL{
Scheme: mp.ProtocolScheme,
Host: mp.Registry,
}
}
func GetManifestPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {

88
server/modelpath_test.go Normal file
View File

@@ -0,0 +1,88 @@
package server
import "testing"
func TestParseModelPath(t *testing.T) {
tests := []struct {
name string
arg string
want ModelPath
}{
{
"full path https",
"https://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"full path http",
"http://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "http",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no protocol",
"example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no registry",
"ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no namespace",
"repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: "tag",
},
},
{
"no tag",
"repo",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: DefaultTag,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := ParseModelPath(tc.arg)
if got != tc.want {
t.Errorf("got: %q want: %q", got, tc.want)
}
})
}
}

View File

@@ -25,6 +25,20 @@ import (
"github.com/jmorganca/ollama/vector"
)
var mode string = gin.DebugMode
func init() {
switch mode {
case gin.DebugMode:
case gin.ReleaseMode:
case gin.TestMode:
default:
mode = gin.DebugMode
}
gin.SetMode(mode)
}
var loaded struct {
mu sync.Mutex
@@ -357,8 +371,9 @@ func ListModelsHandler(c *gin.Context) {
return nil
}
tag := path[:slashIndex] + ":" + path[slashIndex+1:]
mp := ParseModelPath(tag)
manifest, err := GetManifest(mp)
manifest, digest, err := GetManifest(mp)
if err != nil {
log.Printf("skipping file: %s", fp)
return nil
@@ -366,6 +381,7 @@ func ListModelsHandler(c *gin.Context) {
model := api.ListResponseModel{
Name: mp.GetShortTagname(),
Size: manifest.GetTotalSize(),
Digest: digest,
ModifiedAt: fi.ModTime(),
}
models = append(models, model)

165
server/upload.go Normal file
View File

@@ -0,0 +1,165 @@
package server
import (
"context"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"strconv"
"github.com/jmorganca/ollama/api"
)
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, error) {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
if layer.From != "" {
values := requestURL.Query()
values.Add("mount", layer.Digest)
values.Add("from", layer.From)
requestURL.RawQuery = values.Encode()
}
resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts)
if err != nil {
log.Printf("couldn't start upload: %v", err)
return nil, err
}
defer resp.Body.Close()
// Extract UUID location from header
location := resp.Header.Get("Location")
if location == "" {
return nil, fmt.Errorf("location header is missing in response")
}
return url.Parse(location)
}
func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
// TODO allow resumability
// TODO allow canceling uploads via DELETE
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
f, err := os.Open(fp)
if err != nil {
return err
}
defer f.Close()
// 95MB chunk size
chunkSize := 95 * 1024 * 1024
for offset := int64(0); offset < int64(layer.Size); {
chunk := int64(layer.Size) - offset
if chunk > int64(chunkSize) {
chunk = int64(chunkSize)
}
sectionReader := io.NewSectionReader(f, int64(offset), chunk)
for try := 0; try < MaxRetries; try++ {
r, w := io.Pipe()
defer r.Close()
go func() {
defer w.Close()
for chunked := int64(0); chunked < chunk; {
n, err := io.CopyN(w, sectionReader, 1024*1024)
if err != nil && !errors.Is(err, io.EOF) {
fn(api.ProgressResponse{
Status: fmt.Sprintf("error reading chunk: %v", err),
Digest: layer.Digest,
Total: layer.Size,
Completed: int(offset),
})
return
}
chunked += n
fn(api.ProgressResponse{
Status: fmt.Sprintf("uploading %s", layer.Digest),
Digest: layer.Digest,
Total: layer.Size,
Completed: int(offset) + int(chunked),
})
}
}()
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", strconv.Itoa(int(chunk)))
headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
resp, err := makeRequest(ctx, "PATCH", requestURL, headers, r, regOpts)
if err != nil && !errors.Is(err, io.EOF) {
fn(api.ProgressResponse{
Status: fmt.Sprintf("error uploading chunk: %v", err),
Digest: layer.Digest,
Total: layer.Size,
Completed: int(offset),
})
return err
}
defer resp.Body.Close()
switch {
case resp.StatusCode == http.StatusUnauthorized:
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir, regOpts)
if err != nil {
return err
}
regOpts.Token = token
if _, err := sectionReader.Seek(0, io.SeekStart); err != nil {
return err
}
continue
case resp.StatusCode >= http.StatusBadRequest:
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
}
offset += sectionReader.Size()
requestURL, err = url.Parse(resp.Header.Get("Location"))
if err != nil {
return err
}
break
}
}
values := requestURL.Query()
values.Add("digest", layer.Digest)
requestURL.RawQuery = values.Encode()
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", "0")
// finish the upload
resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts)
if err != nil {
log.Printf("couldn't finish upload: %v", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
}
return nil
}

3
version/version.go Normal file
View File

@@ -0,0 +1,3 @@
package version
var Version string = "0.0.0"