Compare commits

..

1 Commits

Author SHA1 Message Date
Matt Williams
e2389b63aa add examples of streaming in python and node
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-09-14 07:12:09 -07:00
36 changed files with 532 additions and 1145 deletions

View File

@@ -1,7 +1,5 @@
.vscode .vscode
ollama ollama
app app
dist
scripts
llm/llama.cpp/ggml llm/llama.cpp/ggml
llm/llama.cpp/gguf llm/llama.cpp/gguf

1
.gitmodules vendored
View File

@@ -6,5 +6,4 @@
[submodule "llm/llama.cpp/gguf"] [submodule "llm/llama.cpp/gguf"]
path = llm/llama.cpp/gguf path = llm/llama.cpp/gguf
url = https://github.com/ggerganov/llama.cpp.git url = https://github.com/ggerganov/llama.cpp.git
ignore = dirty
shallow = true shallow = true

View File

@@ -1,28 +1,18 @@
ARG CUDA_VERSION=12.2.0 FROM golang:alpine
FROM nvidia/cuda:$CUDA_VERSION-devel-ubuntu22.04
ARG TARGETARCH
ARG VERSION=0.0.0
WORKDIR /go/src/github.com/jmorganca/ollama WORKDIR /go/src/github.com/jmorganca/ollama
RUN apt-get update && apt-get install -y git build-essential cmake RUN apk add --no-cache git build-base cmake
ADD https://dl.google.com/go/go1.21.1.linux-$TARGETARCH.tar.gz /tmp/go1.21.1.tar.gz
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.1.tar.gz
COPY . . COPY . .
ENV GOARCH=$TARGETARCH RUN go generate ./... && go build -ldflags '-linkmode external -extldflags "-static"' .
RUN /usr/local/go/bin/go generate ./... \
&& /usr/local/go/bin/go build -ldflags "-linkmode=external -extldflags='-static' -X=github.com/jmorganca/ollama/version.Version=$VERSION -X=github.com/jmorganca/ollama/server.mode=release" .
FROM ubuntu:22.04 FROM alpine
ENV OLLAMA_HOST 0.0.0.0 ENV OLLAMA_HOST 0.0.0.0
RUN apk add --no-cache libstdc++
RUN apt-get update && apt-get install -y ca-certificates
ARG USER=ollama ARG USER=ollama
ARG GROUP=ollama ARG GROUP=ollama
RUN groupadd $GROUP && useradd -m -g $GROUP $USER RUN addgroup $GROUP && adduser -D -G $GROUP $USER
COPY --from=0 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama COPY --from=0 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama

View File

@@ -1,29 +0,0 @@
ARG VERSION=0.0.0
# centos7 amd64 dependencies
FROM --platform=linux/amd64 nvidia/cuda:11.8.0-devel-centos7 AS base-amd64
RUN yum install -y https://repo.ius.io/ius-release-el7.rpm centos-release-scl && \
yum update -y && \
yum install -y devtoolset-10-gcc devtoolset-10-gcc-c++ git236 wget
RUN wget "https://github.com/Kitware/CMake/releases/download/v3.27.6/cmake-3.27.6-linux-x86_64.sh" -O cmake-installer.sh && chmod +x cmake-installer.sh && ./cmake-installer.sh --skip-license --prefix=/usr/local
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
# centos8 arm64 dependencies
FROM --platform=linux/arm64 nvidia/cuda:11.4.3-devel-centos8 AS base-arm64
RUN sed -i -e 's/mirrorlist/#mirrorlist/g' -e 's|#baseurl=http://mirror.centos.org|baseurl=http://vault.centos.org|g' /etc/yum.repos.d/CentOS-*
RUN yum install -y git cmake
FROM base-${TARGETARCH}
ARG TARGETARCH
# install go
ADD https://dl.google.com/go/go1.21.1.linux-$TARGETARCH.tar.gz /tmp/go1.21.1.tar.gz
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.1.tar.gz
# build the final binary
WORKDIR /go/src/github.com/jmorganca/ollama
COPY . .
ENV GOARCH=$TARGETARCH
RUN /usr/local/go/bin/go generate ./... && \
/usr/local/go/bin/go build -ldflags "-X=github.com/jmorganca/ollama/version.Version=$VERSION -X=github.com/jmorganca/ollama/server.mode=release" .

22
Dockerfile.cuda Normal file
View File

@@ -0,0 +1,22 @@
FROM nvidia/cuda:12.2.0-devel-ubuntu22.04
WORKDIR /go/src/github.com/jmorganca/ollama
RUN apt-get update && apt-get install -y git build-essential cmake
ADD https://dl.google.com/go/go1.21.1.linux-amd64.tar.gz /tmp/go1.21.1.tar.gz
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.1.tar.gz
COPY . .
RUN /usr/local/go/bin/go generate ./... && /usr/local/go/bin/go build -ldflags '-linkmode external -extldflags "-static"' .
FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04
ENV OLLAMA_HOST 0.0.0.0
ARG USER=ollama
ARG GROUP=ollama
RUN groupadd $GROUP && useradd -m -g $GROUP $USER
COPY --from=0 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
USER $USER:$GROUP
ENTRYPOINT ["/bin/ollama"]
CMD ["serve"]

View File

@@ -206,16 +206,10 @@ curl -X POST http://localhost:11434/api/generate -d '{
## Community Projects using Ollama ## Community Projects using Ollama
| Project | Description | - [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama) with a question-answering [example](https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa).
| -------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | - [Continue](https://github.com/continuedev/continue) - embeds Ollama inside Visual Studio Code. The extension lets you highlight code to add to the prompt, ask questions in the sidebar, and generate code inline.
| [LangChain][1] and [LangChain.js][2] | Also, there is a question-answering [example][3]. | - [LiteLLM](https://github.com/BerriAI/litellm) a lightweight python package to simplify LLM API calls
| [Continue](https://github.com/continuedev/continue) | Embeds Ollama inside Visual Studio Code. The extension lets you highlight code to add to the prompt, ask questions in the sidebar, and generate code inline. | - [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot) - interact with Ollama as a chatbot on Discord.
| [LiteLLM](https://github.com/BerriAI/litellm) | Lightweight Python package to simplify LLM API calls. | - [Raycast Ollama](https://github.com/MassimilianoPasquini97/raycast_ollama) - Raycast extension to use Ollama for local llama inference on Raycast.
| [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot) | Interact with Ollama as a chatbot on Discord. | - [Simple HTML UI for Ollama](https://github.com/rtcfirefly/ollama-ui)
| [Raycast Ollama](https://github.com/MassimilianoPasquini97/raycast_ollama) | Raycast extension to use Ollama for local llama inference on Raycast. | - [Emacs client](https://github.com/zweifisch/ollama) for Ollama
| [Simple HTML UI](https://github.com/rtcfirefly/ollama-ui) | Also, there is a Chrome extension. |
| [Emacs client](https://github.com/zweifisch/ollama) | |
[1]: https://python.langchain.com/docs/integrations/llms/ollama
[2]: https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama
[3]: https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa

View File

@@ -1,225 +0,0 @@
import os
import json
import requests
BASE_URL = os.environ.get('OLLAMA_HOST', 'http://localhost:11434')
# Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses.
# The final response object will include statistics and additional data from the request. Use the callback function to override
# the default handler.
def generate(model_name, prompt, system=None, template=None, context=None, options=None, callback=None):
try:
url = f"{BASE_URL}/api/generate"
payload = {
"model": model_name,
"prompt": prompt,
"system": system,
"template": template,
"context": context,
"options": options
}
# Remove keys with None values
payload = {k: v for k, v in payload.items() if v is not None}
with requests.post(url, json=payload, stream=True) as response:
response.raise_for_status()
# Creating a variable to hold the context history of the final chunk
final_context = None
# Variable to hold concatenated response strings if no callback is provided
full_response = ""
# Iterating over the response line by line and displaying the details
for line in response.iter_lines():
if line:
# Parsing each line (JSON chunk) and extracting the details
chunk = json.loads(line)
# If a callback function is provided, call it with the chunk
if callback:
callback(chunk)
else:
# If this is not the last chunk, add the "response" field value to full_response and print it
if not chunk.get("done"):
response_piece = chunk.get("response", "")
full_response += response_piece
print(response_piece, end="", flush=True)
# Check if it's the last chunk (done is true)
if chunk.get("done"):
final_context = chunk.get("context")
# Return the full response and the final context
return full_response, final_context
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
return None, None
# Create a model from a Modelfile. Use the callback function to override the default handler.
def create(model_name, model_path, callback=None):
try:
url = f"{BASE_URL}/api/create"
payload = {"name": model_name, "path": model_path}
# Making a POST request with the stream parameter set to True to handle streaming responses
with requests.post(url, json=payload, stream=True) as response:
response.raise_for_status()
# Iterating over the response line by line and displaying the status
for line in response.iter_lines():
if line:
# Parsing each line (JSON chunk) and extracting the status
chunk = json.loads(line)
if callback:
callback(chunk)
else:
print(f"Status: {chunk.get('status')}")
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
# Pull a model from a the model registry. Cancelled pulls are resumed from where they left off, and multiple
# calls to will share the same download progress. Use the callback function to override the default handler.
def pull(model_name, insecure=False, callback=None):
try:
url = f"{BASE_URL}/api/pull"
payload = {
"name": model_name,
"insecure": insecure
}
# Making a POST request with the stream parameter set to True to handle streaming responses
with requests.post(url, json=payload, stream=True) as response:
response.raise_for_status()
# Iterating over the response line by line and displaying the details
for line in response.iter_lines():
if line:
# Parsing each line (JSON chunk) and extracting the details
chunk = json.loads(line)
# If a callback function is provided, call it with the chunk
if callback:
callback(chunk)
else:
# Print the status message directly to the console
print(chunk.get('status', ''), end='', flush=True)
# If there's layer data, you might also want to print that (adjust as necessary)
if 'digest' in chunk:
print(f" - Digest: {chunk['digest']}", end='', flush=True)
print(f" - Total: {chunk['total']}", end='', flush=True)
print(f" - Completed: {chunk['completed']}", end='\n', flush=True)
else:
print()
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
# Push a model to the model registry. Use the callback function to override the default handler.
def push(model_name, insecure=False, callback=None):
try:
url = f"{BASE_URL}/api/push"
payload = {
"name": model_name,
"insecure": insecure
}
# Making a POST request with the stream parameter set to True to handle streaming responses
with requests.post(url, json=payload, stream=True) as response:
response.raise_for_status()
# Iterating over the response line by line and displaying the details
for line in response.iter_lines():
if line:
# Parsing each line (JSON chunk) and extracting the details
chunk = json.loads(line)
# If a callback function is provided, call it with the chunk
if callback:
callback(chunk)
else:
# Print the status message directly to the console
print(chunk.get('status', ''), end='', flush=True)
# If there's layer data, you might also want to print that (adjust as necessary)
if 'digest' in chunk:
print(f" - Digest: {chunk['digest']}", end='', flush=True)
print(f" - Total: {chunk['total']}", end='', flush=True)
print(f" - Completed: {chunk['completed']}", end='\n', flush=True)
else:
print()
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
# List models that are available locally.
def list():
try:
response = requests.get(f"{BASE_URL}/api/tags")
response.raise_for_status()
data = response.json()
models = data.get('models', [])
return models
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
return None
# Copy a model. Creates a model with another name from an existing model.
def copy(source, destination):
try:
# Create the JSON payload
payload = {
"source": source,
"destination": destination
}
response = requests.post(f"{BASE_URL}/api/copy", json=payload)
response.raise_for_status()
# If the request was successful, return a message indicating that the copy was successful
return "Copy successful"
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
return None
# Delete a model and its data.
def delete(model_name):
try:
url = f"{BASE_URL}/api/delete"
payload = {"name": model_name}
response = requests.delete(url, json=payload)
response.raise_for_status()
return "Delete successful"
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
return None
# Show info about a model.
def show(model_name):
try:
url = f"{BASE_URL}/api/show"
payload = {"name": model_name}
response = requests.post(url, json=payload)
response.raise_for_status()
# Parse the JSON response and return it
data = response.json()
return data
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
return None
def heartbeat():
try:
url = f"{BASE_URL}/"
response = requests.head(url)
response.raise_for_status()
return "Ollama is running"
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
return "Ollama is not running"

View File

@@ -11,19 +11,20 @@ import (
"io" "io"
"log" "log"
"net" "net"
"net/http"
"os" "os"
"os/exec" "os/exec"
"path"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"time" "time"
"github.com/chzyer/readline"
"github.com/dustin/go-humanize" "github.com/dustin/go-humanize"
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
"github.com/pdevine/readline"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/term"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format" "github.com/jmorganca/ollama/format"
@@ -32,19 +33,6 @@ import (
"github.com/jmorganca/ollama/version" "github.com/jmorganca/ollama/version"
) )
type Painter struct {
HideHint bool
}
func (p Painter) Paint(line []rune, _ int) []rune {
termType := os.Getenv("TERM")
if termType == "xterm-256color" && len(line) == 0 && !p.HideHint {
prompt := "Send a message (/? for help)"
return []rune(fmt.Sprintf("\033[38;5;245m%s\033[%dD\033[0m", prompt, len(prompt)))
}
return line
}
func CreateHandler(cmd *cobra.Command, args []string) error { func CreateHandler(cmd *cobra.Command, args []string) error {
filename, _ := cmd.Flags().GetString("file") filename, _ := cmd.Flags().GetString("file")
filename, err := filepath.Abs(filename) filename, err := filepath.Abs(filename)
@@ -110,28 +98,39 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
func RunHandler(cmd *cobra.Command, args []string) error { func RunHandler(cmd *cobra.Command, args []string) error {
client, err := api.FromEnv() insecure, err := cmd.Flags().GetBool("insecure")
if err != nil { if err != nil {
return err return err
} }
models, err := client.List(context.Background()) mp := server.ParseModelPath(args[0])
if err != nil { if err != nil {
return err return err
} }
modelName, modelTag, ok := strings.Cut(args[0], ":") if mp.ProtocolScheme == "http" && !insecure {
if !ok { return fmt.Errorf("insecure protocol http")
modelTag = "latest"
} }
for _, model := range models.Models { fp, err := mp.GetManifestPath(false)
if model.Name == strings.Join([]string{modelName, modelTag}, ":") { if err != nil {
return RunGenerate(cmd, args) return err
}
} }
if err := PullHandler(cmd, args); err != nil { _, err = os.Stat(fp)
switch {
case errors.Is(err, os.ErrNotExist):
if err := pull(args[0], insecure); err != nil {
var apiStatusError api.StatusError
if !errors.As(err, &apiStatusError) {
return err
}
if apiStatusError.StatusCode != http.StatusBadGateway {
return err
}
}
case err != nil:
return err return err
} }
@@ -388,6 +387,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
type generateContextKey string type generateContextKey string
func generate(cmd *cobra.Command, model, prompt string) error { func generate(cmd *cobra.Command, model, prompt string) error {
if len(strings.TrimSpace(prompt)) > 0 {
client, err := api.FromEnv() client, err := api.FromEnv()
if err != nil { if err != nil {
return err return err
@@ -403,29 +403,6 @@ func generate(cmd *cobra.Command, model, prompt string) error {
generateContext = []int{} generateContext = []int{}
} }
var wrapTerm bool
termType := os.Getenv("TERM")
if termType == "xterm-256color" {
wrapTerm = true
}
termWidth, _, err := term.GetSize(int(0))
if err != nil {
wrapTerm = false
}
// override wrapping if the user turned it off
nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil {
return err
}
if nowrap {
wrapTerm = false
}
var currentLineLength int
var wordBuffer string
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext} request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
fn := func(response api.GenerateResponse) error { fn := func(response api.GenerateResponse) error {
if !spinner.IsFinished() { if !spinner.IsFinished() {
@@ -434,31 +411,7 @@ func generate(cmd *cobra.Command, model, prompt string) error {
latest = response latest = response
if wrapTerm {
for _, ch := range response.Response {
if currentLineLength+1 > termWidth-5 {
// backtrack the length of the last word and clear to the end of the line
fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer))
fmt.Printf("%s%c", wordBuffer, ch)
currentLineLength = len(wordBuffer) + 1
} else {
fmt.Print(string(ch))
currentLineLength += 1
switch ch {
case ' ':
wordBuffer = ""
case '\n':
currentLineLength = 0
default:
wordBuffer += string(ch)
}
}
}
} else {
fmt.Print(response.Response) fmt.Print(response.Response)
}
return nil return nil
} }
@@ -477,10 +430,9 @@ func generate(cmd *cobra.Command, model, prompt string) error {
} }
return err return err
} }
if prompt != "" {
fmt.Println() fmt.Println()
fmt.Println() fmt.Println()
}
if !latest.Done { if !latest.Done {
return errors.New("unexpected end of response") return errors.New("unexpected end of response")
@@ -498,6 +450,7 @@ func generate(cmd *cobra.Command, model, prompt string) error {
ctx := cmd.Context() ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context) ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
cmd.SetContext(ctx) cmd.SetContext(ctx)
}
return nil return nil
} }
@@ -508,21 +461,19 @@ func generateInteractive(cmd *cobra.Command, model string) error {
return err return err
} }
// load the model
if err := generate(cmd, model, ""); err != nil {
return err
}
completer := readline.NewPrefixCompleter( completer := readline.NewPrefixCompleter(
readline.PcItem("/help"), readline.PcItem("/help"),
readline.PcItem("/list"), readline.PcItem("/list"),
readline.PcItem("/set", readline.PcItem("/set",
readline.PcItem("history"), readline.PcItem("history"),
readline.PcItem("nohistory"), readline.PcItem("nohistory"),
readline.PcItem("wordwrap"),
readline.PcItem("nowordwrap"),
readline.PcItem("verbose"), readline.PcItem("verbose"),
readline.PcItem("quiet"), readline.PcItem("quiet"),
readline.PcItem("mode",
readline.PcItem("vim"),
readline.PcItem("emacs"),
readline.PcItem("default"),
),
), ),
readline.PcItem("/show", readline.PcItem("/show",
readline.PcItem("license"), readline.PcItem("license"),
@@ -540,10 +491,7 @@ func generateInteractive(cmd *cobra.Command, model string) error {
fmt.Fprintln(os.Stderr, completer.Tree(" ")) fmt.Fprintln(os.Stderr, completer.Tree(" "))
} }
var painter Painter
config := readline.Config{ config := readline.Config{
Painter: &painter,
Prompt: ">>> ", Prompt: ">>> ",
HistoryFile: filepath.Join(home, ".ollama", "history"), HistoryFile: filepath.Join(home, ".ollama", "history"),
AutoComplete: completer, AutoComplete: completer,
@@ -579,7 +527,6 @@ func generateInteractive(cmd *cobra.Command, model string) error {
case isMultiLine: case isMultiLine:
if strings.HasSuffix(line, `"""`) { if strings.HasSuffix(line, `"""`) {
isMultiLine = false isMultiLine = false
painter.HideHint = false
multiLineBuffer += strings.TrimSuffix(line, `"""`) multiLineBuffer += strings.TrimSuffix(line, `"""`)
line = multiLineBuffer line = multiLineBuffer
multiLineBuffer = "" multiLineBuffer = ""
@@ -592,49 +539,51 @@ func generateInteractive(cmd *cobra.Command, model string) error {
isMultiLine = true isMultiLine = true
multiLineBuffer = strings.TrimPrefix(line, `"""`) + " " multiLineBuffer = strings.TrimPrefix(line, `"""`) + " "
scanner.SetPrompt("... ") scanner.SetPrompt("... ")
painter.HideHint = true
continue continue
case strings.HasPrefix(line, "/list"): case strings.HasPrefix(line, "/list"):
args := strings.Fields(line) args := strings.Fields(line)
if err := ListHandler(cmd, args[1:]); err != nil { if err := ListHandler(cmd, args[1:]); err != nil {
return err return err
} }
continue
case strings.HasPrefix(line, "/set"): case strings.HasPrefix(line, "/set"):
args := strings.Fields(line) args := strings.Fields(line)
if len(args) > 1 { if len(args) > 1 {
switch args[1] { switch args[1] {
case "history": case "history":
scanner.HistoryEnable() scanner.HistoryEnable()
continue
case "nohistory": case "nohistory":
scanner.HistoryDisable() scanner.HistoryDisable()
case "wordwrap": continue
cmd.Flags().Set("nowordwrap", "false")
fmt.Println("Set 'wordwrap' mode.")
case "nowordwrap":
cmd.Flags().Set("nowordwrap", "true")
fmt.Println("Set 'nowordwrap' mode.")
case "verbose": case "verbose":
cmd.Flags().Set("verbose", "true") cmd.Flags().Set("verbose", "true")
fmt.Println("Set 'verbose' mode.") continue
case "quiet": case "quiet":
cmd.Flags().Set("verbose", "false") cmd.Flags().Set("verbose", "false")
fmt.Println("Set 'quiet' mode.") continue
case "mode": case "mode":
if len(args) > 2 { if len(args) > 2 {
switch args[2] { switch args[2] {
case "vim": case "vim":
scanner.SetVimMode(true) scanner.SetVimMode(true)
continue
case "emacs", "default": case "emacs", "default":
scanner.SetVimMode(false) scanner.SetVimMode(false)
continue
default: default:
usage() usage()
continue
} }
} else { } else {
usage() usage()
continue
} }
} }
} else { } else {
usage() usage()
continue
} }
case strings.HasPrefix(line, "/show"): case strings.HasPrefix(line, "/show"):
args := strings.Fields(line) args := strings.Fields(line)
@@ -642,6 +591,7 @@ func generateInteractive(cmd *cobra.Command, model string) error {
resp, err := server.GetModelInfo(model) resp, err := server.GetModelInfo(model)
if err != nil { if err != nil {
fmt.Println("error: couldn't get model") fmt.Println("error: couldn't get model")
continue
} }
switch args[1] { switch args[1] {
@@ -658,24 +608,23 @@ func generateInteractive(cmd *cobra.Command, model string) error {
default: default:
fmt.Println("error: unknown command") fmt.Println("error: unknown command")
} }
continue
} else { } else {
usage() usage()
continue
} }
case line == "/help", line == "/?": case line == "/help", line == "/?":
usage() usage()
continue
case line == "/exit", line == "/bye": case line == "/exit", line == "/bye":
return nil return nil
case strings.HasPrefix(line, "/"):
args := strings.Fields(line)
fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
} }
if len(line) > 0 && line[0] != '/' {
if err := generate(cmd, model, line); err != nil { if err := generate(cmd, model, line); err != nil {
return err return err
} }
} }
}
} }
func generateBatch(cmd *cobra.Command, model string) error { func generateBatch(cmd *cobra.Command, model string) error {
@@ -692,19 +641,28 @@ func generateBatch(cmd *cobra.Command, model string) error {
} }
func RunServer(cmd *cobra.Command, _ []string) error { func RunServer(cmd *cobra.Command, _ []string) error {
host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST")) host, port := "127.0.0.1", "11434"
if err != nil {
host, port = "127.0.0.1", "11434" parts := strings.Split(os.Getenv("OLLAMA_HOST"), ":")
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil { if ip := net.ParseIP(parts[0]); ip != nil {
host = ip.String() host = ip.String()
} }
if len(parts) > 1 {
port = parts[1]
} }
if err := initializeKeypair(); err != nil { // deprecated: include port in OLLAMA_HOST
if p := os.Getenv("OLLAMA_PORT"); p != "" {
port = p
}
err := initializeKeypair()
if err != nil {
return err return err
} }
ln, err := net.Listen("tcp", net.JoinHostPort(host, port)) ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port))
if err != nil { if err != nil {
return err return err
} }
@@ -745,7 +703,7 @@ func initializeKeypair() error {
return err return err
} }
err = os.MkdirAll(filepath.Dir(privKeyPath), 0o755) err = os.MkdirAll(path.Dir(privKeyPath), 0o700)
if err != nil { if err != nil {
return fmt.Errorf("could not create directory %w", err) return fmt.Errorf("could not create directory %w", err)
} }
@@ -873,7 +831,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("verbose", false, "Show timings for response") runCmd.Flags().Bool("verbose", false, "Show timings for response")
runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
serveCmd := &cobra.Command{ serveCmd := &cobra.Command{
Use: "serve", Use: "serve",

View File

@@ -3,29 +3,30 @@
## Endpoints ## Endpoints
- [Generate a completion](#generate-a-completion) - [Generate a completion](#generate-a-completion)
- [Create a Model](#create-a-model) - [Create a model](#create-a-model)
- [List Local Models](#list-local-models) - [List local models](#list-local-models)
- [Show Model Information](#show-model-information) - [Copy a model](#copy-a-model)
- [Copy a Model](#copy-a-model) - [Delete a model](#delete-a-model)
- [Delete a Model](#delete-a-model) - [Pull a model](#pull-a-model)
- [Pull a Model](#pull-a-model) - [Generate embeddings](#generate-embeddings)
- [Push a Model](#push-a-model)
- [Generate Embeddings](#generate-embeddings)
## Conventions ## Conventions
### Model names ### Model names
Model names follow a `model:tag` format. Some examples are `orca-mini:3b-q4_1` and `llama2:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version. Model names follow a `model:tag` format. Some examples are `orca-mini:3b-q4_1` and `llama2:70b`. The tag is optional and if not provided will default to `latest`. The tag is used to identify a specific version.
### Durations ### Durations
All durations are returned in nanoseconds. All durations are returned in nanoseconds.
### Streams
Many API responses are streams of JSON objects showing the current status. For examples of working with streams in various languages, see [streaming.md](./streaming.md)
## Generate a completion ## Generate a completion
```shell ```
POST /api/generate POST /api/generate
``` ```
@@ -45,7 +46,7 @@ Advanced parameters:
### Request ### Request
```shell ```
curl -X POST http://localhost:11434/api/generate -d '{ curl -X POST http://localhost:11434/api/generate -d '{
"model": "llama2:7b", "model": "llama2:7b",
"prompt": "Why is the sky blue?" "prompt": "Why is the sky blue?"
@@ -98,7 +99,7 @@ To calculate how fast the response is generated in tokens per second (token/s),
## Create a Model ## Create a Model
```shell ```
POST /api/create POST /api/create
``` ```
@@ -111,7 +112,7 @@ Create a model from a [`Modelfile`](./modelfile.md)
### Request ### Request
```shell ```
curl -X POST http://localhost:11434/api/create -d '{ curl -X POST http://localhost:11434/api/create -d '{
"name": "mario", "name": "mario",
"path": "~/Modelfile" "path": "~/Modelfile"
@@ -120,7 +121,7 @@ curl -X POST http://localhost:11434/api/create -d '{
### Response ### Response
A stream of JSON objects. When finished, `status` is `success`. A stream of JSON objects. When finished, `status` is `success`
```json ```json
{ {
@@ -130,7 +131,7 @@ A stream of JSON objects. When finished, `status` is `success`.
## List Local Models ## List Local Models
```shell ```
GET /api/tags GET /api/tags
``` ```
@@ -138,7 +139,7 @@ List models that are available locally.
### Request ### Request
```shell ```
curl http://localhost:11434/api/tags curl http://localhost:11434/api/tags
``` ```
@@ -161,40 +162,9 @@ curl http://localhost:11434/api/tags
} }
``` ```
## Show Model Information
```shell
POST /api/show
```
Show details about a model including modelfile, template, parameters, license, and system prompt.
### Parameters
- `name`: name of the model to show
### Request
```shell
curl http://localhost:11434/api/show -d '{
"name": "llama2:7b"
}'
```
### Response
```json
{
"license": "<contents of license block>",
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llama2:latest\n\nFROM /Users/username/.ollama/models/blobs/sha256:8daa9615cce30c259a9555b1cc250d461d1bc69980a274b44d7eda0be78076d8\nTEMPLATE \"\"\"[INST] {{ if and .First .System }}<<SYS>>{{ .System }}<</SYS>>\n\n{{ end }}{{ .Prompt }} [/INST] \"\"\"\nSYSTEM \"\"\"\"\"\"\nPARAMETER stop [INST]\nPARAMETER stop [/INST]\nPARAMETER stop <<SYS>>\nPARAMETER stop <</SYS>>\n",
"parameters": "stop [INST]\nstop [/INST]\nstop <<SYS>>\nstop <</SYS>>",
"template": "[INST] {{ if and .First .System }}<<SYS>>{{ .System }}<</SYS>>\n\n{{ end }}{{ .Prompt }} [/INST] "
}
```
## Copy a Model ## Copy a Model
```shell ```
POST /api/copy POST /api/copy
``` ```
@@ -202,7 +172,7 @@ Copy a model. Creates a model with another name from an existing model.
### Request ### Request
```shell ```
curl http://localhost:11434/api/copy -d '{ curl http://localhost:11434/api/copy -d '{
"source": "llama2:7b", "source": "llama2:7b",
"destination": "llama2-backup" "destination": "llama2-backup"
@@ -211,7 +181,7 @@ curl http://localhost:11434/api/copy -d '{
## Delete a Model ## Delete a Model
```shell ```
DELETE /api/delete DELETE /api/delete
``` ```
@@ -223,7 +193,7 @@ Delete a model and its data.
### Request ### Request
```shell ```
curl -X DELETE http://localhost:11434/api/delete -d '{ curl -X DELETE http://localhost:11434/api/delete -d '{
"name": "llama2:13b" "name": "llama2:13b"
}' }'
@@ -231,20 +201,19 @@ curl -X DELETE http://localhost:11434/api/delete -d '{
## Pull a Model ## Pull a Model
```shell ```
POST /api/pull POST /api/pull
``` ```
Download a model from the ollama library. Cancelled pulls are resumed from where they left off, and multiple calls will share the same download progress. Download a model from a the model registry. Cancelled pulls are resumed from where they left off, and multiple calls to will share the same download progress.
### Parameters ### Parameters
- `name`: name of the model to pull - `name`: name of the model to pull
- `insecure`: (optional) allow insecure connections to the library. Only use this if you are pulling from your own library during development.
### Request ### Request
```shell ```
curl -X POST http://localhost:11434/api/pull -d '{ curl -X POST http://localhost:11434/api/pull -d '{
"name": "llama2:7b" "name": "llama2:7b"
}' }'
@@ -260,63 +229,9 @@ curl -X POST http://localhost:11434/api/pull -d '{
} }
``` ```
## Push a Model
```shell
POST /api/push
```
Upload a model to a model library. Requires registering for ollama.ai and adding a public key first.
### Parameters
- `name`: name of the model to push in the form of `<namespace>/<model>:<tag>`
- `insecure`: (optional) allow insecure connections to the library. Only use this if you are pushing to your library during development.
### Request
```shell
curl -X POST http://localhost:11434/api/push -d '{
"name": "mattw/pygmalion:latest"
}'
```
### Response
Streaming response that starts with:
```json
{"status":"retrieving manifest"}
```
and then:
```json
{
"status":"starting upload","digest":"sha256:bc07c81de745696fdf5afca05e065818a8149fb0c77266fb584d9b2cba3711ab",
"total":1928429856
}
```
Then there is a series of uploading responses:
```json
{
"status":"starting upload",
"digest":"sha256:bc07c81de745696fdf5afca05e065818a8149fb0c77266fb584d9b2cba3711ab",
"total":1928429856}
```
Finally, when the upload is complete:
```json
{"status":"pushing manifest"}
{"status":"success"}
```
## Generate Embeddings ## Generate Embeddings
```shell ```
POST /api/embeddings POST /api/embeddings
``` ```
@@ -333,7 +248,7 @@ Advanced parameters:
### Request ### Request
```shell ```
curl -X POST http://localhost:11434/api/embeddings -d '{ curl -X POST http://localhost:11434/api/embeddings -d '{
"model": "llama2:7b", "model": "llama2:7b",
"prompt": "Here is an article about llamas..." "prompt": "Here is an article about llamas..."
@@ -348,4 +263,5 @@ curl -X POST http://localhost:11434/api/embeddings -d '{
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313, 0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281 0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
] ]
}``` }
```

35
docs/streaming.md Normal file
View File

@@ -0,0 +1,35 @@
# Streaming responses in the Ollama Client API
## JavaScript / TypeScript / Deno
```javascript
const pull = async () => {
const request = await fetch("http://localhost:11434/api/pull", {
method: "POST",
body: JSON.stringify({ name: "llama2:7b-q5_0" }),
});
const reader = await request.body?.pipeThrough(new TextDecoderStream());
if (!reader) throw new Error("No reader");
for await (const chunk of reader) {
const out = JSON.parse(chunk);
if (out.status.startsWith("downloading")) {
console.log(`${out.status} - ${(out.completed / out.total) * 100}%`);
}
}
}
pull();
```
## Python
```python
import requests
import json
response = requests.post("http://localhost:11434/api/pull", json={"name": "llama2:7b-q5_0"}, stream=True)
for data in response.iter_lines():
out = json.loads(data)
if "completed" in out:
print(out["completed"] / out["total"] * 100)
```

2
go.mod
View File

@@ -8,7 +8,6 @@ require (
github.com/mattn/go-runewidth v0.0.14 github.com/mattn/go-runewidth v0.0.14
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db
github.com/olekukonko/tablewriter v0.0.5 github.com/olekukonko/tablewriter v0.0.5
github.com/pdevine/readline v1.5.2
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
) )
@@ -17,6 +16,7 @@ require github.com/rivo/uniseg v0.2.0 // indirect
require ( require (
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/chzyer/readline v1.5.1
github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect

5
go.sum
View File

@@ -6,6 +6,8 @@ github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhD
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ=
github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI=
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
@@ -78,8 +80,6 @@ github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0=
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y=
github.com/pdevine/readline v1.5.2 h1:oz6Y5GdTmhPG+08hhxcAvtHitSANWuA2100Sppb38xI=
github.com/pdevine/readline v1.5.2/go.mod h1:na/LbuE5PYwxI7GyopWdIs3U8HVe89lYlNTFTXH3wOw=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
@@ -120,6 +120,7 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=

View File

@@ -4,6 +4,8 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"path"
"sync"
) )
type GGML struct { type GGML struct {
@@ -164,6 +166,23 @@ func (c *containerLORA) Decode(r io.Reader) (model, error) {
return nil, nil return nil, nil
} }
var (
ggmlGPU = path.Join("llama.cpp", "ggml", "build", "gpu", "bin")
ggmlCPU = path.Join("llama.cpp", "ggml", "build", "cpu", "bin")
)
var (
ggmlInit sync.Once
ggmlRunnerPath string
)
func ggmlRunner() ModelRunner {
ggmlInit.Do(func() {
ggmlRunnerPath = chooseRunner(ggmlGPU, ggmlCPU)
})
return ModelRunner{Path: ggmlRunnerPath}
}
const ( const (
// Magic constant for `ggml` files (unversioned). // Magic constant for `ggml` files (unversioned).
FILE_MAGIC_GGML = 0x67676d6c FILE_MAGIC_GGML = 0x67676d6c

View File

@@ -6,6 +6,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"path"
"sync"
) )
type containerGGUF struct { type containerGGUF struct {
@@ -367,3 +369,21 @@ func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) {
return return
} }
var (
ggufGPU = path.Join("llama.cpp", "gguf", "build", "gpu", "bin")
ggufCPU = path.Join("llama.cpp", "gguf", "build", "cpu", "bin")
)
var (
ggufInit sync.Once
ggufRunnerPath string
)
func ggufRunner() ModelRunner {
ggufInit.Do(func() {
ggufRunnerPath = chooseRunner(ggufGPU, ggufCPU)
})
return ModelRunner{Path: ggufRunnerPath}
}

View File

@@ -1,14 +1,17 @@
//go:build !darwin
// +build !darwin
package llm package llm
//go:generate git submodule init //go:generate git submodule init
//go:generate git submodule update --force ggml //go:generate git submodule update --force ggml
//go:generate git -C ggml apply ../patches/0001-add-detokenize-endpoint.patch //go:generate -command git-apply git -C ggml apply
//go:generate git -C ggml apply ../patches/0002-34B-model-support.patch //go:generate git-apply ../ggml_patch/0001-add-detokenize-endpoint.patch
//go:generate git-apply ../ggml_patch/0002-34B-model-support.patch
//go:generate cmake -S ggml -B ggml/build/cpu -DLLAMA_K_QUANTS=on //go:generate cmake -S ggml -B ggml/build/cpu -DLLAMA_K_QUANTS=on
//go:generate cmake --build ggml/build/cpu --target server --config Release //go:generate cmake --build ggml/build/cpu --target server --config Release
//go:generate git submodule update --force gguf //go:generate git submodule update --force gguf
//go:generate git -C gguf apply ../patches/0001-remove-warm-up-logging.patch
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on //go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on
//go:generate cmake --build gguf/build/cpu --target server --config Release //go:generate cmake --build gguf/build/cpu --target server --config Release

View File

@@ -3,14 +3,14 @@ package llm
//go:generate git submodule init //go:generate git submodule init
//go:generate git submodule update --force ggml //go:generate git submodule update --force ggml
//go:generate git -C ggml apply ../patches/0001-add-detokenize-endpoint.patch //go:generate -command git-apply git -C ggml apply
//go:generate git -C ggml apply ../patches/0002-34B-model-support.patch //go:generate git-apply ../ggml_patch/0001-add-detokenize-endpoint.patch
//go:generate git -C ggml apply ../patches/0003-metal-fix-synchronization-in-new-matrix-multiplicati.patch //go:generate git-apply ../ggml_patch/0002-34B-model-support.patch
//go:generate git -C ggml apply ../patches/0004-metal-add-missing-barriers-for-mul-mat-2699.patch //go:generate git-apply ../ggml_patch/0003-metal-fix-synchronization-in-new-matrix-multiplicati.patch
//go:generate git-apply ../ggml_patch/0004-metal-add-missing-barriers-for-mul-mat-2699.patch
//go:generate cmake -S ggml -B ggml/build/cpu -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 //go:generate cmake -S ggml -B ggml/build/cpu -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
//go:generate cmake --build ggml/build/cpu --target server --config Release //go:generate cmake --build ggml/build/cpu --target server --config Release
//go:generate git submodule update --force gguf //go:generate git submodule update --force gguf
//go:generate git -C gguf apply ../patches/0001-remove-warm-up-logging.patch
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 //go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
//go:generate cmake --build gguf/build/cpu --target server --config Release //go:generate cmake --build gguf/build/cpu --target server --config Release

View File

@@ -3,14 +3,14 @@ package llm
//go:generate git submodule init //go:generate git submodule init
//go:generate git submodule update --force ggml //go:generate git submodule update --force ggml
//go:generate git -C ggml apply ../patches/0001-add-detokenize-endpoint.patch //go:generate -command git-apply git -C ggml apply
//go:generate git -C ggml apply ../patches/0002-34B-model-support.patch //go:generate git-apply ../ggml_patch/0001-add-detokenize-endpoint.patch
//go:generate git -C ggml apply ../patches/0003-metal-fix-synchronization-in-new-matrix-multiplicati.patch //go:generate git-apply ../ggml_patch/0002-34B-model-support.patch
//go:generate git -C ggml apply ../patches/0004-metal-add-missing-barriers-for-mul-mat-2699.patch //go:generate git-apply ../ggml_patch/0003-metal-fix-synchronization-in-new-matrix-multiplicati.patch
//go:generate cmake -S ggml -B ggml/build/metal -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 //go:generate git-apply ../ggml_patch/0004-metal-add-missing-barriers-for-mul-mat-2699.patch
//go:generate cmake --build ggml/build/metal --target server --config Release //go:generate cmake -S ggml -B ggml/build/gpu -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
//go:generate cmake --build ggml/build/gpu --target server --config Release
//go:generate git submodule update --force gguf //go:generate git submodule update --force gguf
//go:generate git -C gguf apply ../patches/0001-remove-warm-up-logging.patch //go:generate cmake -S gguf -B gguf/build/gpu -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
//go:generate cmake -S gguf -B gguf/build/metal -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 //go:generate cmake --build gguf/build/gpu --target server --config Release
//go:generate cmake --build gguf/build/metal --target server --config Release

View File

@@ -3,20 +3,13 @@ package llm
//go:generate git submodule init //go:generate git submodule init
//go:generate git submodule update --force ggml //go:generate git submodule update --force ggml
//go:generate git -C ggml apply ../patches/0001-add-detokenize-endpoint.patch //go:generate -command git-apply git -C ggml apply
//go:generate git -C ggml apply ../patches/0002-34B-model-support.patch //go:generate git-apply ../ggml_patch/0001-add-detokenize-endpoint.patch
//go:generate git -C ggml apply ../patches/0005-ggml-support-CUDA-s-half-type-for-aarch64-1455-2670.patch //go:generate git-apply ../ggml_patch/0002-34B-model-support.patch
//go:generate git -C ggml apply ../patches/0001-copy-cuda-runtime-libraries.patch //go:generate git-apply ../ggml_patch/0005-ggml-support-CUDA-s-half-type-for-aarch64-1455-2670.patch
//go:generate cmake -S ggml -B ggml/build/cpu -DLLAMA_K_QUANTS=on //go:generate cmake -S ggml -B ggml/build/gpu -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on
//go:generate cmake --build ggml/build/cpu --target server --config Release //go:generate cmake --build ggml/build/gpu --target server --config Release
//go:generate git submodule update --force gguf //go:generate git submodule update --force gguf
//go:generate git -C gguf apply ../patches/0001-copy-cuda-runtime-libraries.patch //go:generate cmake -S gguf -B gguf/build/gpu -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on
//go:generate git -C gguf apply ../patches/0001-remove-warm-up-logging.patch //go:generate cmake --build gguf/build/gpu --target server --config Release
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on
//go:generate cmake --build gguf/build/cpu --target server --config Release
//go:generate cmake -S ggml -B ggml/build/cuda -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on
//go:generate cmake --build ggml/build/cuda --target server --config Release
//go:generate cmake -S gguf -B gguf/build/cuda -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on
//go:generate cmake --build gguf/build/cuda --target server --config Release

View File

@@ -0,0 +1,32 @@
From 8c0ea847ac1460bca534d92266e3471cb31471be Mon Sep 17 00:00:00 2001
From: Bruce MacDonald <brucewmacdonald@gmail.com>
Date: Tue, 5 Sep 2023 16:05:08 -0400
Subject: [PATCH] metal: add missing barriers for mul-mat #2699
---
ggml-metal.metal | 2 ++
1 file changed, 2 insertions(+)
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 3f31252..ce3541f 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -1850,6 +1850,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
//load data and store to threadgroup memory
half4x4 temp_a;
dequantize_func(x, il, temp_a);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll(16)
for (int i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
@@ -1895,6 +1896,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
+ threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
for (int i = 0; i < 8; i++) {
--
2.39.2 (Apple Git-143)

View File

@@ -1,27 +0,0 @@
From 5dd02993e8cc2ce309157736b95bb572f274a3fd Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Wed, 20 Sep 2023 14:19:52 -0700
Subject: [PATCH] copy cuda runtime libraries
---
CMakeLists.txt | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 824d9f2..dd24137 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -274,6 +274,10 @@ if (LLAMA_CUBLAS)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()
+ configure_file(${CUDAToolkit_LIBRARY_DIR}/libcudart.so ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/libcudart.so.${CUDAToolkit_VERSION_MAJOR}.0 COPYONLY)
+ configure_file(${CUDAToolkit_LIBRARY_DIR}/libcublas.so ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/libcublas.so.${CUDAToolkit_VERSION_MAJOR} COPYONLY)
+ configure_file(${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/libcublasLt.so.${CUDAToolkit_VERSION_MAJOR} COPYONLY)
+
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
# 52 == lowest CUDA 12 standard
# 60 == f16 CUDA intrinsics
--
2.42.0

View File

@@ -1,25 +0,0 @@
From 07993bdc35345b67b27aa649a7c099ad42d80c4c Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Thu, 21 Sep 2023 14:43:21 -0700
Subject: [PATCH] remove warm up logging
---
common/common.cpp | 2 --
1 file changed, 2 deletions(-)
diff --git a/common/common.cpp b/common/common.cpp
index 2597ba0..b56549b 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -780,8 +780,6 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}
{
- LOG("warming up the model with an empty run\n");
-
const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
llama_reset_timings(lctx);
--
2.42.0

View File

@@ -28,93 +28,71 @@ import (
//go:embed llama.cpp/*/build/*/bin/* //go:embed llama.cpp/*/build/*/bin/*
var llamaCppEmbed embed.FS var llamaCppEmbed embed.FS
type ModelRunner struct { func osPath(llamaPath string) string {
Path string // path to the model runner executable if runtime.GOOS == "windows" {
return path.Join(llamaPath, "Release")
}
return llamaPath
} }
func chooseRunners(workDir, runnerType string) []ModelRunner { func chooseRunner(gpuPath, cpuPath string) string {
buildPath := path.Join("llama.cpp", runnerType, "build") tmpDir, err := os.MkdirTemp("", "llama-*")
var runners []string if err != nil {
log.Fatalf("llama.cpp: failed to create temp dir: %v", err)
}
// set the runners based on the OS llamaPath := osPath(gpuPath)
// IMPORTANT: the order of the runners in the array is the priority order if _, err := fs.Stat(llamaCppEmbed, llamaPath); err != nil {
llamaPath = osPath(cpuPath)
if _, err := fs.Stat(llamaCppEmbed, llamaPath); err != nil {
log.Fatalf("llama.cpp executable not found")
}
}
files := []string{"server"}
switch runtime.GOOS { switch runtime.GOOS {
case "windows":
files = []string{"server.exe"}
case "darwin": case "darwin":
runners = []string{ if llamaPath == osPath(gpuPath) {
path.Join(buildPath, "metal", "bin", "server"), files = append(files, "ggml-metal.metal")
path.Join(buildPath, "cpu", "bin", "server"),
} }
case "linux": case "linux":
runners = []string{ // check if there is a GPU available
path.Join(buildPath, "cuda", "bin", "server"), if _, err := CheckVRAM(); errors.Is(err, errNoGPU) {
path.Join(buildPath, "cpu", "bin", "server"), // this error was logged on start-up, so we don't need to log it again
} llamaPath = osPath(cpuPath)
case "windows":
// TODO: select windows GPU runner here when available
runners = []string{
path.Join(buildPath, "cpu", "bin", "Release", "server.exe"),
}
default:
log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
runners = []string{
path.Join(buildPath, "cpu", "bin", "server"),
} }
} }
runnerAvailable := false // if no runner files are found in the embed, this flag will cause a fast fail
for _, r := range runners {
// find all the files in the runner's bin directory
files, err := fs.Glob(llamaCppEmbed, filepath.Join(filepath.Dir(r), "*"))
if err != nil {
// this is expected, ollama may be compiled without all runners packed in
log.Printf("%s runner not found: %v", r, err)
continue
}
runnerAvailable = true
for _, f := range files { for _, f := range files {
srcFile, err := llamaCppEmbed.Open(f) srcPath := path.Join(llamaPath, f)
destPath := filepath.Join(tmpDir, f)
srcFile, err := llamaCppEmbed.Open(srcPath)
if err != nil { if err != nil {
log.Fatalf("read llama runner %s: %v", f, err) log.Fatalf("read llama.cpp %s: %v", f, err)
} }
defer srcFile.Close() defer srcFile.Close()
// create the directory in case it does not exist destFile, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
destPath := filepath.Join(workDir, filepath.Dir(f))
if err := os.MkdirAll(destPath, 0o755); err != nil {
log.Fatalf("create runner temp dir %s: %v", filepath.Dir(f), err)
}
destFile := filepath.Join(destPath, filepath.Base(f))
_, err = os.Stat(destFile)
switch {
case errors.Is(err, os.ErrNotExist):
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil { if err != nil {
log.Fatalf("write llama runner %s: %v", f, err) log.Fatalf("write llama.cpp %s: %v", f, err)
} }
defer destFile.Close() defer destFile.Close()
if _, err := io.Copy(destFile, srcFile); err != nil { if _, err := io.Copy(destFile, srcFile); err != nil {
log.Fatalf("copy llama runner %s: %v", f, err) log.Fatalf("copy llama.cpp %s: %v", f, err)
} }
case err != nil:
log.Fatalf("stat llama runner %s: %v", f, err)
}
}
}
if !runnerAvailable {
log.Fatalf("%s runner not found", runnerType)
} }
// return the runners to try in priority order runPath := filepath.Join(tmpDir, "server")
localRunnersByPriority := []ModelRunner{} if runtime.GOOS == "windows" {
for _, r := range runners { runPath = filepath.Join(tmpDir, "server.exe")
localRunnersByPriority = append(localRunnersByPriority, ModelRunner{Path: path.Join(workDir, r)})
} }
return localRunnersByPriority return runPath
} }
type llamaModel struct { type llamaModel struct {
@@ -175,6 +153,10 @@ type Running struct {
Cancel context.CancelFunc Cancel context.CancelFunc
} }
type ModelRunner struct {
Path string // path to the model runner executable
}
type llama struct { type llama struct {
api.Options api.Options
Running Running
@@ -246,11 +228,15 @@ func NumGPU(opts api.Options) int {
return n return n
} }
func newLlama(model string, adapters []string, runners []ModelRunner, opts api.Options) (*llama, error) { func newLlama(model string, adapters []string, runner ModelRunner, opts api.Options) (*llama, error) {
if _, err := os.Stat(model); err != nil { if _, err := os.Stat(model); err != nil {
return nil, err return nil, err
} }
if _, err := os.Stat(runner.Path); err != nil {
return nil, err
}
if len(adapters) > 1 { if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided") return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
} }
@@ -292,12 +278,7 @@ func newLlama(model string, adapters []string, runners []ModelRunner, opts api.O
} }
// start the llama.cpp server with a retry in case the port is already in use // start the llama.cpp server with a retry in case the port is already in use
for _, runner := range runners { for try := 0; try < 3; try++ {
if _, err := os.Stat(runner.Path); err != nil {
log.Printf("llama runner not found: %v", err)
continue
}
port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cmd := exec.CommandContext( cmd := exec.CommandContext(
@@ -305,30 +286,20 @@ func newLlama(model string, adapters []string, runners []ModelRunner, opts api.O
runner.Path, runner.Path,
append(params, "--port", strconv.Itoa(port))..., append(params, "--port", strconv.Itoa(port))...,
) )
cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", filepath.Dir(runner.Path)))
cmd.Stdout = os.Stderr cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}} llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
log.Print("starting llama runner") log.Print("starting llama.cpp server")
if err := llm.Cmd.Start(); err != nil { if err := llm.Cmd.Start(); err != nil {
log.Printf("error starting the external llama runner: %v", err) log.Printf("error starting the external llama.cpp server: %v", err)
continue continue
} }
// monitor the command, it is blocking, so if it exits we need to capture that
go func() {
err := llm.Cmd.Wait() // this will block until the command exits
if err != nil {
log.Printf("llama runner exited with error: %v", err)
} else {
log.Printf("llama runner exited")
}
}()
if err := waitForServer(llm); err != nil { if err := waitForServer(llm); err != nil {
log.Printf("error starting llama runner: %v", err) log.Printf("error starting llama.cpp server: %v", err)
llm.Close() llm.Close()
// try again // try again
continue continue
@@ -338,24 +309,19 @@ func newLlama(model string, adapters []string, runners []ModelRunner, opts api.O
return llm, nil return llm, nil
} }
return nil, fmt.Errorf("failed to start a llama runner") return nil, fmt.Errorf("max retry exceeded starting llama.cpp")
} }
func waitForServer(llm *llama) error { func waitForServer(llm *llama) error {
// wait for the server to start responding // wait for the server to start responding
start := time.Now() start := time.Now()
expiresAt := time.Now().Add(2 * time.Minute) // be generous with timeout, large models can take a while to load expiresAt := time.Now().Add(45 * time.Second)
ticker := time.NewTicker(200 * time.Millisecond) ticker := time.NewTicker(200 * time.Millisecond)
log.Print("waiting for llama runner to start responding") log.Print("waiting for llama.cpp server to start responding")
for range ticker.C { for range ticker.C {
if time.Now().After(expiresAt) { if time.Now().After(expiresAt) {
return fmt.Errorf("llama runner did not start within alloted time, retrying") return fmt.Errorf("llama.cpp server did not start within alloted time, retrying")
}
// check if the server process has terminated
if llm.Cmd.ProcessState != nil && llm.Cmd.ProcessState.Exited() {
return fmt.Errorf("llama runner process has terminated")
} }
if err := llm.Ping(context.Background()); err == nil { if err := llm.Ping(context.Background()); err == nil {
@@ -363,12 +329,15 @@ func waitForServer(llm *llama) error {
} }
} }
log.Printf("llama runner started in %f seconds", time.Since(start).Seconds()) log.Printf("llama.cpp server started in %f seconds", time.Since(start).Seconds())
return nil return nil
} }
func (llm *llama) Close() { func (llm *llama) Close() {
llm.Cancel() llm.Cancel()
if err := llm.Cmd.Wait(); err != nil {
log.Printf("llama.cpp server exited with error: %v", err)
}
} }
func (llm *llama) SetOptions(opts api.Options) { func (llm *llama) SetOptions(opts api.Options) {

View File

@@ -21,7 +21,7 @@ type LLM interface {
Ping(context.Context) error Ping(context.Context) error
} }
func New(workDir, model string, adapters []string, opts api.Options) (LLM, error) { func New(model string, adapters []string, opts api.Options) (LLM, error) {
if _, err := os.Stat(model); err != nil { if _, err := os.Stat(model); err != nil {
return nil, err return nil, err
} }
@@ -91,9 +91,9 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
switch ggml.Name() { switch ggml.Name() {
case "gguf": case "gguf":
opts.NumGQA = 0 // TODO: remove this when llama.cpp runners differ enough to need separate newLlama functions opts.NumGQA = 0 // TODO: remove this when llama.cpp runners differ enough to need separate newLlama functions
return newLlama(model, adapters, chooseRunners(workDir, "gguf"), opts) return newLlama(model, adapters, ggufRunner(), opts)
case "ggml", "ggmf", "ggjt", "ggla": case "ggml", "ggmf", "ggjt", "ggla":
return newLlama(model, adapters, chooseRunners(workDir, "ggml"), opts) return newLlama(model, adapters, ggmlRunner(), opts)
default: default:
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily()) return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
} }

View File

@@ -1,12 +0,0 @@
#!/bin/bash
set -e
mkdir -p dist
for ARCH in arm64 amd64; do
docker buildx build --platform=linux/$ARCH -f Dockerfile.build . -t builder:$ARCH --load
docker create --platform linux/$ARCH --name builder builder:$ARCH
docker cp builder:/go/src/github.com/jmorganca/ollama/ollama ./dist/ollama-linux-$ARCH
docker rm builder
done

View File

@@ -1,160 +0,0 @@
#!/bin/sh
# This script installs Ollama on Linux.
# It detects the current operating system architecture and installs the appropriate version of Ollama.
set -eu
check_os() {
if [ "$(uname -s)" != "Linux" ]; then
echo "This script is intended to run on Linux only."
exit 1
fi
}
determine_architecture() {
ARCH=$(uname -m)
case $ARCH in
x86_64)
ARCH_SUFFIX="amd64"
;;
aarch64|arm64)
ARCH_SUFFIX="arm64"
;;
*)
echo "Unsupported architecture: $ARCH"
exit 1
;;
esac
}
check_sudo() {
if [ "$(id -u)" -ne 0 ]; then
if command -v sudo >/dev/null 2>&1; then
SUDO_CMD="sudo"
echo "Downloading the ollama executable to the PATH, this will require sudo permissions."
else
echo "Error: sudo is not available. Please run as root or install sudo."
exit 1
fi
else
SUDO_CMD=""
fi
}
install_cuda_drivers() {
local os_name os_version
if [ -f "/etc/os-release" ]; then
. /etc/os-release
os_name=$ID
os_version=$VERSION_ID
else
echo "Unable to detect operating system. Skipping CUDA installation."
return 1
fi
# based on https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#package-manager-installation
case $os_name in
CentOS)
$SUDO_CMD yum install yum-utils
$SUDO_CMD yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo
$SUDO_CMD yum clean all
$SUDO_CMD yum -y install nvidia-driver-latest-dkms
$SUDO_CMD yum -y install cuda-driver
$SUDO_CMD yum install kernel-devel-$(uname -r) kernel-headers-$(uname -r)
$SUDO_CMD dkms status | awk -F: '/added/ { print $1 }' | xargs -n1 $SUDO_CMD dkms install
$SUDO_CMD modprobe nvidia
;;
ubuntu)
case $os_version in
20.04)
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb
;;
22.04)
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
;;
*)
echo "Skipping automatic CUDA installation, not supported for Ubuntu ($os_version)."
return
;;
esac
$SUDO_CMD dpkg -i cuda-keyring_1.1-1_all.deb
$SUDO_CMD apt-get update
$SUDO_CMD apt-get -y install cuda-drivers
;;
RedHatEnterprise*|Kylin|Fedora|SLES|openSUSE*|Microsoft|Debian)
echo "NVIDIA CUDA drivers may not be installed, you can install them from: https://developer.nvidia.com/cuda-downloads"
;;
*)
echo "Unsupported or unknown distribution, skipping GPU CUDA driver install: $os_name"
;;
esac
}
check_install_cuda_drivers() {
if lspci -d '10de:' | grep 'NVIDIA' >/dev/null; then
# NVIDIA Corporation [10de] device is available
if command -v nvidia-smi >/dev/null 2>&1; then
CUDA_VERSION=$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")
if [ -z "$CUDA_VERSION" ]; then
echo "Warning: NVIDIA-SMI is available, but the CUDA version cannot be detected. Installing CUDA drivers..."
install_cuda_drivers
else
echo "Detected CUDA version $CUDA_VERSION"
fi
else
echo "Warning: NVIDIA GPU detected but NVIDIA-SMI is not available. Installing CUDA drivers..."
install_cuda_drivers
fi
else
echo "No NVIDIA GPU detected. Skipping driver installation."
fi
}
download_ollama() {
$SUDO_CMD mkdir -p /usr/bin
$SUDO_CMD curl -fsSL -o /usr/bin/ollama "https://ollama.ai/download/latest/ollama-linux-$ARCH_SUFFIX"
}
configure_systemd() {
if command -v systemctl >/dev/null 2>&1; then
$SUDO_CMD useradd -r -s /bin/false -m -d /home/ollama ollama 2>/dev/null
echo "Creating systemd service file for ollama..."
cat <<EOF | $SUDO_CMD tee /etc/systemd/system/ollama.service >/dev/null
[Unit]
Description=Ollama Service
After=network-online.target
[Service]
ExecStart=/usr/bin/ollama serve
User=ollama
Group=ollama
Restart=always
RestartSec=3
Environment="HOME=/home/ollama"
[Install]
WantedBy=default.target
EOF
echo "Reloading systemd and enabling ollama service..."
if [ "$(systemctl is-system-running || echo 'not running')" = 'running' ]; then
$SUDO_CMD systemctl daemon-reload
$SUDO_CMD systemctl enable ollama
$SUDO_CMD systemctl restart ollama
fi
else
echo "Run 'ollama serve' from the command line to start the service."
fi
}
main() {
check_os
determine_architecture
check_sudo
download_ollama
configure_systemd
check_install_cuda_drivers
echo "Installation complete. You can now run 'ollama' from the command line."
}
main

View File

@@ -14,7 +14,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path/filepath" "path"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -71,7 +71,7 @@ func (r AuthRedirect) URL() (*url.URL, error) {
return redirectURL, nil return redirectURL, nil
} }
func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) { func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *RegistryOptions) (string, error) {
redirectURL, err := redirData.URL() redirectURL, err := redirData.URL()
if err != nil { if err != nil {
return "", err return "", err
@@ -82,7 +82,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
return "", err return "", err
} }
keyPath := filepath.Join(home, ".ollama", "id_ed25519") keyPath := path.Join(home, ".ollama", "id_ed25519")
rawKey, err := os.ReadFile(keyPath) rawKey, err := os.ReadFile(keyPath)
if err != nil { if err != nil {

View File

@@ -8,7 +8,7 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"path/filepath" "path"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@@ -173,7 +173,7 @@ func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
return fmt.Errorf("%w: on download registry responded with code %d: %v", errDownload, resp.StatusCode, string(body)) return fmt.Errorf("%w: on download registry responded with code %d: %v", errDownload, resp.StatusCode, string(body))
} }
err = os.MkdirAll(filepath.Dir(f.FilePath), 0o700) err = os.MkdirAll(path.Dir(f.FilePath), 0o700)
if err != nil { if err != nil {
return fmt.Errorf("make blobs directory: %w", err) return fmt.Errorf("make blobs directory: %w", err)
} }

View File

@@ -14,6 +14,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path"
"path/filepath" "path/filepath"
"reflect" "reflect"
"runtime" "runtime"
@@ -267,7 +268,7 @@ func filenameWithPath(path, f string) (string, error) {
return f, nil return f, nil
} }
func CreateModel(ctx context.Context, workDir, name string, path string, fn func(resp api.ProgressResponse)) error { func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
var manifest *ManifestV2 var manifest *ManifestV2
@@ -390,7 +391,7 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
return err return err
} }
// copy the model metadata // copie the model metadata
config.ModelFamily = source.ModelFamily config.ModelFamily = source.ModelFamily
config.ModelType = source.ModelType config.ModelType = source.ModelType
config.ModelFormat = source.ModelFormat config.ModelFormat = source.ModelFormat
@@ -460,10 +461,8 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
return err return err
} }
if layer.Size > 0 {
layer.MediaType = mediaType layer.MediaType = mediaType
layers = append(layers, layer) layers = append(layers, layer)
}
case "template", "system", "prompt": case "template", "system", "prompt":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
// remove the layer if one exists // remove the layer if one exists
@@ -475,10 +474,8 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
return err return err
} }
if layer.Size > 0 {
layer.MediaType = mediaType layer.MediaType = mediaType
layers = append(layers, layer) layers = append(layers, layer)
}
default: default:
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences) // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences)
params[c.Name] = append(params[c.Name], c.Args) params[c.Name] = append(params[c.Name], c.Args)
@@ -524,7 +521,7 @@ func CreateModel(ctx context.Context, workDir, name string, path string, fn func
} }
// generate the embedding layers // generate the embedding layers
embeddingLayers, err := embeddingLayers(workDir, embed) embeddingLayers, err := embeddingLayers(embed)
if err != nil { if err != nil {
return err return err
} }
@@ -581,7 +578,7 @@ type EmbeddingParams struct {
} }
// embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file // embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file
func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error) { func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
layers := []*LayerReader{} layers := []*LayerReader{}
if len(e.files) > 0 { if len(e.files) > 0 {
// check if the model is a file path or a model name // check if the model is a file path or a model name
@@ -594,7 +591,7 @@ func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error)
model = &Model{ModelPath: e.model} model = &Model{ModelPath: e.model}
} }
if err := load(context.Background(), workDir, model, e.opts, defaultSessionDuration); err != nil { if err := load(context.Background(), model, e.opts, defaultSessionDuration); err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err) return nil, fmt.Errorf("load model to generate embeddings: %v", err)
} }
@@ -1154,14 +1151,14 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
Total: layer.Size, Total: layer.Size,
}) })
location, chunkSize, err := startUpload(ctx, mp, layer, regOpts) location, err := startUpload(ctx, mp, layer, regOpts)
if err != nil { if err != nil {
log.Printf("couldn't start upload: %v", err) log.Printf("couldn't start upload: %v", err)
return err return err
} }
if strings.HasPrefix(filepath.Base(location.Path), "sha256:") { if strings.HasPrefix(path.Base(location.Path), "sha256:") {
layer.Digest = filepath.Base(location.Path) layer.Digest = path.Base(location.Path)
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: "using existing layer", Status: "using existing layer",
Digest: layer.Digest, Digest: layer.Digest,
@@ -1171,7 +1168,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
continue continue
} }
if err := uploadBlob(ctx, location, layer, chunkSize, regOpts, fn); err != nil { if err := uploadBlobChunked(ctx, location, layer, regOpts, fn); err != nil {
log.Printf("error uploading blob: %v", err) log.Printf("error uploading blob: %v", err)
return err return err
} }
@@ -1397,7 +1394,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
case resp.StatusCode == http.StatusUnauthorized: case resp.StatusCode == http.StatusUnauthorized:
auth := resp.Header.Get("www-authenticate") auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth) authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir) token, err := getAuthToken(ctx, authRedir, regOpts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1445,15 +1442,6 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if s := req.Header.Get("Content-Length"); s != "" {
contentLength, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil, err
}
req.ContentLength = contentLength
}
client := &http.Client{ client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 10 { if len(via) >= 10 {

View File

@@ -58,7 +58,7 @@ var loaded struct {
var defaultSessionDuration = 5 * time.Minute var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error { func load(ctx context.Context, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
opts := api.DefaultOptions() opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil { if err := opts.FromMap(model.Options); err != nil {
log.Printf("could not load model options: %v", err) log.Printf("could not load model options: %v", err)
@@ -94,7 +94,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
loaded.Embeddings = model.Embeddings loaded.Embeddings = model.Embeddings
} }
llmModel, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts) llmModel, err := llm.New(model.ModelPath, model.AdapterPaths, opts)
if err != nil { if err != nil {
return err return err
} }
@@ -130,7 +130,6 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
llmModel.SetOptions(opts) llmModel.SetOptions(opts)
} }
} }
loaded.expireAt = time.Now().Add(sessionDuration) loaded.expireAt = time.Now().Add(sessionDuration)
if loaded.expireTimer == nil { if loaded.expireTimer == nil {
@@ -151,7 +150,6 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
loaded.digest = "" loaded.digest = ""
}) })
} }
loaded.expireTimer.Reset(sessionDuration) loaded.expireTimer.Reset(sessionDuration)
return nil return nil
} }
@@ -174,11 +172,8 @@ func GenerateHandler(c *gin.Context) {
return return
} }
workDir := c.GetString("workDir") sessionDuration := defaultSessionDuration // TODO: set this duration from the request if specified
if err := load(c.Request.Context(), model, req.Options, sessionDuration); err != nil {
// TODO: set this duration from the request if specified
sessionDuration := defaultSessionDuration
if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@@ -223,14 +218,9 @@ func GenerateHandler(c *gin.Context) {
ch <- r ch <- r
} }
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
ch <- api.GenerateResponse{Model: req.Model, Done: true}
} else {
if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil { if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}
}() }()
streamResponse(c, ch) streamResponse(c, ch)
@@ -251,9 +241,7 @@ func EmbeddingHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
if err := load(c.Request.Context(), model, req.Options, 5*time.Minute); err != nil {
workDir := c.GetString("workDir")
if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
@@ -343,8 +331,6 @@ func CreateModelHandler(c *gin.Context) {
return return
} }
workDir := c.GetString("workDir")
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
@@ -355,7 +341,7 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := CreateModel(ctx, workDir, req.Name, req.Path, fn); err != nil { if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@@ -509,40 +495,33 @@ func CopyModelHandler(c *gin.Context) {
} }
} }
var defaultAllowOrigins = []string{ func Serve(ln net.Listener, origins []string) error {
"localhost",
"127.0.0.1",
"0.0.0.0",
}
func Serve(ln net.Listener, allowOrigins []string) error {
config := cors.DefaultConfig() config := cors.DefaultConfig()
config.AllowWildcard = true config.AllowWildcard = true
config.AllowOrigins = append(origins, []string{
config.AllowOrigins = allowOrigins "http://localhost",
for _, allowOrigin := range defaultAllowOrigins { "http://localhost:*",
config.AllowOrigins = append(config.AllowOrigins, "https://localhost",
fmt.Sprintf("http://%s", allowOrigin), "https://localhost:*",
fmt.Sprintf("https://%s", allowOrigin), "http://127.0.0.1",
fmt.Sprintf("http://%s:*", allowOrigin), "http://127.0.0.1:*",
fmt.Sprintf("https://%s:*", allowOrigin), "https://127.0.0.1",
) "https://127.0.0.1:*",
} "http://0.0.0.0",
"http://0.0.0.0:*",
workDir, err := os.MkdirTemp("", "ollama") "https://0.0.0.0",
if err != nil { "https://0.0.0.0:*",
return err }...)
}
defer os.RemoveAll(workDir)
r := gin.Default() r := gin.Default()
r.Use( r.Use(cors.New(config))
cors.New(config),
func(c *gin.Context) { r.GET("/", func(c *gin.Context) {
c.Set("workDir", workDir) c.String(http.StatusOK, "Ollama is running")
c.Next() })
}, r.HEAD("/", func(c *gin.Context) {
) c.Status(http.StatusOK)
})
r.POST("/api/pull", PullModelHandler) r.POST("/api/pull", PullModelHandler)
r.POST("/api/generate", GenerateHandler) r.POST("/api/generate", GenerateHandler)
@@ -550,17 +529,10 @@ func Serve(ln net.Listener, allowOrigins []string) error {
r.POST("/api/create", CreateModelHandler) r.POST("/api/create", CreateModelHandler)
r.POST("/api/push", PushModelHandler) r.POST("/api/push", PushModelHandler)
r.POST("/api/copy", CopyModelHandler) r.POST("/api/copy", CopyModelHandler)
r.GET("/api/tags", ListModelsHandler)
r.DELETE("/api/delete", DeleteModelHandler) r.DELETE("/api/delete", DeleteModelHandler)
r.POST("/api/show", ShowModelHandler) r.POST("/api/show", ShowModelHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
})
r.Handle(method, "/api/tags", ListModelsHandler)
}
log.Printf("Listening on %s", ln.Addr()) log.Printf("Listening on %s", ln.Addr())
s := &http.Server{ s := &http.Server{
Handler: r, Handler: r,
@@ -568,20 +540,19 @@ func Serve(ln net.Listener, allowOrigins []string) error {
// listen for a ctrl+c and stop any loaded llm // listen for a ctrl+c and stop any loaded llm
signals := make(chan os.Signal, 1) signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) signal.Notify(signals, syscall.SIGINT)
go func() { go func() {
<-signals <-signals
if loaded.llm != nil { if loaded.llm != nil {
loaded.llm.Close() loaded.llm.Close()
} }
os.RemoveAll(workDir)
os.Exit(0) os.Exit(0)
}() }()
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
// check compatibility to log warnings // check compatibility to log warnings
if _, err := llm.CheckVRAM(); err != nil { if _, err := llm.CheckVRAM(); err != nil {
log.Printf("Warning: GPU support may not enabled, check you have installed install GPU drivers: %v", err) log.Printf("Warning: GPU support not enabled, you may need to install GPU drivers: %v", err)
} }
} }

View File

@@ -14,12 +14,7 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
) )
const ( func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, error) {
redirectChunkSize = 1024 * 1024 * 1024
regularChunkSize = 95 * 1024 * 1024
)
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) {
requestURL := mp.BaseURL() requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/") requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
if layer.From != "" { if layer.From != "" {
@@ -32,26 +27,20 @@ func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *Regis
resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts) resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts)
if err != nil { if err != nil {
log.Printf("couldn't start upload: %v", err) log.Printf("couldn't start upload: %v", err)
return nil, 0, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
location := resp.Header.Get("Docker-Upload-Location") // Extract UUID location from header
chunkSize := redirectChunkSize location := resp.Header.Get("Location")
if location == "" { if location == "" {
location = resp.Header.Get("Location") return nil, fmt.Errorf("location header is missing in response")
chunkSize = regularChunkSize
} }
locationURL, err := url.Parse(location) return url.Parse(location)
if err != nil {
return nil, 0, err
}
return locationURL, int64(chunkSize), nil
} }
func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
// TODO allow resumability // TODO allow resumability
// TODO allow canceling uploads via DELETE // TODO allow canceling uploads via DELETE
@@ -66,12 +55,8 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz
} }
defer f.Close() defer f.Close()
pw := ProgressWriter{ // 95MB chunk size
status: fmt.Sprintf("uploading %s", layer.Digest), chunkSize := 95 * 1024 * 1024
digest: layer.Digest,
total: layer.Size,
fn: fn,
}
for offset := int64(0); offset < int64(layer.Size); { for offset := int64(0); offset < int64(layer.Size); {
chunk := int64(layer.Size) - offset chunk := int64(layer.Size) - offset
@@ -79,8 +64,50 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz
chunk = int64(chunkSize) chunk = int64(chunkSize)
} }
resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw) sectionReader := io.NewSectionReader(f, int64(offset), chunk)
if err != nil { for try := 0; try < MaxRetries; try++ {
ch := make(chan error, 1)
r, w := io.Pipe()
defer r.Close()
go func() {
defer w.Close()
for chunked := int64(0); chunked < chunk; {
select {
case err := <-ch:
log.Printf("chunk interrupted: %v", err)
return
default:
n, err := io.CopyN(w, sectionReader, 1024*1024)
if err != nil && !errors.Is(err, io.EOF) {
fn(api.ProgressResponse{
Status: fmt.Sprintf("error reading chunk: %v", err),
Digest: layer.Digest,
Total: layer.Size,
Completed: int(offset),
})
return
}
chunked += n
fn(api.ProgressResponse{
Status: fmt.Sprintf("uploading %s", layer.Digest),
Digest: layer.Digest,
Total: layer.Size,
Completed: int(offset) + int(chunked),
})
}
}
}()
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", strconv.Itoa(int(chunk)))
headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
resp, err := makeRequest(ctx, "PATCH", requestURL, headers, r, regOpts)
if err != nil && !errors.Is(err, io.EOF) {
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("error uploading chunk: %v", err), Status: fmt.Sprintf("error uploading chunk: %v", err),
Digest: layer.Digest, Digest: layer.Digest,
@@ -90,17 +117,35 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz
return err return err
} }
defer resp.Body.Close()
offset += chunk switch {
location := resp.Header.Get("Docker-Upload-Location") case resp.StatusCode == http.StatusUnauthorized:
if location == "" { ch <- errors.New("unauthorized")
location = resp.Header.Get("Location")
}
requestURL, err = url.Parse(location) auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir, regOpts)
if err != nil { if err != nil {
return err return err
} }
regOpts.Token = token
sectionReader = io.NewSectionReader(f, int64(offset), chunk)
continue
case resp.StatusCode >= http.StatusBadRequest:
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
}
offset += sectionReader.Size()
requestURL, err = url.Parse(resp.Header.Get("Location"))
if err != nil {
return err
}
break
}
} }
values := requestURL.Query() values := requestURL.Query()
@@ -125,90 +170,3 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz
} }
return nil return nil
} }
func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) {
sectionReader := io.NewSectionReader(r, int64(offset), limit)
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", strconv.Itoa(int(limit)))
headers.Set("X-Redirect-Uploads", "1")
if method == http.MethodPatch {
headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
}
for try := 0; try < MaxRetries; try++ {
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sectionReader, pw), opts)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}
defer resp.Body.Close()
switch {
case resp.StatusCode == http.StatusTemporaryRedirect:
location, err := resp.Location()
if err != nil {
return nil, err
}
pw.completed = int(offset)
if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil {
// retry
log.Printf("retrying redirected upload: %v", err)
continue
}
return resp, nil
case resp.StatusCode == http.StatusUnauthorized:
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir)
if err != nil {
return nil, err
}
opts.Token = token
pw.completed = int(offset)
sectionReader = io.NewSectionReader(r, offset, limit)
continue
case resp.StatusCode >= http.StatusBadRequest:
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
}
return resp, nil
}
return nil, fmt.Errorf("max retries exceeded")
}
type ProgressWriter struct {
status string
digest string
bucket int
completed int
total int
fn func(api.ProgressResponse)
}
func (pw *ProgressWriter) Write(b []byte) (int, error) {
n := len(b)
pw.bucket += n
pw.completed += n
// throttle status updates to not spam the client
if pw.bucket >= 1024*1024 || pw.completed >= pw.total {
pw.fn(api.ProgressResponse{
Status: pw.status,
Digest: pw.digest,
Total: pw.total,
Completed: pw.completed,
})
pw.bucket = 0
}
return n, nil
}