Compare commits
17 Commits
ollama.com
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e117483ef6 | ||
|
|
62be2050dd | ||
|
|
56f8aa6912 | ||
|
|
e6f9bfc0e8 | ||
|
|
8d1995c625 | ||
|
|
fd01fbf038 | ||
|
|
0408205c1c | ||
|
|
63a7edd771 | ||
|
|
554ffdcce3 | ||
|
|
9850a4ce08 | ||
|
|
fd048f1367 | ||
|
|
8dca03173d | ||
|
|
85bdf14b56 | ||
|
|
da8a0c7657 | ||
|
|
ea4c284a48 | ||
|
|
8aec92fa6d | ||
|
|
70261b9bb6 |
14
Dockerfile
14
Dockerfile
@@ -18,7 +18,7 @@ ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
|
||||
ARG CMAKE_VERSION
|
||||
@@ -28,7 +28,7 @@ ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64
|
||||
ARG CMAKE_VERSION
|
||||
@@ -40,7 +40,7 @@ COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
ARG AMDGPU_TARGETS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
RUN mkdir /tmp/scratch && \
|
||||
for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \
|
||||
cp ${dep} /tmp/scratch/ || exit 1 ; \
|
||||
@@ -64,11 +64,11 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
|
||||
ARG CMAKE_VERSION
|
||||
@@ -84,7 +84,7 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64
|
||||
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
|
||||
FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64
|
||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
|
||||
|
||||
# Intermediate stage used for ./scripts/build_linux.sh
|
||||
|
||||
27
README.md
27
README.md
@@ -35,10 +35,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
||||
|
||||
## Quickstart
|
||||
|
||||
To run and chat with [Llama 2](https://ollama.com/library/llama2):
|
||||
To run and chat with [Llama 3](https://ollama.com/library/llama3):
|
||||
|
||||
```
|
||||
ollama run llama2
|
||||
ollama run llama3
|
||||
```
|
||||
|
||||
## Model library
|
||||
@@ -49,7 +49,8 @@ Here are some example models that can be downloaded:
|
||||
|
||||
| Model | Parameters | Size | Download |
|
||||
| ------------------ | ---------- | ----- | ------------------------------ |
|
||||
| Llama 2 | 7B | 3.8GB | `ollama run llama2` |
|
||||
| Llama 3 | 8B | 4.7GB | `ollama run llama3` |
|
||||
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
|
||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||
| Dolphin Phi | 2.7B | 1.6GB | `ollama run dolphin-phi` |
|
||||
| Phi-2 | 2.7B | 1.7GB | `ollama run phi` |
|
||||
@@ -97,16 +98,16 @@ See the [guide](docs/import.md) on importing models for more information.
|
||||
|
||||
### Customize a prompt
|
||||
|
||||
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama2` model:
|
||||
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3` model:
|
||||
|
||||
```
|
||||
ollama pull llama2
|
||||
ollama pull llama3
|
||||
```
|
||||
|
||||
Create a `Modelfile`:
|
||||
|
||||
```
|
||||
FROM llama2
|
||||
FROM llama3
|
||||
|
||||
# set the temperature to 1 [higher is more creative, lower is more coherent]
|
||||
PARAMETER temperature 1
|
||||
@@ -141,7 +142,7 @@ ollama create mymodel -f ./Modelfile
|
||||
### Pull a model
|
||||
|
||||
```
|
||||
ollama pull llama2
|
||||
ollama pull llama3
|
||||
```
|
||||
|
||||
> This command can also be used to update a local model. Only the diff will be pulled.
|
||||
@@ -149,13 +150,13 @@ ollama pull llama2
|
||||
### Remove a model
|
||||
|
||||
```
|
||||
ollama rm llama2
|
||||
ollama rm llama3
|
||||
```
|
||||
|
||||
### Copy a model
|
||||
|
||||
```
|
||||
ollama cp llama2 my-llama2
|
||||
ollama cp llama3 my-model
|
||||
```
|
||||
|
||||
### Multiline input
|
||||
@@ -179,7 +180,7 @@ The image features a yellow smiley face, which is likely the central focus of th
|
||||
### Pass in prompt as arguments
|
||||
|
||||
```
|
||||
$ ollama run llama2 "Summarize this file: $(cat README.md)"
|
||||
$ ollama run llama3 "Summarize this file: $(cat README.md)"
|
||||
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
||||
```
|
||||
|
||||
@@ -226,7 +227,7 @@ Next, start the server:
|
||||
Finally, in a separate shell, run a model:
|
||||
|
||||
```
|
||||
./ollama run llama2
|
||||
./ollama run llama3
|
||||
```
|
||||
|
||||
## REST API
|
||||
@@ -237,7 +238,7 @@ Ollama has a REST API for running and managing models.
|
||||
|
||||
```
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"prompt":"Why is the sky blue?"
|
||||
}'
|
||||
```
|
||||
@@ -246,7 +247,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
|
||||
```
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "mistral",
|
||||
"model": "llama3",
|
||||
"messages": [
|
||||
{ "role": "user", "content": "why is the sky blue?" }
|
||||
]
|
||||
|
||||
11
api/types.go
11
api/types.go
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
@@ -97,7 +98,8 @@ type ChatResponse struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message Message `json:"message"`
|
||||
|
||||
Done bool `json:"done"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
|
||||
Metrics
|
||||
}
|
||||
@@ -264,8 +266,9 @@ type GenerateResponse struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
|
||||
Done bool `json:"done"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
Metrics
|
||||
}
|
||||
@@ -307,7 +310,7 @@ func (m *Metrics) Summary() {
|
||||
}
|
||||
}
|
||||
|
||||
var ErrInvalidOpts = fmt.Errorf("invalid options")
|
||||
var ErrInvalidOpts = errors.New("invalid options")
|
||||
|
||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
|
||||
@@ -90,7 +90,7 @@ The final response in the stream also includes additional data about the generat
|
||||
- `load_duration`: time spent in nanoseconds loading the model
|
||||
- `prompt_eval_count`: number of tokens in the prompt
|
||||
- `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
|
||||
- `eval_count`: number of tokens the response
|
||||
- `eval_count`: number of tokens in the response
|
||||
- `eval_duration`: time in nanoseconds spent generating the response
|
||||
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
|
||||
- `response`: empty if the response was streamed, if not streamed, this will contain the full response
|
||||
|
||||
@@ -1,38 +1,15 @@
|
||||
# Running Ollama on NVIDIA Jetson Devices
|
||||
|
||||
With some minor configuration, Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/). The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack).
|
||||
Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/) and should run out of the box with the standard installation instructions.
|
||||
|
||||
NVIDIA Jetson devices are Linux-based embedded AI computers that are purpose-built for AI applications.
|
||||
|
||||
Jetsons have an integrated GPU that is wired directly to the memory controller of the machine. For this reason, the `nvidia-smi` command is unrecognized, and Ollama proceeds to operate in "CPU only"
|
||||
mode. This can be verified by using a monitoring tool like jtop.
|
||||
|
||||
In order to address this, we simply pass the path to the Jetson's pre-installed CUDA libraries into `ollama serve` (while in a tmux session). We then hardcode the num_gpu parameters into a cloned
|
||||
version of our target model.
|
||||
|
||||
Prerequisites:
|
||||
|
||||
- curl
|
||||
- tmux
|
||||
|
||||
Here are the steps:
|
||||
The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack), but should also work on JetPack 6.0.
|
||||
|
||||
- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh`
|
||||
- Stop the Ollama service: `sudo systemctl stop ollama`
|
||||
- Start Ollama serve in a tmux session called ollama_jetson and reference the CUDA libraries path: `tmux has-session -t ollama_jetson 2>/dev/null || tmux new-session -d -s ollama_jetson
|
||||
'LD_LIBRARY_PATH=/usr/local/cuda/lib64 ollama serve'`
|
||||
- Pull the model you want to use (e.g. mistral): `ollama pull mistral`
|
||||
- Create a new Modelfile specifically for enabling GPU support on the Jetson: `touch ModelfileMistralJetson`
|
||||
- In the ModelfileMistralJetson file, specify the FROM model and the num_gpu PARAMETER as shown below:
|
||||
|
||||
```
|
||||
FROM mistral
|
||||
PARAMETER num_gpu 999
|
||||
```
|
||||
|
||||
- Create a new model from your Modelfile: `ollama create mistral-jetson -f ./ModelfileMistralJetson`
|
||||
- Run the new model: `ollama run mistral-jetson`
|
||||
|
||||
If you run a monitoring tool like jtop you should now see that Ollama is using the Jetson's integrated GPU.
|
||||
- Start an interactive session: `ollama run mistral`
|
||||
|
||||
And that's it!
|
||||
|
||||
# Running Ollama in Docker
|
||||
|
||||
When running GPU accelerated applications in Docker, it is highly recommended to use [dusty-nv jetson-containers repo](https://github.com/dusty-nv/jetson-containers).
|
||||
@@ -57,21 +57,21 @@ init_vars
|
||||
git_module_setup
|
||||
apply_patches
|
||||
|
||||
init_vars
|
||||
if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
|
||||
# Builds by default, allows skipping, forces build if OLLAMA_CPU_TARGET="static"
|
||||
# Enables optimized Dockerfile builds using a blanket skip and targeted overrides
|
||||
# Static build for linking into the Go binary
|
||||
init_vars
|
||||
CMAKE_TARGETS="--target llama --target ggml"
|
||||
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/linux/${ARCH}_static"
|
||||
echo "Building static library"
|
||||
build
|
||||
fi
|
||||
|
||||
init_vars
|
||||
if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||
|
||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
|
||||
# Static build for linking into the Go binary
|
||||
init_vars
|
||||
CMAKE_TARGETS="--target llama --target ggml"
|
||||
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/linux/${ARCH}_static"
|
||||
echo "Building static library"
|
||||
build
|
||||
fi
|
||||
|
||||
|
||||
# Users building from source can tune the exact flags we pass to cmake for configuring
|
||||
# llama.cpp, and we'll build only 1 CPU variant in that case as the default.
|
||||
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/ollama/ollama/gpu"
|
||||
)
|
||||
|
||||
var errPayloadMissing = fmt.Errorf("expected payloads not included in this build of ollama")
|
||||
var errPayloadMissing = errors.New("expected payloads not included in this build of ollama")
|
||||
|
||||
func Init() error {
|
||||
payloadsDir, err := gpu.PayloadsDir()
|
||||
|
||||
@@ -509,10 +509,13 @@ type ImageData struct {
|
||||
}
|
||||
|
||||
type completion struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
StoppedEos bool `json:"stopped_eos"`
|
||||
StoppedWord bool `json:"stopped_word"`
|
||||
StoppedLimit bool `json:"stopped_limit"`
|
||||
|
||||
Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
@@ -532,6 +535,7 @@ type CompletionRequest struct {
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
Done bool
|
||||
DoneReason string
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
@@ -648,6 +652,8 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
|
||||
return fmt.Errorf("error parsing llm response stream: %s", line)
|
||||
}
|
||||
|
||||
fmt.Println("c", string(evt))
|
||||
|
||||
var c completion
|
||||
if err := json.Unmarshal(evt, &c); err != nil {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||
@@ -674,8 +680,18 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
|
||||
}
|
||||
|
||||
if c.Stop {
|
||||
var doneReason string
|
||||
switch {
|
||||
case c.StoppedEos:
|
||||
doneReason = "stop"
|
||||
case c.StoppedWord:
|
||||
doneReason = "stop"
|
||||
case c.StoppedLimit:
|
||||
doneReason = "limit"
|
||||
}
|
||||
fn(CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: doneReason,
|
||||
PromptEvalCount: c.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||
EvalCount: c.Timings.PredictedN,
|
||||
|
||||
@@ -1137,7 +1137,7 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||
}
|
||||
|
||||
var errUnauthorized = fmt.Errorf("unauthorized")
|
||||
var errUnauthorized = errors.New("unauthorized")
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||
for i := 0; i < 2; i++ {
|
||||
@@ -1255,7 +1255,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
|
||||
}
|
||||
}
|
||||
|
||||
var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
|
||||
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
|
||||
|
||||
func verifyBlob(digest string) error {
|
||||
fp, err := GetBlobsPath(digest)
|
||||
|
||||
@@ -91,7 +91,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
|
||||
}
|
||||
|
||||
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
||||
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
||||
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, int, error) {
|
||||
type prompt struct {
|
||||
System string
|
||||
Prompt string
|
||||
@@ -138,7 +138,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
||||
|
||||
p.Response = msg.Content
|
||||
default:
|
||||
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
return "", 0, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,7 +151,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
||||
for i, p := range prompts {
|
||||
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
prompts[i].tokens = tokens + len(prompts[i].images)*768
|
||||
@@ -160,15 +160,17 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
||||
// truncate images and prompts starting from the beginning of the list
|
||||
// until either one prompt remains or the total tokens fits the context window
|
||||
// TODO (jmorganca): this doesn't account for the context window room required for the response
|
||||
var required int
|
||||
for {
|
||||
var required int
|
||||
required = 0
|
||||
for _, p := range prompts {
|
||||
required += p.tokens
|
||||
}
|
||||
|
||||
required += 1 // for bos token
|
||||
|
||||
if required <= window {
|
||||
// leave ~1024 tokens for generation
|
||||
if required <= max(1024, window/2) {
|
||||
slog.Debug("prompt now fits in context window", "required", required, "window", window)
|
||||
break
|
||||
}
|
||||
@@ -194,7 +196,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
||||
|
||||
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
prompts[0].tokens = tokens + len(prompts[0].images)*768
|
||||
@@ -212,10 +214,10 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
||||
// last prompt should leave the response unrendered (for completion)
|
||||
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", 0, err
|
||||
}
|
||||
sb.WriteString(rendered)
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
return sb.String(), required, nil
|
||||
}
|
||||
|
||||
@@ -192,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
|
||||
got, _, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
|
||||
if err != nil {
|
||||
t.Errorf("error = %v", err)
|
||||
}
|
||||
|
||||
@@ -234,9 +234,10 @@ func GenerateHandler(c *gin.Context) {
|
||||
// of `raw` mode so we need to check for it too
|
||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -289,6 +290,14 @@ func GenerateHandler(c *gin.Context) {
|
||||
prompt = sb.String()
|
||||
}
|
||||
|
||||
tokens, err := loaded.llama.Tokenize(c.Request.Context(), prompt)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
opts.NumPredict = max(opts.NumCtx-len(tokens), 0)
|
||||
|
||||
slog.Debug("generate handler", "prompt", prompt)
|
||||
|
||||
ch := make(chan any)
|
||||
@@ -307,10 +316,11 @@ func GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
resp := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: r.Done,
|
||||
Response: r.Content,
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: r.Done,
|
||||
DoneReason: r.DoneReason,
|
||||
Response: r.Content,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
@@ -1219,17 +1229,17 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
}
|
||||
|
||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
|
||||
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, int, error) {
|
||||
encode := func(s string) ([]int, error) {
|
||||
return loaded.llama.Tokenize(ctx, s)
|
||||
}
|
||||
|
||||
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
||||
prompt, tokens, err := ChatPrompt(template, messages, numCtx, encode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
return prompt, nil
|
||||
return prompt, tokens, nil
|
||||
}
|
||||
|
||||
func ChatHandler(c *gin.Context) {
|
||||
@@ -1309,19 +1319,22 @@ func ChatHandler(c *gin.Context) {
|
||||
}, req.Messages...)
|
||||
}
|
||||
|
||||
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
|
||||
prompt, tokens, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
opts.NumPredict = max(opts.NumCtx-tokens, 0)
|
||||
|
||||
// an empty request loads the model
|
||||
if len(req.Messages) == 0 || prompt == "" {
|
||||
resp := api.ChatResponse{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
Message: api.Message{Role: "assistant"},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
Message: api.Message{Role: "assistant"},
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
return
|
||||
@@ -1356,10 +1369,11 @@ func ChatHandler(c *gin.Context) {
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
DoneReason: r.DoneReason,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
@@ -47,8 +48,11 @@ var (
|
||||
// Digest.
|
||||
func ParseDigest(s string) Digest {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
if !ok {
|
||||
typ, digest, ok = strings.Cut(s, ":")
|
||||
}
|
||||
if ok && isValidDigestType(typ) && isValidHex(digest) {
|
||||
return Digest{s: s}
|
||||
return Digest{s: fmt.Sprintf("%s-%s", typ, digest)}
|
||||
}
|
||||
return Digest{}
|
||||
}
|
||||
|
||||
@@ -156,7 +156,7 @@ func ParseName(s, fill string) Name {
|
||||
r = Name{}
|
||||
return false
|
||||
}
|
||||
if kind == PartExtraneous || !isValidPart(kind, part) {
|
||||
if kind == PartExtraneous || !IsValidNamePart(kind, part) {
|
||||
r = Name{}
|
||||
return false
|
||||
}
|
||||
@@ -176,7 +176,7 @@ func parseMask(s string) Name {
|
||||
// mask part; treat as empty but valid
|
||||
return true
|
||||
}
|
||||
if !isValidPart(kind, part) {
|
||||
if !IsValidNamePart(kind, part) {
|
||||
panic(fmt.Errorf("invalid mask part %s: %q", kind, part))
|
||||
}
|
||||
r.parts[kind] = part
|
||||
@@ -608,7 +608,7 @@ func ParseNameFromFilepath(s, fill string) Name {
|
||||
var r Name
|
||||
for i := range PartBuild + 1 {
|
||||
part, rest, _ := strings.Cut(s, string(filepath.Separator))
|
||||
if !isValidPart(i, part) {
|
||||
if !IsValidNamePart(i, part) {
|
||||
return Name{}
|
||||
}
|
||||
r.parts[i] = part
|
||||
@@ -654,9 +654,12 @@ func (r Name) FilepathNoBuild() string {
|
||||
return filepath.Join(r.parts[:PartBuild]...)
|
||||
}
|
||||
|
||||
// isValidPart reports if s contains all valid characters for the given
|
||||
// part kind.
|
||||
func isValidPart(kind PartKind, s string) bool {
|
||||
// IsValidNamePart reports if s contains all valid characters for the given
|
||||
// part kind and is under MaxNamePartLen bytes.
|
||||
func IsValidNamePart(kind PartKind, s string) bool {
|
||||
if len(s) > MaxNamePartLen {
|
||||
return false
|
||||
}
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -105,6 +105,12 @@ var testNames = map[string]fields{
|
||||
strings.Repeat("a", MaxNamePartLen+1): {},
|
||||
}
|
||||
|
||||
func TestIsValidNameLen(t *testing.T) {
|
||||
if IsValidNamePart(PartNamespace, strings.Repeat("a", MaxNamePartLen+1)) {
|
||||
t.Errorf("unexpectedly valid long name")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsecutiveDots tests that consecutive dots are not allowed in any
|
||||
// part, to avoid path traversal. There also are some tests in testNames, but
|
||||
// this test is more exhaustive and exists to emphasize the importance of
|
||||
|
||||
Reference in New Issue
Block a user