Compare commits
42 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
f4432e1dba | ||
![]() |
982c535428 | ||
![]() |
7df342a6ea | ||
![]() |
8bbff2df98 | ||
![]() |
16b06699fd | ||
![]() |
246dc65417 | ||
![]() |
865fceb73c | ||
![]() |
72266c7684 | ||
![]() |
d3b838ce60 | ||
![]() |
e639a12fa1 | ||
![]() |
e82fcf30c6 | ||
![]() |
495e8b0a6a | ||
![]() |
59734ca24d | ||
![]() |
22ab7f5f88 | ||
![]() |
b25dd1795d | ||
![]() |
304f2b6c96 | ||
![]() |
2ecc3a33c3 | ||
![]() |
ee6e1df118 | ||
![]() |
177b69a211 | ||
![]() |
dad63f0821 | ||
![]() |
041f9ad1a1 | ||
![]() |
7a378f8b66 | ||
![]() |
de0bdd7f29 | ||
![]() |
b1cececb8e | ||
![]() |
e0d39fa3bf | ||
![]() |
968ced2e71 | ||
![]() |
32d1a00017 | ||
![]() |
04e2128273 | ||
![]() |
2cc634689b | ||
![]() |
8f827641b0 | ||
![]() |
95187d7e1e | ||
![]() |
9ec7e37534 | ||
![]() |
2c7f956b38 | ||
![]() |
a9f6c56652 | ||
![]() |
0a892419ad | ||
![]() |
e3054fc74e | ||
![]() |
23c2485044 | ||
![]() |
386c66f285 | ||
![]() |
3b49315f97 | ||
![]() |
5ca05c2e88 | ||
![]() |
7eda70f23b | ||
![]() |
3d79b414d3 |
@@ -4,4 +4,5 @@ llama/build
|
||||
.vscode
|
||||
ollama
|
||||
app
|
||||
web
|
||||
web
|
||||
.env
|
||||
|
@@ -29,9 +29,9 @@ ollama run llama2
|
||||
|
||||
## Model library
|
||||
|
||||
Ollama supports a list of open-source models available on [ollama.ai/library](https://ollama.ai/library "ollama model library")
|
||||
Ollama supports a list of open-source models available on [ollama.ai/library](https://ollama.ai/library 'ollama model library')
|
||||
|
||||
Here are some example open-source models that can be downloaded:
|
||||
Here are some example open-source models that can be downloaded:
|
||||
|
||||
| Model | Parameters | Size | Download |
|
||||
| ------------------------ | ---------- | ----- | ------------------------------- |
|
||||
@@ -39,6 +39,7 @@ Here are some example open-source models that can be downloaded:
|
||||
| Llama2 13B | 13B | 7.3GB | `ollama pull llama2:13b` |
|
||||
| Llama2 70B | 70B | 39GB | `ollama pull llama2:70b` |
|
||||
| Llama2 Uncensored | 7B | 3.8GB | `ollama pull llama2-uncensored` |
|
||||
| Code Llama | 7B | 3.8GB | `ollama pull codellama` |
|
||||
| Orca Mini | 3B | 1.9GB | `ollama pull orca-mini` |
|
||||
| Vicuna | 7B | 3.8GB | `ollama pull vicuna` |
|
||||
| Nous-Hermes | 7B | 3.8GB | `ollama pull nous-hermes` |
|
||||
@@ -104,7 +105,7 @@ For more examples, see the [examples](./examples) directory. For more informatio
|
||||
### Pull a model from the registry
|
||||
|
||||
```
|
||||
ollama pull orca
|
||||
ollama pull orca-mini
|
||||
```
|
||||
|
||||
### Listing local models
|
||||
@@ -126,6 +127,8 @@ Ollama bundles model weights, configuration, and data into a single package, def
|
||||
|
||||
## Building
|
||||
|
||||
You will also need a C/C++ compiler such as GCC for MacOS and Linux or Mingw-w64 GCC for Windows.
|
||||
|
||||
```
|
||||
go build .
|
||||
```
|
||||
|
@@ -10,10 +10,13 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
||||
const DefaultHost = "localhost:11434"
|
||||
const DefaultHost = "127.0.0.1:11434"
|
||||
|
||||
var (
|
||||
envHost = os.Getenv("OLLAMA_HOST")
|
||||
@@ -26,7 +29,7 @@ type Client struct {
|
||||
}
|
||||
|
||||
func checkError(resp *http.Response, body []byte) error {
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 400 {
|
||||
if resp.StatusCode < http.StatusBadRequest {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -83,21 +86,21 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
reqBody = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
url := c.Base.JoinPath(path).String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
|
||||
requestURL := c.Base.JoinPath(path)
|
||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
|
||||
for k, v := range c.Headers {
|
||||
req.Header[k] = v
|
||||
request.Header[k] = v
|
||||
}
|
||||
|
||||
respObj, err := c.HTTP.Do(req)
|
||||
respObj, err := c.HTTP.Do(request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -131,13 +134,15 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
buf = bytes.NewBuffer(bts)
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, method, c.Base.JoinPath(path).String(), buf)
|
||||
requestURL := c.Base.JoinPath(path)
|
||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
|
||||
response, err := http.DefaultClient.Do(request)
|
||||
if err != nil {
|
||||
@@ -160,7 +165,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
return fmt.Errorf(errorResponse.Error)
|
||||
}
|
||||
|
||||
if response.StatusCode >= 400 {
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
|
@@ -96,6 +96,7 @@ type ListResponseModel struct {
|
||||
Name string `json:"name"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
|
@@ -27,7 +27,7 @@ const config: ForgeConfig = {
|
||||
path.join(__dirname, './assets/iconDarkTemplate@2x.png'),
|
||||
path.join(__dirname, './assets/iconDarkUpdateTemplate.png'),
|
||||
path.join(__dirname, './assets/iconDarkUpdateTemplate@2x.png'),
|
||||
...(process.platform === 'darwin' ? ['../llama/ggml-metal.metal'] : []),
|
||||
...(process.platform === 'darwin' ? ['../llm/ggml-metal.metal'] : []),
|
||||
],
|
||||
...(process.env.SIGN
|
||||
? {
|
||||
|
44
cmd/cmd.go
44
cmd/cmd.go
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/jmorganca/ollama/format"
|
||||
"github.com/jmorganca/ollama/progressbar"
|
||||
"github.com/jmorganca/ollama/server"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
@@ -97,7 +98,20 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
insecure, err := cmd.Flags().GetBool("insecure")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mp := server.ParseModelPath(args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if mp.ProtocolScheme == "http" && !insecure {
|
||||
return fmt.Errorf("insecure protocol http")
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath(false)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -106,7 +120,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
_, err = os.Stat(fp)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := pull(args[0], false); err != nil {
|
||||
if err := pull(args[0], insecure); err != nil {
|
||||
var apiStatusError api.StatusError
|
||||
if !errors.As(err, &apiStatusError) {
|
||||
return err
|
||||
@@ -182,12 +196,12 @@ func ListHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
for _, m := range models.Models {
|
||||
if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
|
||||
data = append(data, []string{m.Name, humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")})
|
||||
data = append(data, []string{m.Name, m.Digest[:12], humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")})
|
||||
}
|
||||
}
|
||||
|
||||
table := tablewriter.NewWriter(os.Stdout)
|
||||
table.SetHeader([]string{"NAME", "SIZE", "MODIFIED"})
|
||||
table.SetHeader([]string{"NAME", "ID", "SIZE", "MODIFIED"})
|
||||
table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
|
||||
table.SetAlignment(tablewriter.ALIGN_LEFT)
|
||||
table.SetHeaderLine(false)
|
||||
@@ -206,11 +220,13 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
req := api.DeleteRequest{Name: args[0]}
|
||||
if err := client.Delete(context.Background(), &req); err != nil {
|
||||
return err
|
||||
for _, name := range args {
|
||||
req := api.DeleteRequest{Name: name}
|
||||
if err := client.Delete(context.Background(), &req); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("deleted '%s'\n", name)
|
||||
}
|
||||
fmt.Printf("deleted '%s'\n", args[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -507,7 +523,11 @@ func generateInteractive(cmd *cobra.Command, model string) error {
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
mp := server.ParseModelPath(model)
|
||||
manifest, err := server.GetManifest(mp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manifest, _, err := server.GetManifest(mp)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get a manifest for this model")
|
||||
continue
|
||||
@@ -569,7 +589,7 @@ func generateBatch(cmd *cobra.Command, model string) error {
|
||||
}
|
||||
|
||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||
var host, port = "127.0.0.1", "11434"
|
||||
host, port := "127.0.0.1", "11434"
|
||||
|
||||
parts := strings.Split(os.Getenv("OLLAMA_HOST"), ":")
|
||||
if ip := net.ParseIP(parts[0]); ip != nil {
|
||||
@@ -630,7 +650,7 @@ func initializeKeypair() error {
|
||||
return fmt.Errorf("could not create directory %w", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0600)
|
||||
err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0o600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -642,7 +662,7 @@ func initializeKeypair() error {
|
||||
|
||||
pubKeyData := ssh.MarshalAuthorizedKey(sshPrivateKey.PublicKey())
|
||||
|
||||
err = os.WriteFile(pubKeyPath, pubKeyData, 0644)
|
||||
err = os.WriteFile(pubKeyPath, pubKeyData, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -714,6 +734,7 @@ func NewCLI() *cobra.Command {
|
||||
CompletionOptions: cobra.CompletionOptions{
|
||||
DisableDefaultCmd: true,
|
||||
},
|
||||
Version: version.Version,
|
||||
}
|
||||
|
||||
cobra.EnableCommandSorting = false
|
||||
@@ -737,6 +758,7 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
|
||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
serveCmd := &cobra.Command{
|
||||
Use: "serve",
|
||||
|
@@ -14,7 +14,7 @@
|
||||
|
||||
### Model names
|
||||
|
||||
Model names follow a `model:tag` format. Some examples are `orca:3b-q4_1` and `llama2:70b`. The tag is optional and if not provided will default to `latest`. The tag is used to identify a specific version.
|
||||
Model names follow a `model:tag` format. Some examples are `orca-mini:3b-q4_1` and `llama2:70b`. The tag is optional and if not provided will default to `latest`. The tag is used to identify a specific version.
|
||||
|
||||
### Durations
|
||||
|
||||
|
@@ -25,20 +25,3 @@ Now you can run `ollama`:
|
||||
```
|
||||
./ollama
|
||||
```
|
||||
|
||||
## Releasing
|
||||
|
||||
To release a new version of Ollama you'll need to set some environment variables:
|
||||
|
||||
- `GITHUB_TOKEN`: your GitHub token
|
||||
- `APPLE_IDENTITY`: the Apple signing identity (macOS only)
|
||||
- `APPLE_ID`: your Apple ID
|
||||
- `APPLE_PASSWORD`: your Apple ID app-specific password
|
||||
- `APPLE_TEAM_ID`: the Apple team ID for the signing identity
|
||||
- `TELEMETRY_WRITE_KEY`: segment write key for telemetry
|
||||
|
||||
Then run the publish script with the target version:
|
||||
|
||||
```
|
||||
VERSION=0.0.2 ./scripts/publish.sh
|
||||
```
|
||||
|
@@ -123,7 +123,7 @@ PARAMETER <parameter> <parametervalue>
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
| stop | Sets the stop tokens to use. | string | stop "AI assistant:" |
|
||||
| stop | Sets the stop sequences to use. | string | stop "AI assistant:" |
|
||||
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
|
||||
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
|
||||
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
|
||||
|
@@ -15,6 +15,7 @@ const (
|
||||
ModelType3B ModelType = 26
|
||||
ModelType7B ModelType = 32
|
||||
ModelType13B ModelType = 40
|
||||
ModelType34B ModelType = 48
|
||||
ModelType30B ModelType = 60
|
||||
ModelType65B ModelType = 80
|
||||
)
|
||||
@@ -27,6 +28,8 @@ func (mt ModelType) String() string {
|
||||
return "7B"
|
||||
case ModelType13B:
|
||||
return "13B"
|
||||
case ModelType34B:
|
||||
return "34B"
|
||||
case ModelType30B:
|
||||
return "30B"
|
||||
case ModelType65B:
|
||||
|
@@ -105,6 +105,7 @@ enum e_model {
|
||||
MODEL_7B,
|
||||
MODEL_13B,
|
||||
MODEL_30B,
|
||||
MODEL_34B,
|
||||
MODEL_65B,
|
||||
MODEL_70B,
|
||||
};
|
||||
@@ -148,6 +149,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx)
|
||||
{ MODEL_7B, ((size_t) n_ctx / 16ull + 100ull) * MB },
|
||||
{ MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB },
|
||||
{ MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB },
|
||||
{ MODEL_34B, ((size_t) n_ctx / 9ull + 160ull) * MB },
|
||||
{ MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess
|
||||
{ MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB },
|
||||
};
|
||||
@@ -161,6 +163,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
|
||||
{ MODEL_7B, 160ull * MB },
|
||||
{ MODEL_13B, 192ull * MB },
|
||||
{ MODEL_30B, 256ull * MB },
|
||||
{ MODEL_34B, 256ull * MB },
|
||||
{ MODEL_65B, 384ull * MB }, // guess
|
||||
{ MODEL_70B, 304ull * MB },
|
||||
};
|
||||
@@ -175,6 +178,7 @@ static const std::map<e_model, size_t> & MEM_REQ_EVAL()
|
||||
{ MODEL_7B, 10ull * MB },
|
||||
{ MODEL_13B, 12ull * MB },
|
||||
{ MODEL_30B, 16ull * MB },
|
||||
{ MODEL_34B, 16ull * MB },
|
||||
{ MODEL_65B, 24ull * MB }, // guess
|
||||
{ MODEL_70B, 24ull * MB },
|
||||
};
|
||||
@@ -190,6 +194,7 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE()
|
||||
{ MODEL_7B, 512ull * kB },
|
||||
{ MODEL_13B, 640ull * kB },
|
||||
{ MODEL_30B, 768ull * kB },
|
||||
{ MODEL_34B, 768ull * kB },
|
||||
{ MODEL_65B, 1280ull * kB },
|
||||
{ MODEL_70B, 1280ull * kB },
|
||||
};
|
||||
@@ -205,6 +210,7 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
|
||||
{ MODEL_7B, 128ull },
|
||||
{ MODEL_13B, 160ull },
|
||||
{ MODEL_30B, 208ull },
|
||||
{ MODEL_34B, 208ull },
|
||||
{ MODEL_65B, 256ull },
|
||||
{ MODEL_70B, 256ull },
|
||||
};
|
||||
@@ -1053,6 +1059,7 @@ static const char *llama_model_type_name(e_model type) {
|
||||
case MODEL_7B: return "7B";
|
||||
case MODEL_13B: return "13B";
|
||||
case MODEL_30B: return "30B";
|
||||
case MODEL_34B: return "34B";
|
||||
case MODEL_65B: return "65B";
|
||||
case MODEL_70B: return "70B";
|
||||
default: LLAMA_ASSERT(false);
|
||||
@@ -1100,6 +1107,7 @@ static void llama_model_load_internal(
|
||||
case 26: model.type = e_model::MODEL_3B; break;
|
||||
case 32: model.type = e_model::MODEL_7B; break;
|
||||
case 40: model.type = e_model::MODEL_13B; break;
|
||||
case 48: model.type = e_model::MODEL_34B; break;
|
||||
case 60: model.type = e_model::MODEL_30B; break;
|
||||
case 80: model.type = e_model::MODEL_65B; break;
|
||||
default:
|
||||
@@ -1120,6 +1128,8 @@ static void llama_model_load_internal(
|
||||
LLAMA_LOG_WARN("%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa);
|
||||
model.type = e_model::MODEL_70B;
|
||||
hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model
|
||||
} else if (model.type == e_model::MODEL_34B && n_gqa == 8) {
|
||||
hparams.f_ffn_mult = 1.0f; // from the params.json of the 34B model
|
||||
}
|
||||
|
||||
hparams.rope_freq_base = rope_freq_base;
|
||||
|
59
llm/llama.go
59
llm/llama.go
@@ -117,7 +117,21 @@ func (llm *llamaModel) ModelFamily() ModelFamily {
|
||||
}
|
||||
|
||||
func (llm *llamaModel) ModelType() ModelType {
|
||||
return ModelType30B
|
||||
switch llm.hyperparameters.NumLayer {
|
||||
case 26:
|
||||
return ModelType3B
|
||||
case 32:
|
||||
return ModelType7B
|
||||
case 40:
|
||||
return ModelType13B
|
||||
case 60:
|
||||
return ModelType30B
|
||||
case 80:
|
||||
return ModelType65B
|
||||
}
|
||||
|
||||
// TODO: find a better default
|
||||
return ModelType7B
|
||||
}
|
||||
|
||||
func (llm *llamaModel) FileType() FileType {
|
||||
@@ -320,20 +334,18 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse
|
||||
|
||||
b.WriteString(llm.Decode(int(token)))
|
||||
|
||||
if err := llm.checkStopConditions(b); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if errors.Is(err, errNeedMoreData) {
|
||||
continue
|
||||
}
|
||||
|
||||
return err
|
||||
stop, endsWithStopPrefix := handleStopSequences(&b, llm.Stop)
|
||||
if endsWithStopPrefix {
|
||||
continue
|
||||
}
|
||||
|
||||
if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
|
||||
fn(api.GenerateResponse{Response: b.String()})
|
||||
b.Reset()
|
||||
}
|
||||
if stop {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
embd := make([]int, len(llm.embd))
|
||||
@@ -356,16 +368,31 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *llama) checkStopConditions(b bytes.Buffer) error {
|
||||
for _, stopCondition := range llm.Stop {
|
||||
if stopCondition == strings.TrimSpace(b.String()) {
|
||||
return io.EOF
|
||||
} else if strings.HasPrefix(stopCondition, strings.TrimSpace(b.String())) {
|
||||
return errNeedMoreData
|
||||
// handleStopSequences checks whether b contains any of the stop sequences, or ends with a prefix of
|
||||
// any stop sequence (and therefore might contain data that should not ultimately be returned to the
|
||||
// client).
|
||||
//
|
||||
// If b contains a stop sequence, it modifies b to remove the stop sequence and all subsequent data.
|
||||
func handleStopSequences(b *bytes.Buffer, stopSequences []string) (stop bool, endsWithStopPrefix bool) {
|
||||
s := b.String()
|
||||
for _, seq := range stopSequences {
|
||||
// Check for an exact or substring match.
|
||||
if i := strings.Index(s, seq); i != -1 {
|
||||
b.Truncate(i)
|
||||
return true, false
|
||||
}
|
||||
|
||||
// Check if b ends with a prefix of the stop sequence.
|
||||
if len(seq) > 1 {
|
||||
for i := 1; i < len(seq); i++ {
|
||||
if strings.HasSuffix(s, seq[:i]) {
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token {
|
||||
|
79
llm/llama_test.go
Normal file
79
llm/llama_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCheckStopConditions(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
b string
|
||||
stop []string
|
||||
wantB string
|
||||
wantStop bool
|
||||
wantEndsWithStopPrefix bool
|
||||
}{
|
||||
"not present": {
|
||||
b: "abc",
|
||||
stop: []string{"x"},
|
||||
wantStop: false,
|
||||
wantEndsWithStopPrefix: false,
|
||||
},
|
||||
"exact": {
|
||||
b: "abc",
|
||||
stop: []string{"abc"},
|
||||
wantStop: true,
|
||||
wantEndsWithStopPrefix: false,
|
||||
},
|
||||
"substring": {
|
||||
b: "abc",
|
||||
stop: []string{"b"},
|
||||
wantB: "a",
|
||||
wantStop: true,
|
||||
wantEndsWithStopPrefix: false,
|
||||
},
|
||||
"prefix 1": {
|
||||
b: "abc",
|
||||
stop: []string{"abcd"},
|
||||
wantStop: false,
|
||||
wantEndsWithStopPrefix: true,
|
||||
},
|
||||
"prefix 2": {
|
||||
b: "abc",
|
||||
stop: []string{"bcd"},
|
||||
wantStop: false,
|
||||
wantEndsWithStopPrefix: true,
|
||||
},
|
||||
"prefix 3": {
|
||||
b: "abc",
|
||||
stop: []string{"cd"},
|
||||
wantStop: false,
|
||||
wantEndsWithStopPrefix: true,
|
||||
},
|
||||
"no prefix": {
|
||||
b: "abc",
|
||||
stop: []string{"bx"},
|
||||
wantStop: false,
|
||||
wantEndsWithStopPrefix: false,
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
b.WriteString(test.b)
|
||||
stop, endsWithStopPrefix := handleStopSequences(&b, test.stop)
|
||||
if test.wantB != "" {
|
||||
gotB := b.String()
|
||||
if gotB != test.wantB {
|
||||
t.Errorf("got b %q, want %q", gotB, test.wantB)
|
||||
}
|
||||
}
|
||||
if stop != test.wantStop {
|
||||
t.Errorf("got stop %v, want %v", stop, test.wantStop)
|
||||
}
|
||||
if endsWithStopPrefix != test.wantEndsWithStopPrefix {
|
||||
t.Errorf("got endsWithStopPrefix %v, want %v", endsWithStopPrefix, test.wantEndsWithStopPrefix)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
22
llm/llm.go
22
llm/llm.go
@@ -36,11 +36,11 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
|
||||
}
|
||||
|
||||
switch ggml.FileType().String() {
|
||||
case "F32", "F16", "Q5_0", "Q5_1", "Q8_0":
|
||||
case "F32", "Q5_0", "Q5_1", "Q8_0":
|
||||
if opts.NumGPU != 0 {
|
||||
// F32, F16, Q5_0, Q5_1, and Q8_0 do not support Metal API and will
|
||||
// cause the runner to segmentation fault so disable GPU
|
||||
log.Printf("WARNING: GPU disabled for F32, F16, Q5_0, Q5_1, and Q8_0")
|
||||
log.Printf("WARNING: GPU disabled for F32, Q5_0, Q5_1, and Q8_0")
|
||||
opts.NumGPU = 0
|
||||
}
|
||||
}
|
||||
@@ -48,19 +48,27 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
|
||||
totalResidentMemory := memory.TotalMemory()
|
||||
switch ggml.ModelType() {
|
||||
case ModelType3B, ModelType7B:
|
||||
if totalResidentMemory < 8*1024*1024 {
|
||||
if ggml.FileType().String() == "F16" && totalResidentMemory < 16*1024*1024 {
|
||||
return nil, fmt.Errorf("F16 model requires at least 16GB of memory")
|
||||
} else if totalResidentMemory < 8*1024*1024 {
|
||||
return nil, fmt.Errorf("model requires at least 8GB of memory")
|
||||
}
|
||||
case ModelType13B:
|
||||
if totalResidentMemory < 16*1024*1024 {
|
||||
if ggml.FileType().String() == "F16" && totalResidentMemory < 32*1024*1024 {
|
||||
return nil, fmt.Errorf("F16 model requires at least 32GB of memory")
|
||||
} else if totalResidentMemory < 16*1024*1024 {
|
||||
return nil, fmt.Errorf("model requires at least 16GB of memory")
|
||||
}
|
||||
case ModelType30B:
|
||||
if totalResidentMemory < 32*1024*1024 {
|
||||
case ModelType30B, ModelType34B:
|
||||
if ggml.FileType().String() == "F16" && totalResidentMemory < 64*1024*1024 {
|
||||
return nil, fmt.Errorf("F16 model requires at least 64GB of memory")
|
||||
} else if totalResidentMemory < 32*1024*1024 {
|
||||
return nil, fmt.Errorf("model requires at least 32GB of memory")
|
||||
}
|
||||
case ModelType65B:
|
||||
if totalResidentMemory < 64*1024*1024 {
|
||||
if ggml.FileType().String() == "F16" && totalResidentMemory < 128*1024*1024 {
|
||||
return nil, fmt.Errorf("F16 model requires at least 128GB of memory")
|
||||
} else if totalResidentMemory < 64*1024*1024 {
|
||||
return nil, fmt.Errorf("model requires at least 64GB of memory")
|
||||
}
|
||||
}
|
||||
|
@@ -2,9 +2,12 @@
|
||||
|
||||
mkdir -p dist
|
||||
|
||||
GO_LDFLAGS="-X github.com/jmorganca/ollama/version.Version=$VERSION"
|
||||
GO_LDFLAGS="$GO_LDFLAGS -X github.com/jmorganca/ollama/server.mode=release"
|
||||
|
||||
# build universal binary
|
||||
CGO_ENABLED=1 GOARCH=arm64 go build -o dist/ollama-darwin-arm64
|
||||
CGO_ENABLED=1 GOARCH=amd64 go build -o dist/ollama-darwin-amd64
|
||||
CGO_ENABLED=1 GOARCH=arm64 go build -ldflags "$GO_LDFLAGS" -o dist/ollama-darwin-arm64
|
||||
CGO_ENABLED=1 GOARCH=amd64 go build -ldflags "$GO_LDFLAGS" -o dist/ollama-darwin-amd64
|
||||
lipo -create -output dist/ollama dist/ollama-darwin-arm64 dist/ollama-darwin-amd64
|
||||
rm dist/ollama-darwin-amd64 dist/ollama-darwin-arm64
|
||||
codesign --deep --force --options=runtime --sign "$APPLE_IDENTITY" --timestamp dist/ollama
|
||||
|
@@ -12,8 +12,10 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -43,21 +45,34 @@ func generateNonce(length int) (string, error) {
|
||||
return base64.RawURLEncoding.EncodeToString(nonce), nil
|
||||
}
|
||||
|
||||
func (r AuthRedirect) URL() (string, error) {
|
||||
func (r AuthRedirect) URL() (*url.URL, error) {
|
||||
redirectURL, err := url.Parse(r.Realm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values := redirectURL.Query()
|
||||
|
||||
values.Add("service", r.Service)
|
||||
|
||||
for _, s := range strings.Split(r.Scope, " ") {
|
||||
values.Add("scope", s)
|
||||
}
|
||||
|
||||
values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
|
||||
nonce, err := generateNonce(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
scopes := []string{}
|
||||
for _, s := range strings.Split(r.Scope, " ") {
|
||||
scopes = append(scopes, fmt.Sprintf("scope=%s", s))
|
||||
}
|
||||
scopeStr := strings.Join(scopes, "&")
|
||||
return fmt.Sprintf("%s?service=%s&%s&ts=%d&nonce=%s", r.Realm, r.Service, scopeStr, time.Now().Unix(), nonce), nil
|
||||
values.Add("nonce", nonce)
|
||||
|
||||
redirectURL.RawQuery = values.Encode()
|
||||
return redirectURL, nil
|
||||
}
|
||||
|
||||
func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *RegistryOptions) (string, error) {
|
||||
url, err := redirData.URL()
|
||||
redirectURL, err := redirData.URL()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -77,34 +92,24 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *Registry
|
||||
|
||||
s := SignatureData{
|
||||
Method: "GET",
|
||||
Path: url,
|
||||
Path: redirectURL.String(),
|
||||
Data: nil,
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(s.Path, "http") {
|
||||
if regOpts.Insecure {
|
||||
s.Path = "http://" + url
|
||||
} else {
|
||||
s.Path = "https://" + url
|
||||
}
|
||||
}
|
||||
|
||||
sig, err := s.Sign(rawKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headers := map[string]string{
|
||||
"Authorization": sig,
|
||||
}
|
||||
|
||||
resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
|
||||
headers := make(http.Header)
|
||||
headers.Set("Authorization", sig)
|
||||
resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't get token: %q", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
@@ -155,19 +155,20 @@ func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
|
||||
}
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v2/%s/blobs/%s", opts.mp.Registry, opts.mp.GetNamespaceRepository(), f.Digest)
|
||||
headers := map[string]string{
|
||||
"Range": fmt.Sprintf("bytes=%d-", size),
|
||||
}
|
||||
requestURL := opts.mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", f.Digest)
|
||||
|
||||
resp, err := makeRequest(ctx, "GET", url, headers, nil, opts.regOpts)
|
||||
headers := make(http.Header)
|
||||
headers.Set("Range", fmt.Sprintf("bytes=%d-", size))
|
||||
|
||||
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts.regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't download blob: %v", err)
|
||||
return fmt.Errorf("%w: %w", errDownload, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("%w: on download registry responded with code %d: %v", errDownload, resp.StatusCode, string(body))
|
||||
}
|
||||
|
276
server/images.go
276
server/images.go
@@ -5,6 +5,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -12,10 +13,12 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -23,6 +26,7 @@ import (
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/jmorganca/ollama/parser"
|
||||
"github.com/jmorganca/ollama/vector"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
||||
const MaxRetries = 3
|
||||
@@ -41,6 +45,7 @@ type Model struct {
|
||||
Template string
|
||||
System string
|
||||
Digest string
|
||||
ConfigDigest string
|
||||
Options map[string]interface{}
|
||||
Embeddings []vector.Embedding
|
||||
}
|
||||
@@ -128,41 +133,45 @@ func (m *ManifestV2) GetTotalSize() int {
|
||||
return total
|
||||
}
|
||||
|
||||
func GetManifest(mp ModelPath) (*ManifestV2, error) {
|
||||
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
|
||||
fp, err := mp.GetManifestPath(false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if _, err = os.Stat(fp); err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
var manifest *ManifestV2
|
||||
|
||||
bts, err := os.ReadFile(fp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't open file '%s'", fp)
|
||||
return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
|
||||
}
|
||||
|
||||
shaSum := sha256.Sum256(bts)
|
||||
shaStr := hex.EncodeToString(shaSum[:])
|
||||
|
||||
if err := json.Unmarshal(bts, &manifest); err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return manifest, nil
|
||||
return manifest, shaStr, nil
|
||||
}
|
||||
|
||||
func GetModel(name string) (*Model, error) {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
manifest, err := GetManifest(mp)
|
||||
manifest, digest, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model := &Model{
|
||||
Name: mp.GetFullTagname(),
|
||||
Digest: manifest.Config.Digest,
|
||||
Name: mp.GetFullTagname(),
|
||||
Digest: digest,
|
||||
ConfigDigest: manifest.Config.Digest,
|
||||
Template: "{{ .Prompt }}",
|
||||
}
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
@@ -272,8 +281,9 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||
case "model":
|
||||
fn(api.ProgressResponse{Status: "looking for model"})
|
||||
embed.model = c.Args
|
||||
|
||||
mp := ParseModelPath(c.Args)
|
||||
mf, err := GetManifest(mp)
|
||||
mf, _, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
modelFile, err := filenameWithPath(path, c.Args)
|
||||
if err != nil {
|
||||
@@ -286,7 +296,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
mf, err = GetManifest(ParseModelPath(c.Args))
|
||||
mf, _, err = GetManifest(mp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file after pull: %v", err)
|
||||
}
|
||||
@@ -325,7 +335,27 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||
}
|
||||
|
||||
if mf != nil {
|
||||
log.Printf("manifest = %#v", mf)
|
||||
sourceBlobPath, err := GetBlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sourceBlob, err := os.Open(sourceBlobPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sourceBlob.Close()
|
||||
|
||||
var source ConfigV2
|
||||
if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// copie the model metadata
|
||||
config.ModelFamily = source.ModelFamily
|
||||
config.ModelType = source.ModelType
|
||||
config.FileType = source.FileType
|
||||
|
||||
for _, l := range mf.Layers {
|
||||
newLayer, err := GetLayerWithBufferFromLayer(l)
|
||||
if err != nil {
|
||||
@@ -400,7 +430,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||
layer.MediaType = mediaType
|
||||
layers = append(layers, layer)
|
||||
default:
|
||||
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens)
|
||||
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences)
|
||||
params[c.Name] = append(params[c.Name], c.Args)
|
||||
}
|
||||
}
|
||||
@@ -655,7 +685,6 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
|
||||
|
||||
func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
manifest := ManifestV2{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
@@ -786,11 +815,14 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
|
||||
}
|
||||
|
||||
func CopyModel(src, dest string) error {
|
||||
srcPath, err := ParseModelPath(src).GetManifestPath(false)
|
||||
srcModelPath := ParseModelPath(src)
|
||||
srcPath, err := srcModelPath.GetManifestPath(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destPath, err := ParseModelPath(dest).GetManifestPath(true)
|
||||
|
||||
destModelPath := ParseModelPath(dest)
|
||||
destPath, err := destModelPath.GetManifestPath(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -813,8 +845,7 @@ func CopyModel(src, dest string) error {
|
||||
|
||||
func DeleteModel(name string) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
manifest, err := GetManifest(mp)
|
||||
manifest, _, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -847,7 +878,7 @@ func DeleteModel(name string) error {
|
||||
}
|
||||
|
||||
// save (i.e. delete from the deleteMap) any files used in other manifests
|
||||
manifest, err := GetManifest(fmp)
|
||||
manifest, _, err := GetManifest(fmp)
|
||||
if err != nil {
|
||||
log.Printf("skipping file: %s", fp)
|
||||
return nil
|
||||
@@ -893,10 +924,13 @@ func DeleteModel(name string) error {
|
||||
|
||||
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
manifest, err := GetManifest(mp)
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return fmt.Errorf("insecure protocol http")
|
||||
}
|
||||
|
||||
manifest, _, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
|
||||
return err
|
||||
@@ -935,8 +969,8 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.HasPrefix(path.Base(location), "sha256:") {
|
||||
layer.Digest = path.Base(location)
|
||||
if strings.HasPrefix(path.Base(location.Path), "sha256:") {
|
||||
layer.Digest = path.Base(location.Path)
|
||||
fn(api.ProgressResponse{
|
||||
Status: "using existing layer",
|
||||
Digest: layer.Digest,
|
||||
@@ -946,24 +980,24 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
continue
|
||||
}
|
||||
|
||||
if err := uploadBlobChunked(ctx, mp, location, layer, regOpts, fn); err != nil {
|
||||
if err := uploadBlobChunked(ctx, location, layer, regOpts, fn); err != nil {
|
||||
log.Printf("error uploading blob: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pushing manifest"})
|
||||
url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
|
||||
headers := map[string]string{
|
||||
"Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
}
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -977,6 +1011,10 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return fmt.Errorf("insecure protocol http")
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
|
||||
manifest, err := pullModelManifest(ctx, mp, regOpts)
|
||||
@@ -1043,23 +1081,22 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
|
||||
url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
|
||||
headers := map[string]string{
|
||||
"Accept": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
}
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
|
||||
resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
|
||||
headers := make(http.Header)
|
||||
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't get manifest: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check for success: For a successful upload, the Docker registry will respond with a 201 Created
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, fmt.Errorf("model not found")
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
@@ -1107,35 +1144,12 @@ func GetSHA256Digest(r io.Reader) (string, int) {
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
|
||||
}
|
||||
|
||||
type requestContextKey string
|
||||
|
||||
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) {
|
||||
url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
|
||||
if layer.From != "" {
|
||||
url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From)
|
||||
}
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, "POST", url, nil, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't start upload: %v", err)
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Extract UUID location from header
|
||||
location := resp.Header.Get("Location")
|
||||
if location == "" {
|
||||
return "", fmt.Errorf("location header is missing in response")
|
||||
}
|
||||
|
||||
return location, nil
|
||||
}
|
||||
|
||||
// Function to check if a blob already exists in the Docker registry
|
||||
func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
|
||||
url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest)
|
||||
|
||||
resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts)
|
||||
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't check for blob: %v", err)
|
||||
return false, err
|
||||
@@ -1143,113 +1157,13 @@ func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpt
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check for success: If the blob exists, the Docker registry will respond with a 200 OK
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
return resp.StatusCode < http.StatusBadRequest, nil
|
||||
}
|
||||
|
||||
func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
// TODO allow resumability
|
||||
// TODO allow canceling uploads via DELETE
|
||||
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
completed := 0
|
||||
chunkSize := 10 * 1024 * 1024
|
||||
|
||||
for {
|
||||
r, w := io.Pipe()
|
||||
defer r.Close()
|
||||
|
||||
limit := completed + chunkSize
|
||||
if chunkSize >= layer.Size-completed {
|
||||
limit = layer.Size
|
||||
chunkSize = layer.Size - completed
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer w.Close()
|
||||
for {
|
||||
n, err := io.CopyN(w, f, 1024*1024)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("error copying pipe: %v", err),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: completed,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
completed += int(n)
|
||||
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("uploading %s", layer.Digest),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: completed,
|
||||
})
|
||||
|
||||
if completed >= limit {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
headers := make(map[string]string)
|
||||
headers["Content-Type"] = "application/octet-stream"
|
||||
headers["Content-Length"] = strconv.Itoa(chunkSize)
|
||||
headers["Content-Range"] = fmt.Sprintf("%d-%d", completed, limit-1)
|
||||
|
||||
resp, err := makeRequest(ctx, "PATCH", url, headers, r, regOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
url = resp.Header.Get("Location")
|
||||
if completed >= layer.Size {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
url = fmt.Sprintf("%s&digest=%s", url, layer.Digest)
|
||||
|
||||
headers := make(map[string]string)
|
||||
headers["Content-Type"] = "application/octet-stream"
|
||||
headers["Content-Length"] = "0"
|
||||
|
||||
// finish the upload
|
||||
resp, err := makeRequest(ctx, "PUT", url, headers, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't finish upload: %v", err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method, url string, headers map[string]string, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
var status string
|
||||
for try := 0; try < MaxRetries; try++ {
|
||||
resp, err := makeRequest(ctx, method, url, headers, body, regOpts)
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't start upload: %v", err)
|
||||
return nil, err
|
||||
@@ -1257,10 +1171,8 @@ func makeRequestWithRetry(ctx context.Context, method, url string, headers map[s
|
||||
|
||||
status = resp.Status
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusAccepted, http.StatusCreated:
|
||||
return resp, nil
|
||||
case http.StatusUnauthorized:
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
auth := resp.Header.Get("www-authenticate")
|
||||
authRedir := ParseAuthRedirectString(auth)
|
||||
token, err := getAuthToken(ctx, authRedir, regOpts)
|
||||
@@ -1276,38 +1188,38 @@ func makeRequestWithRetry(ctx context.Context, method, url string, headers map[s
|
||||
}
|
||||
|
||||
continue
|
||||
default:
|
||||
case resp.StatusCode >= http.StatusBadRequest:
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
|
||||
default:
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("max retry exceeded: %v", status)
|
||||
}
|
||||
|
||||
func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
if !strings.HasPrefix(url, "http") {
|
||||
if regOpts.Insecure {
|
||||
url = "http://" + url
|
||||
} else {
|
||||
url = "https://" + url
|
||||
}
|
||||
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
if requestURL.Scheme != "http" && regOpts.Insecure {
|
||||
requestURL.Scheme = "http"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if headers != nil {
|
||||
req.Header = headers
|
||||
}
|
||||
|
||||
if regOpts.Token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
|
||||
} else if regOpts.Username != "" && regOpts.Password != "" {
|
||||
req.SetBasicAuth(regOpts.Username, regOpts.Password)
|
||||
}
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
|
||||
client := &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
|
@@ -1,7 +1,9 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -23,42 +25,46 @@ const (
|
||||
DefaultProtocolScheme = "https"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
func ParseModelPath(name string) ModelPath {
|
||||
slashParts := strings.Split(name, "/")
|
||||
var registry, namespace, repository, tag string
|
||||
|
||||
switch len(slashParts) {
|
||||
case 3:
|
||||
registry = slashParts[0]
|
||||
namespace = slashParts[1]
|
||||
repository = strings.Split(slashParts[2], ":")[0]
|
||||
case 2:
|
||||
registry = DefaultRegistry
|
||||
namespace = slashParts[0]
|
||||
repository = strings.Split(slashParts[1], ":")[0]
|
||||
case 1:
|
||||
registry = DefaultRegistry
|
||||
namespace = DefaultNamespace
|
||||
repository = strings.Split(slashParts[0], ":")[0]
|
||||
default:
|
||||
fmt.Println("Invalid image format.")
|
||||
return ModelPath{}
|
||||
}
|
||||
|
||||
colonParts := strings.Split(slashParts[len(slashParts)-1], ":")
|
||||
if len(colonParts) == 2 {
|
||||
tag = colonParts[1]
|
||||
} else {
|
||||
tag = DefaultTag
|
||||
}
|
||||
|
||||
return ModelPath{
|
||||
mp := ModelPath{
|
||||
ProtocolScheme: DefaultProtocolScheme,
|
||||
Registry: registry,
|
||||
Namespace: namespace,
|
||||
Repository: repository,
|
||||
Tag: tag,
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "",
|
||||
Tag: DefaultTag,
|
||||
}
|
||||
|
||||
before, after, found := strings.Cut(name, "://")
|
||||
if found {
|
||||
mp.ProtocolScheme = before
|
||||
name = after
|
||||
}
|
||||
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
mp.Registry = parts[0]
|
||||
mp.Namespace = parts[1]
|
||||
mp.Repository = parts[2]
|
||||
case 2:
|
||||
mp.Namespace = parts[0]
|
||||
mp.Repository = parts[1]
|
||||
case 1:
|
||||
mp.Repository = parts[0]
|
||||
}
|
||||
|
||||
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
|
||||
mp.Repository = repo
|
||||
mp.Tag = tag
|
||||
}
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetNamespaceRepository() string {
|
||||
@@ -95,6 +101,13 @@ func (mp ModelPath) GetManifestPath(createDir bool) (string, error) {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func (mp ModelPath) BaseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: mp.ProtocolScheme,
|
||||
Host: mp.Registry,
|
||||
}
|
||||
}
|
||||
|
||||
func GetManifestPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
|
88
server/modelpath_test.go
Normal file
88
server/modelpath_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package server
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseModelPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arg string
|
||||
want ModelPath
|
||||
}{
|
||||
{
|
||||
"full path https",
|
||||
"https://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"full path http",
|
||||
"http://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "http",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no protocol",
|
||||
"example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no registry",
|
||||
"ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no namespace",
|
||||
"repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no tag",
|
||||
"repo",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: DefaultTag,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ParseModelPath(tc.arg)
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("got: %q want: %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -25,6 +25,20 @@ import (
|
||||
"github.com/jmorganca/ollama/vector"
|
||||
)
|
||||
|
||||
var mode string = gin.DebugMode
|
||||
|
||||
func init() {
|
||||
switch mode {
|
||||
case gin.DebugMode:
|
||||
case gin.ReleaseMode:
|
||||
case gin.TestMode:
|
||||
default:
|
||||
mode = gin.DebugMode
|
||||
}
|
||||
|
||||
gin.SetMode(mode)
|
||||
}
|
||||
|
||||
var loaded struct {
|
||||
mu sync.Mutex
|
||||
|
||||
@@ -357,8 +371,9 @@ func ListModelsHandler(c *gin.Context) {
|
||||
return nil
|
||||
}
|
||||
tag := path[:slashIndex] + ":" + path[slashIndex+1:]
|
||||
|
||||
mp := ParseModelPath(tag)
|
||||
manifest, err := GetManifest(mp)
|
||||
manifest, digest, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
log.Printf("skipping file: %s", fp)
|
||||
return nil
|
||||
@@ -366,6 +381,7 @@ func ListModelsHandler(c *gin.Context) {
|
||||
model := api.ListResponseModel{
|
||||
Name: mp.GetShortTagname(),
|
||||
Size: manifest.GetTotalSize(),
|
||||
Digest: digest,
|
||||
ModifiedAt: fi.ModTime(),
|
||||
}
|
||||
models = append(models, model)
|
||||
|
165
server/upload.go
Normal file
165
server/upload.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, error) {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
||||
if layer.From != "" {
|
||||
values := requestURL.Query()
|
||||
values.Add("mount", layer.Digest)
|
||||
values.Add("from", layer.From)
|
||||
requestURL.RawQuery = values.Encode()
|
||||
}
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't start upload: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Extract UUID location from header
|
||||
location := resp.Header.Get("Location")
|
||||
if location == "" {
|
||||
return nil, fmt.Errorf("location header is missing in response")
|
||||
}
|
||||
|
||||
return url.Parse(location)
|
||||
}
|
||||
|
||||
func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
// TODO allow resumability
|
||||
// TODO allow canceling uploads via DELETE
|
||||
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// 95MB chunk size
|
||||
chunkSize := 95 * 1024 * 1024
|
||||
|
||||
for offset := int64(0); offset < int64(layer.Size); {
|
||||
chunk := int64(layer.Size) - offset
|
||||
if chunk > int64(chunkSize) {
|
||||
chunk = int64(chunkSize)
|
||||
}
|
||||
|
||||
sectionReader := io.NewSectionReader(f, int64(offset), chunk)
|
||||
for try := 0; try < MaxRetries; try++ {
|
||||
r, w := io.Pipe()
|
||||
defer r.Close()
|
||||
go func() {
|
||||
defer w.Close()
|
||||
|
||||
for chunked := int64(0); chunked < chunk; {
|
||||
n, err := io.CopyN(w, sectionReader, 1024*1024)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("error reading chunk: %v", err),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: int(offset),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
chunked += n
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("uploading %s", layer.Digest),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: int(offset) + int(chunked),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/octet-stream")
|
||||
headers.Set("Content-Length", strconv.Itoa(int(chunk)))
|
||||
headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
|
||||
resp, err := makeRequest(ctx, "PATCH", requestURL, headers, r, regOpts)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("error uploading chunk: %v", err),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: int(offset),
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
auth := resp.Header.Get("www-authenticate")
|
||||
authRedir := ParseAuthRedirectString(auth)
|
||||
token, err := getAuthToken(ctx, authRedir, regOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
regOpts.Token = token
|
||||
if _, err := sectionReader.Seek(0, io.SeekStart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
continue
|
||||
case resp.StatusCode >= http.StatusBadRequest:
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
offset += sectionReader.Size()
|
||||
requestURL, err = url.Parse(resp.Header.Get("Location"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
values := requestURL.Query()
|
||||
values.Add("digest", layer.Digest)
|
||||
requestURL.RawQuery = values.Encode()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/octet-stream")
|
||||
headers.Set("Content-Length", "0")
|
||||
|
||||
// finish the upload
|
||||
resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't finish upload: %v", err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
3
version/version.go
Normal file
3
version/version.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package version
|
||||
|
||||
var Version string = "0.0.0"
|
Reference in New Issue
Block a user