Compare commits
52 Commits
v0.4.3
...
parth/open
Author | SHA1 | Date | |
---|---|---|---|
![]() |
2536ffe0ab | ||
![]() |
97abd7bfea | ||
![]() |
c6509bf76e | ||
![]() |
aed1419c64 | ||
![]() |
c6c526275d | ||
![]() |
630e7dc6ff | ||
![]() |
eb8366d658 | ||
![]() |
4456012956 | ||
![]() |
539be43640 | ||
![]() |
1bdab9fdb1 | ||
![]() |
2b82c5a8a1 | ||
![]() |
55c3efa900 | ||
![]() |
1aedffad93 | ||
![]() |
ff6c2d6dc8 | ||
![]() |
d543b282a7 | ||
![]() |
5f8051180e | ||
![]() |
39e29ae5dd | ||
![]() |
30a9f063c9 | ||
![]() |
ce7455a8e1 | ||
![]() |
e3936d4fb3 | ||
![]() |
940e62772e | ||
![]() |
71e6a0d0d1 | ||
![]() |
2cd11ae365 | ||
![]() |
52bbad12f9 | ||
![]() |
30e88d7f31 | ||
![]() |
2b7ed61ca2 | ||
![]() |
647513a7d4 | ||
![]() |
a210ec74d2 | ||
![]() |
cfb1ddd6fc | ||
![]() |
3987acd7ec | ||
![]() |
fda1e6b563 | ||
![]() |
3440ffb37b | ||
![]() |
a820d2b267 | ||
![]() |
2ebdb54fb3 | ||
![]() |
bb52abfa55 | ||
![]() |
31cb1ca9e5 | ||
![]() |
78f779a323 | ||
![]() |
3478b2cf14 | ||
![]() |
7b5585b9cb | ||
![]() |
f0a351810c | ||
![]() |
b85520bfb9 | ||
![]() |
d88972ea48 | ||
![]() |
25c9339e2d | ||
![]() |
597072ef1b | ||
![]() |
84b3e07f1b | ||
![]() |
422d52858c | ||
![]() |
723f285813 | ||
![]() |
eaaf5d309d | ||
![]() |
27d9c749d5 | ||
![]() |
7355ab3703 | ||
![]() |
7ed81437fe | ||
![]() |
220108d3f4 |
7
.github/workflows/test.yaml
vendored
7
.github/workflows/test.yaml
vendored
@@ -243,7 +243,7 @@ jobs:
|
||||
$env:PATH="$gopath;$gccpath;$env:PATH"
|
||||
echo $env:PATH
|
||||
if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" }
|
||||
make -j 4
|
||||
make -j 4
|
||||
- name: 'Build Unix Go Runners'
|
||||
if: ${{ ! startsWith(matrix.os, 'windows-') }}
|
||||
run: make -j 4
|
||||
@@ -310,8 +310,7 @@ jobs:
|
||||
arm64) echo ARCH=arm64 ;;
|
||||
esac >>$GITHUB_ENV
|
||||
shell: bash
|
||||
- run: go build
|
||||
- run: go test -v ./...
|
||||
- run: go test ./...
|
||||
|
||||
patches:
|
||||
needs: [changes]
|
||||
@@ -323,4 +322,4 @@ jobs:
|
||||
submodules: recursive
|
||||
- name: Verify patches carry all the changes
|
||||
run: |
|
||||
make apply-patches sync && git diff --compact-summary --exit-code llama
|
||||
make apply-patches sync && git diff --compact-summary --exit-code llama
|
||||
|
37
README.md
37
README.md
@@ -298,7 +298,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
|
||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Chat with Code Repository)
|
||||
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories)
|
||||
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
|
||||
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
|
||||
- [RAGFlow](https://github.com/infiniflow/ragflow) (Open-source Retrieval-Augmented Generation engine based on deep document understanding)
|
||||
@@ -308,6 +308,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG)
|
||||
- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation)
|
||||
- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
|
||||
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
|
||||
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
|
||||
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||
@@ -316,8 +317,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in discord )
|
||||
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
||||
- [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine)
|
||||
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy to use GUI with sample custom LLM for Drivers Education)
|
||||
- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application)
|
||||
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy to use GUI with sample custom LLM for Drivers Education)
|
||||
- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application)
|
||||
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations)
|
||||
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
||||
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
|
||||
@@ -338,20 +339,33 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama)
|
||||
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
|
||||
- [LLMChat](https://github.com/trendy-design/llmchat) (Privacy focused, 100% local, intuitive all-in-one chat interface)
|
||||
- [Local Multimodal AI Chat](https://github.com/Leon-Sander/Local-Multimodal-AI-Chat) (Ollama-based LLM Chat with support for multiple features, including PDF RAG, voice chat, image-based interactions, and integration with OpenAI.)
|
||||
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG on Mac/Windows/Linux)
|
||||
- [OrionChat](https://github.com/EliasPereirah/OrionChat) - OrionChat is a web interface for chatting with different AI providers
|
||||
- [G1](https://github.com/bklieger-groq/g1) (Prototype of using prompting strategies to improve the LLM's reasoning through o1-like reasoning chains.)
|
||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
|
||||
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
|
||||
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
||||
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
||||
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
|
||||
- [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder)
|
||||
- [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation)
|
||||
- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt)
|
||||
- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI)
|
||||
- [VT](https://github.com/vinhnx/vt.ai) (A minimal multimodal AI chat app, with dynamic conversation routing. Supports local models via Ollama)
|
||||
- [Nosia](https://github.com/nosia-ai/nosia) (Easy to install and use RAG platform based on Ollama)
|
||||
- [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application avaiable for Mac/Windows/Linux)
|
||||
- [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application avaiable for Mac/Windows/Linux)
|
||||
- [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support)
|
||||
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
|
||||
|
||||
### Cloud
|
||||
|
||||
- [Google Cloud](https://cloud.google.com/run/docs/tutorials/gpu-gemma2-with-ollama)
|
||||
- [Fly.io](https://fly.io/docs/python/do-more/add-ollama/)
|
||||
- [Koyeb](https://www.koyeb.com/deploy/ollama)
|
||||
|
||||
### Terminal
|
||||
|
||||
@@ -367,7 +381,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Oatmeal](https://github.com/dustinblackman/oatmeal)
|
||||
- [cmdh](https://github.com/pgibler/cmdh)
|
||||
- [ooo](https://github.com/npahlfer/ooo)
|
||||
- [shell-pilot](https://github.com/reid41/shell-pilot)
|
||||
- [shell-pilot](https://github.com/reid41/shell-pilot)(Interact with models via pure shell scripts on Linux or macOS)
|
||||
- [tenere](https://github.com/pythops/tenere)
|
||||
- [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
|
||||
- [typechat-cli](https://github.com/anaisbetts/typechat-cli)
|
||||
@@ -379,12 +393,15 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Ollama eBook Summary](https://github.com/cognitivetech/ollama-ebook-summary/)
|
||||
- [Ollama Mixture of Experts (MOE) in 50 lines of code](https://github.com/rapidarchitect/ollama_moe)
|
||||
- [vim-intelligence-bridge](https://github.com/pepo-ec/vim-intelligence-bridge) Simple interaction of "Ollama" with the Vim editor
|
||||
- [x-cmd ollama](https://x-cmd.com/mod/ollama)
|
||||
- [bb7](https://github.com/drunkwcodes/bb7)
|
||||
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
||||
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
||||
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||
|
||||
### Database
|
||||
@@ -485,12 +502,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
|
||||
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
|
||||
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
||||
- [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467)
|
||||
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
|
||||
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server)
|
||||
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
|
||||
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
|
||||
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
|
||||
- [vnc-lm](https://github.com/jk011ru/vnc-lm) (A containerized Discord bot with support for attachments and web links)
|
||||
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
|
||||
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
|
||||
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
|
||||
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
|
||||
@@ -500,3 +518,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
### Supported backends
|
||||
|
||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
|
||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
||||
|
@@ -67,7 +67,7 @@ type GenerateRequest struct {
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
|
||||
// Format specifies the format to return a response in.
|
||||
Format string `json:"format"`
|
||||
Format json.RawMessage `json:"format,omitempty"`
|
||||
|
||||
// KeepAlive controls how long the model will stay loaded in memory following
|
||||
// this request.
|
||||
@@ -94,7 +94,7 @@ type ChatRequest struct {
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Format is the format to return the response in (e.g. "json").
|
||||
Format string `json:"format"`
|
||||
Format json.RawMessage `json:"format,omitempty"`
|
||||
|
||||
// KeepAlive controls how long the model will stay loaded into memory
|
||||
// following the request.
|
||||
@@ -146,6 +146,7 @@ type ToolCall struct {
|
||||
}
|
||||
|
||||
type ToolCallFunction struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||
}
|
||||
|
@@ -64,7 +64,7 @@ func initStore() {
|
||||
slog.Debug(fmt.Sprintf("unexpected error searching for store: %s", err))
|
||||
}
|
||||
slog.Debug("initializing new store")
|
||||
store.ID = uuid.New().String()
|
||||
store.ID = uuid.NewString()
|
||||
writeStore(getStorePath())
|
||||
}
|
||||
|
||||
|
73
cmd/cmd.go
73
cmd/cmd.go
@@ -8,6 +8,7 @@ import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -19,7 +20,6 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -35,13 +35,11 @@ import (
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -456,6 +454,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(prompts) > 0 {
|
||||
interactive = false
|
||||
}
|
||||
// Be quiet if we're redirecting to a pipe or file
|
||||
if !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
interactive = false
|
||||
}
|
||||
|
||||
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||
if err != nil {
|
||||
@@ -512,47 +514,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
func errFromUnknownKey(unknownKeyErr error) error {
|
||||
// find SSH public key in the error message
|
||||
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
||||
re := regexp.MustCompile(sshKeyPattern)
|
||||
matches := re.FindStringSubmatch(unknownKeyErr.Error())
|
||||
|
||||
if len(matches) > 0 {
|
||||
serverPubKey := matches[0]
|
||||
|
||||
localPubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||
// try the ollama service public key
|
||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
localPubKey = strings.TrimSpace(string(svcPubKey))
|
||||
}
|
||||
|
||||
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
|
||||
if serverPubKey != localPubKey {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
var msg strings.Builder
|
||||
msg.WriteString(unknownKeyErr.Error())
|
||||
msg.WriteString("\n\nYour ollama key is:\n")
|
||||
msg.WriteString(localPubKey)
|
||||
msg.WriteString("\nAdd your key at:\n")
|
||||
msg.WriteString("https://ollama.com/settings/keys")
|
||||
|
||||
return errors.New(msg.String())
|
||||
}
|
||||
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -599,6 +560,8 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||
|
||||
n := model.ParseName(args[0])
|
||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
@@ -606,18 +569,19 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
if strings.Contains(err.Error(), "access denied") {
|
||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||
}
|
||||
host := model.ParseName(args[0]).Host
|
||||
isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
|
||||
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
|
||||
// the user has not added their ollama key to ollama.com
|
||||
// re-throw an error with a more user-friendly message
|
||||
return errFromUnknownKey(err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
p.Stop()
|
||||
spinner.Stop()
|
||||
|
||||
destination := n.String()
|
||||
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
|
||||
destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
|
||||
}
|
||||
fmt.Printf("\nYou can find your model at:\n\n")
|
||||
fmt.Printf("\t%s\n", destination)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1075,7 +1039,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: opts.Messages,
|
||||
Format: opts.Format,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
}
|
||||
|
||||
@@ -1162,7 +1126,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
Prompt: opts.Prompt,
|
||||
Context: generateContext,
|
||||
Images: opts.Images,
|
||||
Format: opts.Format,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
System: opts.System,
|
||||
Options: opts.Options,
|
||||
KeepAlive: opts.KeepAlive,
|
||||
@@ -1482,6 +1446,7 @@ func NewCLI() *cobra.Command {
|
||||
envVars["OLLAMA_SCHED_SPREAD"],
|
||||
envVars["OLLAMA_TMPDIR"],
|
||||
envVars["OLLAMA_FLASH_ATTENTION"],
|
||||
envVars["OLLAMA_KV_CACHE_TYPE"],
|
||||
envVars["OLLAMA_LLM_LIBRARY"],
|
||||
envVars["OLLAMA_GPU_OVERHEAD"],
|
||||
envVars["OLLAMA_LOAD_TIMEOUT"],
|
||||
|
134
cmd/cmd_test.go
134
cmd/cmd_test.go
@@ -4,10 +4,10 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -179,18 +179,14 @@ Weigh anchor!
|
||||
|
||||
t.Run("license", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
license, err := os.ReadFile(filepath.Join("..", "LICENSE"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
license := "MIT License\nCopyright (c) Ollama\n"
|
||||
if err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "test",
|
||||
ParameterSize: "7B",
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
License: string(license),
|
||||
License: license,
|
||||
}, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -369,3 +365,127 @@ func TestGetModelfileName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
|
||||
expectedError string
|
||||
expectedOutput string
|
||||
}{
|
||||
{
|
||||
name: "successful push",
|
||||
modelName: "test-model",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST request, got %s", r.Method)
|
||||
}
|
||||
|
||||
var req api.PushRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name != "test-model" {
|
||||
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
||||
}
|
||||
|
||||
// Simulate progress updates
|
||||
responses := []api.ProgressResponse{
|
||||
{Status: "preparing manifest"},
|
||||
{Digest: "sha256:abc123456789", Total: 100, Completed: 50},
|
||||
{Digest: "sha256:abc123456789", Total: 100, Completed: 100},
|
||||
}
|
||||
|
||||
for _, resp := range responses {
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
|
||||
},
|
||||
{
|
||||
name: "unauthorized push",
|
||||
modelName: "unauthorized-model",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "access denied",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
|
||||
handler(w, r)
|
||||
return
|
||||
}
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.SetContext(context.TODO())
|
||||
|
||||
// Redirect stderr to capture progress output
|
||||
oldStderr := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
// Capture stdout for the "Model pushed" message
|
||||
oldStdout := os.Stdout
|
||||
outR, outW, _ := os.Pipe()
|
||||
os.Stdout = outW
|
||||
|
||||
err := PushHandler(cmd, []string{tt.modelName})
|
||||
|
||||
// Restore stderr
|
||||
w.Close()
|
||||
os.Stderr = oldStderr
|
||||
// drain the pipe
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Restore stdout and get output
|
||||
outW.Close()
|
||||
os.Stdout = oldStdout
|
||||
stdout, _ := io.ReadAll(outR)
|
||||
|
||||
if tt.expectedError == "" {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if tt.expectedOutput != "" {
|
||||
if got := string(stdout); got != tt.expectedOutput {
|
||||
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -319,8 +319,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
|
||||
sb.Reset()
|
||||
continue
|
||||
default:
|
||||
@@ -516,7 +514,7 @@ func extractFileNames(input string) []string {
|
||||
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
||||
// and followed by more characters and a file extension
|
||||
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b`
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
|
||||
return re.FindAllString(input, -1)
|
||||
|
@@ -12,44 +12,45 @@ import (
|
||||
func TestExtractFilenames(t *testing.T) {
|
||||
// Unix style paths
|
||||
input := ` some preamble
|
||||
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2
|
||||
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.svg`
|
||||
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
|
||||
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
|
||||
res := extractFileNames(input)
|
||||
assert.Len(t, res, 5)
|
||||
assert.Contains(t, res[0], "one.png")
|
||||
assert.Contains(t, res[1], "two.jpg")
|
||||
assert.Contains(t, res[2], "three.jpeg")
|
||||
assert.Contains(t, res[3], "four.png")
|
||||
assert.Contains(t, res[4], "five.svg")
|
||||
assert.Contains(t, res[4], "five.JPG")
|
||||
assert.NotContains(t, res[4], '"')
|
||||
assert.NotContains(t, res, "inbtween")
|
||||
assert.NotContains(t, res, "inbetween1")
|
||||
assert.NotContains(t, res, "./1.svg")
|
||||
|
||||
// Windows style paths
|
||||
input = ` some preamble
|
||||
c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2
|
||||
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
|
||||
./relative\ path/five.svg inbetween5 "./relative with/spaces/six.png inbetween6
|
||||
d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
|
||||
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.svg some ending
|
||||
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
|
||||
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
|
||||
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
|
||||
`
|
||||
res = extractFileNames(input)
|
||||
assert.Len(t, res, 10)
|
||||
assert.NotContains(t, res, "inbtween")
|
||||
assert.NotContains(t, res, "inbetween2")
|
||||
assert.Contains(t, res[0], "one.png")
|
||||
assert.Contains(t, res[0], "c:")
|
||||
assert.Contains(t, res[1], "two.jpg")
|
||||
assert.Contains(t, res[1], "c:")
|
||||
assert.Contains(t, res[2], "three.jpeg")
|
||||
assert.Contains(t, res[3], "four.png")
|
||||
assert.Contains(t, res[4], "five.svg")
|
||||
assert.Contains(t, res[4], "five.JPG")
|
||||
assert.Contains(t, res[5], "six.png")
|
||||
assert.Contains(t, res[6], "seven.svg")
|
||||
assert.Contains(t, res[6], "seven.JPEG")
|
||||
assert.Contains(t, res[6], "d:")
|
||||
assert.Contains(t, res[7], "eight.png")
|
||||
assert.Contains(t, res[7], "c:")
|
||||
assert.Contains(t, res[8], "nine.png")
|
||||
assert.Contains(t, res[8], "d:")
|
||||
assert.Contains(t, res[9], "ten.svg")
|
||||
assert.Contains(t, res[9], "ten.PNG")
|
||||
assert.Contains(t, res[9], "E:")
|
||||
}
|
||||
|
||||
|
@@ -10,6 +10,7 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
@@ -60,7 +61,25 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||
addedTokens[t.Content] = t
|
||||
}
|
||||
|
||||
t.Merges = tt.Model.Merges
|
||||
if len(tt.Model.Merges) == 0 {
|
||||
// noop; merges is empty
|
||||
} else if err := json.Unmarshal(tt.Model.Merges, &t.Merges); err == nil {
|
||||
// noop; merges is []string
|
||||
} else if merges, err := func() ([][]string, error) {
|
||||
var merges [][]string
|
||||
if err := json.Unmarshal(tt.Model.Merges, &merges); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return merges, nil
|
||||
}(); err == nil {
|
||||
t.Merges = make([]string, len(merges))
|
||||
for i := range merges {
|
||||
t.Merges[i] = strings.Join(merges[i], " ")
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("could not parse tokenizer merges. expected []string or [][]string: %w", err)
|
||||
}
|
||||
|
||||
sha256sum := sha256.New()
|
||||
for _, pt := range tt.PreTokenizer.PreTokenizers {
|
||||
@@ -156,9 +175,9 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||
type tokenizer struct {
|
||||
AddedTokens []token `json:"added_tokens"`
|
||||
Model struct {
|
||||
Type string `json:"type"`
|
||||
Vocab map[string]int `json:"vocab"`
|
||||
Merges []string `json:"merges"`
|
||||
Type string `json:"type"`
|
||||
Vocab map[string]int `json:"vocab"`
|
||||
Merges json.RawMessage `json:"merges"`
|
||||
} `json:"model"`
|
||||
|
||||
PreTokenizer struct {
|
||||
|
@@ -191,6 +191,62 @@ func TestParseTokenizer(t *testing.T) {
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list string merges",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
"tokenizer.json": strings.NewReader(`{
|
||||
"model": {
|
||||
"merges": [
|
||||
"a b",
|
||||
"c d",
|
||||
"e f"
|
||||
]
|
||||
}
|
||||
}`),
|
||||
}),
|
||||
want: &Tokenizer{
|
||||
Vocabulary: &Vocabulary{
|
||||
Model: "gpt2",
|
||||
},
|
||||
Merges: []string{
|
||||
"a b",
|
||||
"c d",
|
||||
"e f",
|
||||
},
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list list string merges",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
"tokenizer.json": strings.NewReader(`{
|
||||
"model": {
|
||||
"merges": [
|
||||
[
|
||||
"a", "b"
|
||||
],
|
||||
[
|
||||
"c", "d"
|
||||
],
|
||||
[
|
||||
"e", "f"
|
||||
]
|
||||
]
|
||||
}
|
||||
}`),
|
||||
}),
|
||||
want: &Tokenizer{
|
||||
Vocabulary: &Vocabulary{
|
||||
Model: "gpt2",
|
||||
},
|
||||
Merges: []string{
|
||||
"a b",
|
||||
"c d",
|
||||
"e f",
|
||||
},
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
|
@@ -183,3 +183,17 @@ func (si SystemInfo) GetOptimalThreadCount() int {
|
||||
|
||||
return coreCount
|
||||
}
|
||||
|
||||
// For each GPU, check if it does NOT support flash attention
|
||||
func (l GpuInfoList) FlashAttentionSupported() bool {
|
||||
for _, gpu := range l {
|
||||
supportsFA := gpu.Library == "metal" ||
|
||||
(gpu.Library == "cuda" && gpu.DriverMajor >= 7) ||
|
||||
gpu.Library == "rocm"
|
||||
|
||||
if !supportsFA {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@@ -49,10 +49,10 @@ Advanced parameters (optional):
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `system`: system message to (overrides what is defined in the `Modelfile`)
|
||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||
|
||||
#### JSON mode
|
||||
|
||||
|
28
docs/faq.md
28
docs/faq.md
@@ -151,7 +151,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
|
||||
|
||||
Ollama runs an HTTP server and can be exposed using a proxy server such as Nginx. To do so, configure the proxy to forward requests and optionally set required headers (if not exposing Ollama on the network). For example, with Nginx:
|
||||
|
||||
```
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name example.com; # Replace with your domain or IP
|
||||
@@ -285,4 +285,28 @@ Note: Windows with Radeon GPUs currently default to 1 model maximum due to limit
|
||||
|
||||
## How does Ollama load models on multiple GPUs?
|
||||
|
||||
Installing multiple GPUs of the same brand can be a great way to increase your available VRAM to load larger models. When you load a new model, Ollama evaluates the required VRAM for the model against what is currently available. If the model will entirely fit on any single GPU, Ollama will load the model on that GPU. This typically provides the best performance as it reduces the amount of data transfering across the PCI bus during inference. If the model does not fit entirely on one GPU, then it will be spread across all the available GPUs.
|
||||
When loading a new model, Ollama evaluates the required VRAM for the model against what is currently available. If the model will entirely fit on any single GPU, Ollama will load the model on that GPU. This typically provides the best performance as it reduces the amount of data transferring across the PCI bus during inference. If the model does not fit entirely on one GPU, then it will be spread across all the available GPUs.
|
||||
|
||||
## How can I enable Flash Attention?
|
||||
|
||||
Flash Attention is a feature of most modern models that can significantly reduce memory usage as the context size grows. To enable Flash Attention, set the `OLLAMA_FLASH_ATTENTION` environment variable to `1` when starting the Ollama server.
|
||||
|
||||
## How can I set the quantization type for the K/V cache?
|
||||
|
||||
The K/V context cache can be quantized to significantly reduce memory usage when Flash Attention is enabled.
|
||||
|
||||
To use quantized K/V cache with Ollama you can set the following environment variable:
|
||||
|
||||
- `OLLAMA_KV_CACHE_TYPE` - The quantization type for the K/V cache. Default is `f16`.
|
||||
|
||||
> Note: Currently this is a global option - meaning all models will run with the specified quantization type.
|
||||
|
||||
The currently available K/V cache quantization types are:
|
||||
|
||||
- `f16` - high precision and memory usage (default).
|
||||
- `q8_0` - 8-bit quantization, uses approximately 1/2 the memory of `f16` with a very small loss in precision, this usually has no noticeable impact on the model's quality (recommended if not using f16).
|
||||
- `q4_0` - 4-bit quantization, uses approximately 1/4 the memory of `f16` with a small-medium loss in precision that may be more noticeable at higher context sizes.
|
||||
|
||||
How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count.
|
||||
|
||||
You may need to experiment with different quantization types to find the best balance between memory usage and quality.
|
||||
|
@@ -63,7 +63,7 @@ SYSTEM You are Mario from super mario bros, acting as an assistant.
|
||||
To use this:
|
||||
|
||||
1. Save it as a file (e.g. `Modelfile`)
|
||||
2. `ollama create choose-a-model-name -f <location of the file e.g. ./Modelfile>'`
|
||||
2. `ollama create choose-a-model-name -f <location of the file e.g. ./Modelfile>`
|
||||
3. `ollama run choose-a-model-name`
|
||||
4. Start using the model!
|
||||
|
||||
@@ -156,7 +156,7 @@ PARAMETER <parameter> <parametervalue>
|
||||
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
||||
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
||||
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
|
||||
| num_predict | Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) | int | num_predict 42 |
|
||||
| num_predict | Maximum number of tokens to predict when generating text. (Default: -1, infinite generation) | int | num_predict 42 |
|
||||
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
|
||||
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
|
||||
| min_p | Alternative to the top_p, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out. (Default: 0.0) | float | min_p 0.05 |
|
||||
|
@@ -199,6 +199,8 @@ curl http://localhost:11434/v1/embeddings \
|
||||
- [x] `seed`
|
||||
- [x] `stop`
|
||||
- [x] `stream`
|
||||
- [x] `stream_options`
|
||||
- [x] `include_usage`
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `max_tokens`
|
||||
@@ -227,6 +229,8 @@ curl http://localhost:11434/v1/embeddings \
|
||||
- [x] `seed`
|
||||
- [x] `stop`
|
||||
- [x] `stream`
|
||||
- [x] `stream_options`
|
||||
- [x] `include_usage`
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `max_tokens`
|
||||
|
@@ -1,83 +0,0 @@
|
||||
# Running Ollama on Fly.io GPU Instances
|
||||
|
||||
Ollama runs with little to no configuration on [Fly.io GPU instances](https://fly.io/docs/gpus/gpu-quickstart/). If you don't have access to GPUs yet, you'll need to [apply for access](https://fly.io/gpu/) on the waitlist. Once you're accepted, you'll get an email with instructions on how to get started.
|
||||
|
||||
Create a new app with `fly apps create`:
|
||||
|
||||
```bash
|
||||
fly apps create
|
||||
```
|
||||
|
||||
Then create a `fly.toml` file in a new folder that looks like this:
|
||||
|
||||
```toml
|
||||
app = "sparkling-violet-709"
|
||||
primary_region = "ord"
|
||||
vm.size = "a100-40gb" # see https://fly.io/docs/gpus/gpu-quickstart/ for more info
|
||||
|
||||
[build]
|
||||
image = "ollama/ollama"
|
||||
|
||||
[http_service]
|
||||
internal_port = 11434
|
||||
force_https = false
|
||||
auto_stop_machines = true
|
||||
auto_start_machines = true
|
||||
min_machines_running = 0
|
||||
processes = ["app"]
|
||||
|
||||
[mounts]
|
||||
source = "models"
|
||||
destination = "/root/.ollama"
|
||||
initial_size = "100gb"
|
||||
```
|
||||
|
||||
Then create a [new private IPv6 address](https://fly.io/docs/reference/private-networking/#flycast-private-load-balancing) for your app:
|
||||
|
||||
```bash
|
||||
fly ips allocate-v6 --private
|
||||
```
|
||||
|
||||
Then deploy your app:
|
||||
|
||||
```bash
|
||||
fly deploy
|
||||
```
|
||||
|
||||
And finally you can access it interactively with a new Fly.io Machine:
|
||||
|
||||
```
|
||||
fly machine run -e OLLAMA_HOST=http://your-app-name.flycast --shell ollama/ollama
|
||||
```
|
||||
|
||||
```bash
|
||||
$ ollama run openchat:7b-v3.5-fp16
|
||||
>>> How do I bake chocolate chip cookies?
|
||||
To bake chocolate chip cookies, follow these steps:
|
||||
|
||||
1. Preheat the oven to 375°F (190°C) and line a baking sheet with parchment paper or silicone baking mat.
|
||||
|
||||
2. In a large bowl, mix together 1 cup of unsalted butter (softened), 3/4 cup granulated sugar, and 3/4
|
||||
cup packed brown sugar until light and fluffy.
|
||||
|
||||
3. Add 2 large eggs, one at a time, to the butter mixture, beating well after each addition. Stir in 1
|
||||
teaspoon of pure vanilla extract.
|
||||
|
||||
4. In a separate bowl, whisk together 2 cups all-purpose flour, 1/2 teaspoon baking soda, and 1/2 teaspoon
|
||||
salt. Gradually add the dry ingredients to the wet ingredients, stirring until just combined.
|
||||
|
||||
5. Fold in 2 cups of chocolate chips (or chunks) into the dough.
|
||||
|
||||
6. Drop rounded tablespoons of dough onto the prepared baking sheet, spacing them about 2 inches apart.
|
||||
|
||||
7. Bake for 10-12 minutes, or until the edges are golden brown. The centers should still be slightly soft.
|
||||
|
||||
8. Allow the cookies to cool on the baking sheet for a few minutes before transferring them to a wire rack
|
||||
to cool completely.
|
||||
|
||||
Enjoy your homemade chocolate chip cookies!
|
||||
```
|
||||
|
||||
When you set it up like this, it will automatically turn off when you're done using it. Then when you access it again, it will automatically turn back on. This is a great way to save money on GPU instances when you're not using them. If you want a persistent wake-on-use connection to your Ollama instance, you can set up a [connection to your Fly network using WireGuard](https://fly.io/docs/reference/private-networking/#discovering-apps-through-dns-on-a-wireguard-connection). Then you can access your Ollama instance at `http://your-app-name.flycast`.
|
||||
|
||||
And that's it!
|
@@ -1,77 +0,0 @@
|
||||
# Using LangChain with Ollama using JavaScript
|
||||
|
||||
In this tutorial, we are going to use JavaScript with LangChain and Ollama to learn about something just a touch more recent. In August 2023, there was a series of wildfires on Maui. There is no way an LLM trained before that time can know about this, since their training data would not include anything as recent as that. So we can find the [Wikipedia article about the fires](https://en.wikipedia.org/wiki/2023_Hawaii_wildfires) and ask questions about the contents.
|
||||
|
||||
To get started, let's just use **LangChain** to ask a simple question to a model. To do this with JavaScript, we need to install **LangChain**:
|
||||
|
||||
```bash
|
||||
npm install @langchain/community
|
||||
```
|
||||
|
||||
Now we can start building out our JavaScript:
|
||||
|
||||
```javascript
|
||||
import { Ollama } from "@langchain/community/llms/ollama";
|
||||
|
||||
const ollama = new Ollama({
|
||||
baseUrl: "http://localhost:11434",
|
||||
model: "llama3.2",
|
||||
});
|
||||
|
||||
const answer = await ollama.invoke(`why is the sky blue?`);
|
||||
|
||||
console.log(answer);
|
||||
```
|
||||
|
||||
That will get us the same thing as if we ran `ollama run llama3.2 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
|
||||
|
||||
```bash
|
||||
npm install cheerio
|
||||
```
|
||||
|
||||
```javascript
|
||||
import { CheerioWebBaseLoader } from "langchain/document_loaders/web/cheerio";
|
||||
|
||||
const loader = new CheerioWebBaseLoader("https://en.wikipedia.org/wiki/2023_Hawaii_wildfires");
|
||||
const data = await loader.load();
|
||||
```
|
||||
|
||||
That will load the document. Although this page is smaller than the Odyssey, it is certainly bigger than the context size for most LLMs. So we are going to need to split into smaller pieces, and then select just the pieces relevant to our question. This is a great use for a vector datastore. In this example, we will use the **MemoryVectorStore** that is part of **LangChain**. But there is one more thing we need to get the content into the datastore. We have to run an embeddings process that converts the tokens in the text into a series of vectors. And for that, we are going to use **Tensorflow**. There is a lot of stuff going on in this one. First, install the **Tensorflow** components that we need.
|
||||
|
||||
```javascript
|
||||
npm install @tensorflow/tfjs-core@3.6.0 @tensorflow/tfjs-converter@3.6.0 @tensorflow-models/universal-sentence-encoder@1.3.3 @tensorflow/tfjs-node@4.10.0
|
||||
```
|
||||
|
||||
If you just install those components without the version numbers, it will install the latest versions, but there are conflicts within **Tensorflow**, so you need to install the compatible versions.
|
||||
|
||||
```javascript
|
||||
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
|
||||
import { MemoryVectorStore } from "langchain/vectorstores/memory";
|
||||
import "@tensorflow/tfjs-node";
|
||||
import { TensorFlowEmbeddings } from "langchain/embeddings/tensorflow";
|
||||
|
||||
// Split the text into 500 character chunks. And overlap each chunk by 20 characters
|
||||
const textSplitter = new RecursiveCharacterTextSplitter({
|
||||
chunkSize: 500,
|
||||
chunkOverlap: 20
|
||||
});
|
||||
const splitDocs = await textSplitter.splitDocuments(data);
|
||||
|
||||
// Then use the TensorFlow Embedding to store these chunks in the datastore
|
||||
const vectorStore = await MemoryVectorStore.fromDocuments(splitDocs, new TensorFlowEmbeddings());
|
||||
```
|
||||
|
||||
To connect the datastore to a question asked to a LLM, we need to use the concept at the heart of **LangChain**: the chain. Chains are a way to connect a number of activities together to accomplish a particular tasks. There are a number of chain types available, but for this tutorial we are using the **RetrievalQAChain**.
|
||||
|
||||
```javascript
|
||||
import { RetrievalQAChain } from "langchain/chains";
|
||||
|
||||
const retriever = vectorStore.asRetriever();
|
||||
const chain = RetrievalQAChain.fromLLM(ollama, retriever);
|
||||
const result = await chain.call({query: "When was Hawaii's request for a major disaster declaration approved?"});
|
||||
console.log(result.text)
|
||||
```
|
||||
|
||||
So we created a retriever, which is a way to return the chunks that match a query from a datastore. And then connect the retriever and the model via a chain. Finally, we send a query to the chain, which results in an answer using our document as a source. The answer it returned was correct, August 10, 2023.
|
||||
|
||||
And that is a simple introduction to what you can do with **LangChain** and **Ollama.**
|
@@ -1,85 +0,0 @@
|
||||
# Using LangChain with Ollama in Python
|
||||
|
||||
Let's imagine we are studying the classics, such as **the Odyssey** by **Homer**. We might have a question about Neleus and his family. If you ask llama2 for that info, you may get something like:
|
||||
|
||||
> I apologize, but I'm a large language model, I cannot provide information on individuals or families that do not exist in reality. Neleus is not a real person or character, and therefore does not have a family or any other personal details. My apologies for any confusion. Is there anything else I can help you with?
|
||||
|
||||
This sounds like a typical censored response, but even llama2-uncensored gives a mediocre answer:
|
||||
|
||||
> Neleus was a legendary king of Pylos and the father of Nestor, one of the Argonauts. His mother was Clymene, a sea nymph, while his father was Neptune, the god of the sea.
|
||||
|
||||
So let's figure out how we can use **LangChain** with Ollama to ask our question to the actual document, the Odyssey by Homer, using Python.
|
||||
|
||||
Let's start by asking a simple question that we can get an answer to from the **Llama3** model using **Ollama**. First, we need to install the **LangChain** package:
|
||||
|
||||
`pip install langchain_community`
|
||||
|
||||
Then we can create a model and ask the question:
|
||||
|
||||
```python
|
||||
from langchain_community.llms import Ollama
|
||||
ollama = Ollama(
|
||||
base_url='http://localhost:11434',
|
||||
model="llama3"
|
||||
)
|
||||
print(ollama.invoke("why is the sky blue"))
|
||||
```
|
||||
|
||||
Notice that we are defining the model and the base URL for Ollama.
|
||||
|
||||
Now let's load a document to ask questions against. I'll load up the Odyssey by Homer, which you can find at Project Gutenberg. We will need **WebBaseLoader** which is part of **LangChain** and loads text from any webpage. On my machine, I also needed to install **bs4** to get that to work, so run `pip install bs4`.
|
||||
|
||||
```python
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
loader = WebBaseLoader("https://www.gutenberg.org/files/1727/1727-h/1727-h.htm")
|
||||
data = loader.load()
|
||||
```
|
||||
|
||||
This file is pretty big. Just the preface is 3000 tokens. Which means the full document won't fit into the context for the model. So we need to split it up into smaller pieces.
|
||||
|
||||
```python
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
text_splitter=RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
|
||||
all_splits = text_splitter.split_documents(data)
|
||||
```
|
||||
|
||||
It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install chromadb`
|
||||
We also need to pull embedding model: `ollama pull nomic-embed-text`
|
||||
```python
|
||||
from langchain.embeddings import OllamaEmbeddings
|
||||
from langchain.vectorstores import Chroma
|
||||
oembed = OllamaEmbeddings(base_url="http://localhost:11434", model="nomic-embed-text")
|
||||
vectorstore = Chroma.from_documents(documents=all_splits, embedding=oembed)
|
||||
```
|
||||
|
||||
Now let's ask a question from the document. **Who was Neleus, and who is in his family?** Neleus is a character in the Odyssey, and the answer can be found in our text.
|
||||
|
||||
```python
|
||||
question="Who is Neleus and who is in Neleus' family?"
|
||||
docs = vectorstore.similarity_search(question)
|
||||
len(docs)
|
||||
```
|
||||
|
||||
This will output the number of matches for chunks of data similar to the search.
|
||||
|
||||
The next thing is to send the question and the relevant parts of the docs to the model to see if we can get a good answer. But we are stitching two parts of the process together, and that is called a chain. This means we need to define a chain:
|
||||
|
||||
```python
|
||||
from langchain.chains import RetrievalQA
|
||||
qachain=RetrievalQA.from_chain_type(ollama, retriever=vectorstore.as_retriever())
|
||||
res = qachain.invoke({"query": question})
|
||||
print(res['result'])
|
||||
```
|
||||
|
||||
The answer received from this chain was:
|
||||
|
||||
> Neleus is a character in Homer's "Odyssey" and is mentioned in the context of Penelope's suitors. Neleus is the father of Chloris, who is married to Neleus and bears him several children, including Nestor, Chromius, Periclymenus, and Pero. Amphinomus, the son of Nisus, is also mentioned as a suitor of Penelope and is known for his good natural disposition and agreeable conversation.
|
||||
|
||||
It's not a perfect answer, as it implies Neleus married his daughter when actually Chloris "was the youngest daughter to Amphion son of Iasus and king of Minyan Orchomenus, and was Queen in Pylos".
|
||||
|
||||
I updated the chunk_overlap for the text splitter to 20 and tried again and got a much better answer:
|
||||
|
||||
> Neleus is a character in Homer's epic poem "The Odyssey." He is the husband of Chloris, who is the youngest daughter of Amphion son of Iasus and king of Minyan Orchomenus. Neleus has several children with Chloris, including Nestor, Chromius, Periclymenus, and Pero.
|
||||
|
||||
And that is a much better answer.
|
@@ -1,15 +0,0 @@
|
||||
# Running Ollama on NVIDIA Jetson Devices
|
||||
|
||||
Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/) and should run out of the box with the standard installation instructions.
|
||||
|
||||
The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack), but should also work on JetPack 6.0.
|
||||
|
||||
- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh`
|
||||
- Pull the model you want to use (e.g. mistral): `ollama pull mistral`
|
||||
- Start an interactive session: `ollama run mistral`
|
||||
|
||||
And that's it!
|
||||
|
||||
# Running Ollama in Docker
|
||||
|
||||
When running GPU accelerated applications in Docker, it is highly recommended to use [dusty-nv jetson-containers repo](https://github.com/dusty-nv/jetson-containers).
|
@@ -153,6 +153,8 @@ var (
|
||||
Debug = Bool("OLLAMA_DEBUG")
|
||||
// FlashAttention enables the experimental flash attention feature.
|
||||
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
|
||||
// KvCacheType is the quantization type for the K/V cache.
|
||||
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
||||
// NoHistory disables readline history.
|
||||
NoHistory = Bool("OLLAMA_NOHISTORY")
|
||||
// NoPrune disables pruning of model blobs on startup.
|
||||
@@ -234,6 +236,7 @@ func AsMap() map[string]EnvVar {
|
||||
ret := map[string]EnvVar{
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from langchain.llms import Ollama
|
||||
|
||||
input = input("What is your question?")
|
||||
input = input("What is your question?\n> ")
|
||||
llm = Ollama(model="llama3.2")
|
||||
res = llm.predict(input)
|
||||
res = llm.invoke(input)
|
||||
print (res)
|
||||
|
4
go.mod
4
go.mod
@@ -7,7 +7,7 @@ require (
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/google/uuid v1.1.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
@@ -29,7 +29,7 @@ require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/chewxy/hm v1.0.0 // indirect
|
||||
github.com/chewxy/math32 v1.10.1 // indirect
|
||||
github.com/chewxy/math32 v1.11.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
|
13
go.sum
13
go.sum
@@ -21,8 +21,8 @@ github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA
|
||||
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
|
||||
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
|
||||
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
|
||||
github.com/chewxy/math32 v1.10.1 h1:LFpeY0SLJXeaiej/eIp2L40VYfscTvKh/FSEZ68uMkU=
|
||||
github.com/chewxy/math32 v1.10.1/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs=
|
||||
github.com/chewxy/math32 v1.11.0 h1:8sek2JWqeaKkVnHa7bPVqCEOUPbARo4SGxs6toKyAOo=
|
||||
github.com/chewxy/math32 v1.11.0/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
@@ -113,8 +113,9 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
@@ -230,8 +231,6 @@ golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+o
|
||||
golang.org/x/image v0.0.0-20200618115811-c13761719519/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/image v0.0.0-20201208152932-35266b937fa6/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/image v0.0.0-20210216034530-4410531fe030/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
|
||||
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/image v0.22.0 h1:UtK5yLUzilVrkjMAZAZ34DXGpASN8i8pj8g+O+yd10g=
|
||||
golang.org/x/image v0.22.0/go.mod h1:9hPFhljd4zZ1GNSIZJ49sqbp45GKK9t6w+iXvGqZUz4=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
@@ -267,8 +266,6 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ=
|
||||
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -295,8 +292,6 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
|
||||
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func TestMaxQueue(t *testing.T) {
|
||||
@@ -27,12 +26,8 @@ func TestMaxQueue(t *testing.T) {
|
||||
|
||||
// Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU
|
||||
// Also note that by default Darwin can't sustain > ~128 connections without adjusting limits
|
||||
threadCount := 32
|
||||
if maxQueue := envconfig.MaxQueue(); maxQueue != 0 {
|
||||
threadCount = int(maxQueue)
|
||||
} else {
|
||||
t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
|
||||
}
|
||||
threadCount := 16
|
||||
t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "orca-mini",
|
||||
|
@@ -93,7 +93,7 @@ make -j
|
||||
|
||||
## Vendoring
|
||||
|
||||
Ollama currently vendors [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [ggml](https://github.com/ggerganov/ggml) through a vendoring model. While we generally strive to contribute changes back upstream to avoid drift, we cary a small set of patches which are applied to the tracking commit. A set of make targets are available to aid developers in updating to a newer tracking commit, or to work on changes.
|
||||
Ollama currently vendors [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [ggml](https://github.com/ggerganov/ggml) through a vendoring model. While we generally strive to contribute changes back upstream to avoid drift, we cary a small set of patches which are applied to the tracking commit. A set of make targets are available to aid developers in updating to a newer tracking commit, or to work on changes.
|
||||
|
||||
If you update the vendoring code, start by running the following command to establish the tracking llama.cpp repo in the `./vendor/` directory.
|
||||
|
||||
@@ -105,35 +105,35 @@ make apply-patches
|
||||
|
||||
**Pin to new base commit**
|
||||
|
||||
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring.env`
|
||||
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring`
|
||||
|
||||
#### Applying patches
|
||||
|
||||
When updating to a newer base commit, the existing patches may not apply cleanly and require manual merge resolution.
|
||||
|
||||
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
|
||||
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
|
||||
|
||||
```
|
||||
make apply-patches
|
||||
```
|
||||
|
||||
If you see an error message about a conflict, go into the `./vendor/` directory, and perform merge resolution using your preferred tool to the patch commit which failed. Save the file(s) and continue the patch series with `git am --continue` . If any additional patches fail, follow the same pattern until the full patch series is applied. Once finished, run a final `create-patches` and `sync` target to ensure everything is updated.
|
||||
If you see an error message about a conflict, go into the `./vendor/` directory, and perform merge resolution using your preferred tool to the patch commit which failed. Save the file(s) and continue the patch series with `git am --continue` . If any additional patches fail, follow the same pattern until the full patch series is applied. Once finished, run a final `create-patches` and `sync` target to ensure everything is updated.
|
||||
|
||||
```
|
||||
make create-patches sync
|
||||
```
|
||||
|
||||
Build and test Ollama, and make any necessary changes to the Go code based on the new base commit. Submit your PR to the Ollama repo.
|
||||
Build and test Ollama, and make any necessary changes to the Go code based on the new base commit. Submit your PR to the Ollama repo.
|
||||
|
||||
### Generating Patches
|
||||
|
||||
When working on new fixes or features that impact vendored code, use the following model. First get a clean tracking repo with all current patches applied:
|
||||
When working on new fixes or features that impact vendored code, use the following model. First get a clean tracking repo with all current patches applied:
|
||||
|
||||
```
|
||||
make apply-patches
|
||||
```
|
||||
|
||||
Now edit the upstream native code in the `./vendor/` directory. You do not need to commit every change in order to build, a dirty working tree in the tracking repo is OK while developing. Simply save in your editor, and run the following to refresh the vendored code with your changes, build the backend(s) and build ollama:
|
||||
Now edit the upstream native code in the `./vendor/` directory. You do not need to commit every change in order to build, a dirty working tree in the tracking repo is OK while developing. Simply save in your editor, and run the following to refresh the vendored code with your changes, build the backend(s) and build ollama:
|
||||
|
||||
```
|
||||
make sync
|
||||
@@ -142,9 +142,9 @@ go build .
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Do **NOT** run `apply-patches` while you're iterating as that will reset the tracking repo. It will detect a dirty tree and abort, but if your tree is clean and you accidentally ran this target, use `git reflog` to recover your commit(s).
|
||||
> Do **NOT** run `apply-patches` while you're iterating as that will reset the tracking repo. It will detect a dirty tree and abort, but if your tree is clean and you accidentally ran this target, use `git reflog` to recover your commit(s).
|
||||
|
||||
Iterate until you're ready to submit PRs. Once your code is ready, commit a change in the `./vendor/` directory, then generate the patches for ollama with
|
||||
Iterate until you're ready to submit PRs. Once your code is ready, commit a change in the `./vendor/` directory, then generate the patches for ollama with
|
||||
|
||||
```
|
||||
make create-patches
|
||||
@@ -157,4 +157,4 @@ In your `./vendor/` directory, create a branch, and cherry-pick the new commit t
|
||||
|
||||
Commit the changes in the ollama repo and submit a PR to Ollama, which will include the vendored code update with your change, along with the patches.
|
||||
|
||||
After your PR upstream is merged, follow the **Updating Base Commit** instructions above, however first remove your patch before running `apply-patches` since the new base commit contains your change already.
|
||||
After your PR upstream is merged, follow the **Updating Base Commit** instructions above, however first remove your patch before running `apply-patches` since the new base commit contains your change already.
|
||||
|
@@ -85,9 +85,12 @@ COMPILER inline get_compiler() {
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"runtime"
|
||||
"runtime/cgo"
|
||||
"slices"
|
||||
@@ -140,7 +143,7 @@ type ContextParams struct {
|
||||
c C.struct_llama_context_params
|
||||
}
|
||||
|
||||
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams {
|
||||
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
|
||||
params := C.llama_context_default_params()
|
||||
params.n_ctx = C.uint(numCtx)
|
||||
params.n_batch = C.uint(batchSize)
|
||||
@@ -149,9 +152,28 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla
|
||||
params.n_threads_batch = params.n_threads
|
||||
params.embeddings = C.bool(true)
|
||||
params.flash_attn = C.bool(flashAttention)
|
||||
params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
|
||||
params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
|
||||
|
||||
return ContextParams{c: params}
|
||||
}
|
||||
|
||||
// kvCacheTypeFromStr converts a string cache type to the corresponding GGML type value
|
||||
func kvCacheTypeFromStr(s string) C.enum_ggml_type {
|
||||
if s == "" {
|
||||
return C.GGML_TYPE_F16
|
||||
}
|
||||
|
||||
switch s {
|
||||
case "q8_0":
|
||||
return C.GGML_TYPE_Q8_0
|
||||
case "q4_0":
|
||||
return C.GGML_TYPE_Q4_0
|
||||
default:
|
||||
return C.GGML_TYPE_F16
|
||||
}
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
c *C.struct_llama_context
|
||||
numThreads int
|
||||
@@ -680,3 +702,33 @@ func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
|
||||
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
|
||||
C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
|
||||
}
|
||||
|
||||
type JsonSchema struct {
|
||||
Defs map[string]any `json:"$defs,omitempty"`
|
||||
Properties map[string]any `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
func (js JsonSchema) AsGrammar() string {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(js); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
cStr := C.CString(b.String())
|
||||
defer C.free(unsafe.Pointer(cStr))
|
||||
|
||||
// Allocate buffer for grammar output with reasonable size
|
||||
const maxLen = 32768 // 32KB
|
||||
buf := make([]byte, maxLen)
|
||||
|
||||
// Call C function to convert schema to grammar
|
||||
length := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
|
||||
if length == 0 {
|
||||
slog.Warn("unable to convert schema to grammar")
|
||||
}
|
||||
|
||||
return string(buf[:length])
|
||||
}
|
||||
|
@@ -1 +1,70 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestJsonSchema(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
schema JsonSchema
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty schema",
|
||||
schema: JsonSchema{
|
||||
Type: "object",
|
||||
},
|
||||
expected: `array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= object
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null`,
|
||||
},
|
||||
{
|
||||
name: "invalid schema with circular reference",
|
||||
schema: JsonSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"self": map[string]any{
|
||||
"$ref": "#", // Self reference
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "", // Should return empty string for invalid schema
|
||||
},
|
||||
{
|
||||
name: "schema with invalid type",
|
||||
schema: JsonSchema{
|
||||
Type: "invalid_type", // Invalid type
|
||||
Properties: map[string]any{
|
||||
"foo": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "", // Should return empty string for invalid schema
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := tc.schema.AsGrammar()
|
||||
if !strings.EqualFold(strings.TrimSpace(result), strings.TrimSpace(tc.expected)) {
|
||||
if diff := cmp.Diff(tc.expected, result); diff != "" {
|
||||
t.Fatalf("grammar mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -199,6 +199,20 @@ func countCommonPrefix(a []input, b []input) int {
|
||||
return count
|
||||
}
|
||||
|
||||
func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
||||
targetFree := (c.numCtx - numKeep) / 2
|
||||
targetFree = max(targetFree, 1)
|
||||
|
||||
currentFree := c.numCtx - inputLen
|
||||
discard := targetFree - currentFree
|
||||
|
||||
if discard < 0 {
|
||||
discard = 0
|
||||
}
|
||||
|
||||
return discard
|
||||
}
|
||||
|
||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||
//
|
||||
@@ -208,11 +222,7 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||
}
|
||||
|
||||
targetFree := (c.numCtx - numKeep) / 2
|
||||
targetFree = max(targetFree, 1)
|
||||
|
||||
currentFree := c.numCtx - len(slot.Inputs)
|
||||
discard := targetFree - currentFree
|
||||
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
|
||||
|
||||
if discard <= 0 {
|
||||
return nil
|
||||
|
@@ -227,3 +227,66 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShiftDiscard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numCtx int
|
||||
numKeep int
|
||||
inputLen int
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "Shift",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 2048,
|
||||
expected: 1021,
|
||||
},
|
||||
{
|
||||
name: "Max Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 2047,
|
||||
inputLen: 2048,
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "No Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 0,
|
||||
inputLen: 2048,
|
||||
expected: 1024,
|
||||
},
|
||||
{
|
||||
name: "Truncate",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 5000,
|
||||
expected: 3973,
|
||||
},
|
||||
{
|
||||
name: "Truncate Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 2047,
|
||||
inputLen: 5000,
|
||||
expected: 2953,
|
||||
},
|
||||
{
|
||||
name: "No Op",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 512,
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := InputCache{numCtx: tt.numCtx}
|
||||
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
|
||||
if result != tt.expected {
|
||||
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -122,9 +122,11 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
if len(inputs) > s.cache.numCtx {
|
||||
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep)
|
||||
discard := len(inputs) - s.cache.numCtx
|
||||
newInputs := inputs[:params.numKeep]
|
||||
newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...)
|
||||
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||
|
||||
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
|
||||
inputs = newInputs
|
||||
}
|
||||
|
||||
@@ -162,10 +164,16 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
// generating image embeddings for each image
|
||||
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||
var inputs []input
|
||||
var parts []string
|
||||
var matches [][]string
|
||||
|
||||
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
||||
parts := re.Split(prompt, -1)
|
||||
matches := re.FindAllStringSubmatch(prompt, -1)
|
||||
if s.image != nil {
|
||||
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
||||
parts = re.Split(prompt, -1)
|
||||
matches = re.FindAllStringSubmatch(prompt, -1)
|
||||
} else {
|
||||
parts = []string{prompt}
|
||||
}
|
||||
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
@@ -300,6 +308,7 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
s.seqs[seqIndex] = nil
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
|
||||
func (s *Server) run(ctx context.Context) {
|
||||
@@ -649,14 +658,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure that a place to put the sequence is available
|
||||
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
||||
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer s.seqsSem.Release(1)
|
||||
|
||||
s.mu.Lock()
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
@@ -670,11 +683,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
@@ -738,14 +757,18 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure that a place to put the sequence is available
|
||||
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
||||
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting embeddings request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer s.seqsSem.Release(1)
|
||||
|
||||
s.mu.Lock()
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
@@ -756,11 +779,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
embedding := <-seq.embedding
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
||||
@@ -804,12 +833,24 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type multiLPath []string
|
||||
|
||||
func (m *multiLPath) Set(value string) error {
|
||||
*m = append(*m, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *multiLPath) String() string {
|
||||
return strings.Join(*m, ", ")
|
||||
}
|
||||
|
||||
func (s *Server) loadModel(
|
||||
params llama.ModelParams,
|
||||
mpath string,
|
||||
lpath string,
|
||||
lpath multiLPath,
|
||||
ppath string,
|
||||
kvSize int,
|
||||
kvCacheType string,
|
||||
flashAttention bool,
|
||||
threads int,
|
||||
multiUserCache bool,
|
||||
@@ -822,16 +863,18 @@ func (s *Server) loadModel(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention)
|
||||
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
|
||||
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if lpath != "" {
|
||||
err := s.model.ApplyLoraFromFile(s.lc, lpath, 1.0, threads)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
if lpath.String() != "" {
|
||||
for _, path := range lpath {
|
||||
err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -861,7 +904,7 @@ func main() {
|
||||
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
||||
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
|
||||
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||
lpath := flag.String("lora", "", "Path to lora layer file")
|
||||
kvCacheType := flag.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||
port := flag.Int("port", 8080, "Port to expose the server on")
|
||||
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
|
||||
@@ -871,6 +914,9 @@ func main() {
|
||||
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||
requirements := flag.Bool("requirements", false, "print json requirement information")
|
||||
|
||||
var lpaths multiLPath
|
||||
flag.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
|
||||
|
||||
flag.Parse()
|
||||
if *requirements {
|
||||
printRequirements(os.Stdout)
|
||||
@@ -917,7 +963,7 @@ func main() {
|
||||
params := llama.ModelParams{
|
||||
NumGpuLayers: *nGpuLayers,
|
||||
MainGpu: *mainGpu,
|
||||
UseMmap: !*noMmap && *lpath == "",
|
||||
UseMmap: !*noMmap && lpaths.String() == "",
|
||||
UseMlock: *mlock,
|
||||
TensorSplit: tensorSplitFloats,
|
||||
Progress: func(progress float32) {
|
||||
@@ -926,7 +972,7 @@ func main() {
|
||||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
|
||||
go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *kvCacheType, *flashAttention, *threads, *multiUserCache)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
|
29
llama/sampling_ext.cpp
vendored
29
llama/sampling_ext.cpp
vendored
@@ -1,11 +1,13 @@
|
||||
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
||||
#include "sampling.h"
|
||||
#include "sampling_ext.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
|
||||
struct gpt_sampler *gpt_sampler_cinit(
|
||||
const struct llama_model *model, struct gpt_sampler_cparams *params)
|
||||
{
|
||||
try {
|
||||
try
|
||||
{
|
||||
gpt_sampler_params sparams;
|
||||
sparams.top_k = params->top_k;
|
||||
sparams.top_p = params->top_p;
|
||||
@@ -24,7 +26,9 @@ struct gpt_sampler *gpt_sampler_cinit(
|
||||
sparams.seed = params->seed;
|
||||
sparams.grammar = params->grammar;
|
||||
return gpt_sampler_init(model, sparams);
|
||||
} catch (const std::exception & err) {
|
||||
}
|
||||
catch (const std::exception &err)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
@@ -54,3 +58,24 @@ void gpt_sampler_caccept(
|
||||
{
|
||||
gpt_sampler_accept(sampler, id, apply_grammar);
|
||||
}
|
||||
|
||||
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
|
||||
{
|
||||
try
|
||||
{
|
||||
nlohmann::json schema = nlohmann::json::parse(json_schema);
|
||||
std::string grammar_str = json_schema_to_grammar(schema);
|
||||
size_t len = grammar_str.length();
|
||||
if (len >= max_len)
|
||||
{
|
||||
len = max_len - 1;
|
||||
}
|
||||
strncpy(grammar, grammar_str.c_str(), len);
|
||||
return len;
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
strncpy(grammar, "", max_len - 1);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
2
llama/sampling_ext.h
vendored
2
llama/sampling_ext.h
vendored
@@ -47,6 +47,8 @@ extern "C"
|
||||
llama_token id,
|
||||
bool apply_grammar);
|
||||
|
||||
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@@ -32,9 +32,10 @@ const (
|
||||
fileTypeIQ1_S
|
||||
fileTypeIQ4_NL
|
||||
fileTypeIQ3_S
|
||||
fileTypeIQ3_M
|
||||
fileTypeIQ2_S
|
||||
fileTypeIQ4_XS
|
||||
fileTypeIQ2_M
|
||||
fileTypeIQ4_XS
|
||||
fileTypeIQ1_M
|
||||
fileTypeBF16
|
||||
|
||||
@@ -93,6 +94,8 @@ func ParseFileType(s string) (fileType, error) {
|
||||
return fileTypeIQ4_NL, nil
|
||||
case "IQ3_S":
|
||||
return fileTypeIQ3_S, nil
|
||||
case "IQ3_M":
|
||||
return fileTypeIQ3_M, nil
|
||||
case "IQ2_S":
|
||||
return fileTypeIQ2_S, nil
|
||||
case "IQ4_XS":
|
||||
@@ -160,6 +163,8 @@ func (t fileType) String() string {
|
||||
return "IQ4_NL"
|
||||
case fileTypeIQ3_S:
|
||||
return "IQ3_S"
|
||||
case fileTypeIQ3_M:
|
||||
return "IQ3_M"
|
||||
case fileTypeIQ2_S:
|
||||
return "IQ2_S"
|
||||
case fileTypeIQ4_XS:
|
||||
|
36
llm/ggml.go
36
llm/ggml.go
@@ -360,7 +360,7 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffload uint64) {
|
||||
func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
||||
embedding := llm.KV().EmbeddingLength()
|
||||
heads := llm.KV().HeadCount()
|
||||
headsKV := llm.KV().HeadCountKV()
|
||||
@@ -372,7 +372,8 @@ func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffloa
|
||||
|
||||
layers := llm.Tensors().Layers()
|
||||
|
||||
kv = 2 * context * llm.KV().BlockCount() * (embeddingHeadsK + embeddingHeadsV) * headsKV
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
kv = uint64(float64(context*llm.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
case "llama":
|
||||
@@ -527,3 +528,34 @@ func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffloa
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (ggml GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
validKVCacheTypes := []string{"f16", "q8_0", "q4_0"}
|
||||
return slices.Contains(validKVCacheTypes, cacheType)
|
||||
}
|
||||
|
||||
// SupportsFlashAttention checks if the model supports flash attention
|
||||
func (ggml GGML) SupportsFlashAttention() bool {
|
||||
_, isEmbedding := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]
|
||||
if isEmbedding {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCountK := ggml.KV().EmbeddingHeadCountK()
|
||||
headCountV := ggml.KV().EmbeddingHeadCountV()
|
||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
}
|
||||
|
||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
switch cacheType {
|
||||
case "q8_0":
|
||||
return 1 // 1/2 of fp16
|
||||
case "q4_0":
|
||||
return 0.5 // 1/4 of fp16
|
||||
default:
|
||||
return 2 // f16 (default)
|
||||
}
|
||||
}
|
||||
|
@@ -123,7 +123,23 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||
slog.Warn("model missing blk.0 layer size")
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
||||
fa := envconfig.FlashAttention() &&
|
||||
discover.GetGPUInfo().FlashAttentionSupported() &&
|
||||
ggml.SupportsFlashAttention()
|
||||
|
||||
var kvct string
|
||||
if fa {
|
||||
requested := strings.ToLower(envconfig.KvCacheType())
|
||||
if requested != "" && ggml.SupportsKVCacheType(requested) {
|
||||
kvct = requested
|
||||
}
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
|
||||
|
||||
// KV is proportional to the number of layers
|
||||
layerSize += kv / ggml.KV().BlockCount()
|
||||
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
||||
}
|
||||
@@ -131,9 +147,6 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||
graphFullOffload = graphPartialOffload
|
||||
}
|
||||
|
||||
// KV is proportional to the number of layers
|
||||
layerSize += kv / ggml.KV().BlockCount()
|
||||
|
||||
// on metal there's no partial offload overhead
|
||||
if gpus[0].Library == "metal" {
|
||||
graphPartialOffload = graphFullOffload
|
||||
|
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
func TestEstimateGPULayers(t *testing.T) {
|
||||
t.Setenv("OLLAMA_DEBUG", "1")
|
||||
t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16
|
||||
|
||||
modelName := "dummy"
|
||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||
|
@@ -144,10 +144,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
// Loop through potential servers
|
||||
finalErr := errors.New("no suitable llama servers found")
|
||||
|
||||
if len(adapters) > 1 {
|
||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||
}
|
||||
|
||||
rDir, err := runners.Refresh(build.EmbedFS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -201,8 +197,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
}
|
||||
|
||||
if len(adapters) > 0 {
|
||||
// TODO: applying multiple adapters is not supported by the llama.cpp server yet
|
||||
params = append(params, "--lora", adapters[0])
|
||||
for _, adapter := range adapters {
|
||||
params = append(params, "--lora", adapter)
|
||||
}
|
||||
}
|
||||
|
||||
if len(projectors) > 0 {
|
||||
@@ -217,15 +214,36 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
params = append(params, "--threads", strconv.Itoa(defaultThreads))
|
||||
}
|
||||
|
||||
flashAttnEnabled := envconfig.FlashAttention()
|
||||
fa := envconfig.FlashAttention()
|
||||
if fa && !gpus.FlashAttentionSupported() {
|
||||
slog.Warn("flash attention enabled but not supported by gpu")
|
||||
fa = false
|
||||
}
|
||||
|
||||
for _, g := range gpus {
|
||||
// only cuda (compute capability 7+) and metal support flash attention
|
||||
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
|
||||
flashAttnEnabled = false
|
||||
if fa && !ggml.SupportsFlashAttention() {
|
||||
slog.Warn("flash attention enabled but not supported by model")
|
||||
fa = false
|
||||
}
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
|
||||
if fa {
|
||||
slog.Info("enabling flash attention")
|
||||
params = append(params, "--flash-attn")
|
||||
|
||||
// Flash Attention also supports kv cache quantization
|
||||
// Enable if the requested and kv cache type is supported by the model
|
||||
if kvct != "" && ggml.SupportsKVCacheType(kvct) {
|
||||
params = append(params, "--kv-cache-type", kvct)
|
||||
} else {
|
||||
slog.Warn("kv cache type not supported by model", "type", kvct)
|
||||
}
|
||||
} else if kvct != "" && kvct != "f16" {
|
||||
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
|
||||
}
|
||||
|
||||
// mmap has issues with partial offloading on metal
|
||||
// mmap has issues with partial offloading on metal
|
||||
for _, g := range gpus {
|
||||
if g.Library == "metal" &&
|
||||
uint64(opts.NumGPU) > 0 &&
|
||||
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
||||
@@ -234,10 +252,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
}
|
||||
}
|
||||
|
||||
if flashAttnEnabled {
|
||||
params = append(params, "--flash-attn")
|
||||
}
|
||||
|
||||
// Windows CUDA should not use mmap for best performance
|
||||
// Linux with a model larger than free space, mmap leads to thrashing
|
||||
// For CPU loads we want the memory to be allocated, not FS cache
|
||||
@@ -620,27 +634,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
const jsonGrammar = `
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
@@ -670,7 +679,7 @@ type completion struct {
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Format json.RawMessage
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
}
|
||||
@@ -687,7 +696,11 @@ type CompletionResponse struct {
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer s.sem.Release(1)
|
||||
@@ -731,10 +744,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
if req.Format == "json" {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
// TODO (parthsareen): Move conversion to grammar with sampling logic
|
||||
// API should do error handling for invalid formats
|
||||
if req.Format != nil {
|
||||
if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||
slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
}
|
||||
} else if schema, err := func() (llama.JsonSchema, error) {
|
||||
var schema llama.JsonSchema
|
||||
err := json.Unmarshal(req.Format, &schema)
|
||||
return schema, err
|
||||
}(); err == nil {
|
||||
request["grammar"] = schema.AsGrammar()
|
||||
} else {
|
||||
slog.Warn(`format is neither a schema or "json"`, "format", req.Format)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -865,7 +890,11 @@ type EmbeddingResponse struct {
|
||||
|
||||
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting embedding request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer s.sem.Release(1)
|
||||
|
168
openai/openai.go
168
openai/openai.go
@@ -62,7 +62,12 @@ type Usage struct {
|
||||
}
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
Type string `json:"type"`
|
||||
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
type JsonSchema struct {
|
||||
Schema map[string]any `json:"schema"`
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
@@ -70,10 +75,15 @@ type EmbedRequest struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
StreamOptions *StreamOptions `json:"stream_options"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
Seed *int `json:"seed"`
|
||||
Stop any `json:"stop"`
|
||||
@@ -102,21 +112,23 @@ type ChatCompletionChunk struct {
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Choices []ChunkChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
||||
type CompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
Seed *int `json:"seed"`
|
||||
Stop any `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature *float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Suffix string `json:"suffix"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
Seed *int `json:"seed"`
|
||||
Stop any `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
StreamOptions *StreamOptions `json:"stream_options"`
|
||||
Temperature *float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Suffix string `json:"suffix"`
|
||||
}
|
||||
|
||||
type Completion struct {
|
||||
@@ -136,10 +148,12 @@ type CompletionChunk struct {
|
||||
Choices []CompleteChunkChoice `json:"choices"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Index int `json:"index"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
@@ -191,6 +205,14 @@ func NewError(code int, message string) ErrorResponse {
|
||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||
}
|
||||
|
||||
func toUsage(r api.ChatResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
func toolCallId() string {
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, 8)
|
||||
@@ -200,12 +222,13 @@ func toolCallId() string {
|
||||
return "call_" + strings.ToLower(string(b))
|
||||
}
|
||||
|
||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
||||
for i, tc := range r.Message.ToolCalls {
|
||||
func toToolCalls(tc []api.ToolCall) []ToolCall {
|
||||
toolCalls := make([]ToolCall, len(tc))
|
||||
for i, tc := range tc {
|
||||
toolCalls[i].ID = toolCallId()
|
||||
toolCalls[i].Type = "function"
|
||||
toolCalls[i].Function.Name = tc.Function.Name
|
||||
toolCalls[i].Index = tc.Function.Index
|
||||
|
||||
args, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
@@ -215,7 +238,11 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
|
||||
toolCalls[i].Function.Arguments = string(args)
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||
return ChatCompletion{
|
||||
Id: id,
|
||||
Object: "chat.completion",
|
||||
@@ -235,15 +262,12 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
Usage: Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
Usage: toUsage(r),
|
||||
}
|
||||
}
|
||||
|
||||
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||
return ChatCompletionChunk{
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
@@ -252,7 +276,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||
SystemFingerprint: "fp_ollama",
|
||||
Choices: []ChunkChoice{{
|
||||
Index: 0,
|
||||
Delta: Message{Role: "assistant", Content: r.Message.Content},
|
||||
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
|
||||
FinishReason: func(reason string) *string {
|
||||
if len(reason) > 0 {
|
||||
return &reason
|
||||
@@ -263,6 +287,14 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||
}
|
||||
}
|
||||
|
||||
func toUsageGenerate(r api.GenerateResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||
return Completion{
|
||||
Id: id,
|
||||
@@ -280,11 +312,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
Usage: Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
Usage: toUsageGenerate(r),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -475,9 +503,21 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
options["top_p"] = 1.0
|
||||
}
|
||||
|
||||
var format string
|
||||
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
|
||||
format = "json"
|
||||
var format json.RawMessage
|
||||
if r.ResponseFormat != nil {
|
||||
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
|
||||
// Support the old "json_object" type for OpenAI compatibility
|
||||
case "json_object":
|
||||
format = json.RawMessage(`"json"`)
|
||||
case "json_schema":
|
||||
if r.ResponseFormat.JsonSchema != nil {
|
||||
schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal json schema: %w", err)
|
||||
}
|
||||
format = schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &api.ChatRequest{
|
||||
@@ -546,14 +586,16 @@ type BaseWriter struct {
|
||||
}
|
||||
|
||||
type ChatWriter struct {
|
||||
stream bool
|
||||
id string
|
||||
stream bool
|
||||
streamOptions *StreamOptions
|
||||
id string
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type CompleteWriter struct {
|
||||
stream bool
|
||||
id string
|
||||
stream bool
|
||||
streamOptions *StreamOptions
|
||||
id string
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
@@ -571,7 +613,7 @@ type EmbedWriter struct {
|
||||
model string
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||
var serr api.StatusError
|
||||
err := json.Unmarshal(data, &serr)
|
||||
if err != nil {
|
||||
@@ -596,7 +638,11 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// chat chunk
|
||||
if w.stream {
|
||||
d, err := json.Marshal(toChunk(w.id, chatResponse))
|
||||
c := toChunk(w.id, chatResponse)
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
c.Usage = &Usage{}
|
||||
}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -608,6 +654,17 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
if chatResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := toUsage(chatResponse)
|
||||
d, err := json.Marshal(ChatCompletionChunk{Choices: []ChunkChoice{}, Usage: &u})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -630,7 +687,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(code, data)
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
@@ -645,7 +702,11 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// completion chunk
|
||||
if w.stream {
|
||||
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
|
||||
c := toCompleteChunk(w.id, generateResponse)
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
c.Usage = &Usage{}
|
||||
}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -657,6 +718,17 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
if generateResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := toUsageGenerate(generateResponse)
|
||||
d, err := json.Marshal(CompletionChunk{Choices: []CompleteChunkChoice{}, Usage: &u})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -679,7 +751,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(code, data)
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
@@ -704,7 +776,7 @@ func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(code, data)
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
@@ -730,7 +802,7 @@ func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(code, data)
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
@@ -755,7 +827,7 @@ func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(code, data)
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
@@ -819,9 +891,10 @@ func CompletionsMiddleware() gin.HandlerFunc {
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &CompleteWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
@@ -901,9 +974,10 @@ func ChatMiddleware() gin.HandlerFunc {
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ChatWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -107,7 +108,46 @@ func TestChatMiddleware(t *testing.T) {
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
},
|
||||
Format: "json",
|
||||
Format: json.RawMessage(`"json"`),
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with streaming usage",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": true,
|
||||
"stream_options": {"include_usage": true},
|
||||
"max_tokens": 999,
|
||||
"seed": 123,
|
||||
"stop": ["\n", "stop"],
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
"response_format": {"type": "json_object"}
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||
"seed": 123.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
},
|
||||
Format: json.RawMessage(`"json"`),
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
@@ -195,7 +235,86 @@ func TestChatMiddleware(t *testing.T) {
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "chat handler with streaming tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris?"}
|
||||
],
|
||||
"stream": true,
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris?",
|
||||
},
|
||||
},
|
||||
Tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: "string",
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler error forwarding",
|
||||
body: `{
|
||||
@@ -237,13 +356,13 @@ func TestChatMiddleware(t *testing.T) {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatal("requests did not match")
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match: %+v", diff)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatal("errors did not match")
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -283,6 +402,55 @@ func TestCompletionsMiddleware(t *testing.T) {
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler stream",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"stream": true,
|
||||
"temperature": 0.8,
|
||||
"stop": ["\n", "stop"],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Options: map[string]any{
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
},
|
||||
Suffix: "suffix",
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler stream with usage",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"stream": true,
|
||||
"stream_options": {"include_usage": true},
|
||||
"temperature": 0.8,
|
||||
"stop": ["\n", "stop"],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Options: map[string]any{
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
},
|
||||
Suffix: "suffix",
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler error forwarding",
|
||||
body: `{
|
||||
|
@@ -5,7 +5,6 @@ import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -24,14 +23,12 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -985,37 +982,7 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||
|
||||
var errUnauthorized = errors.New("unauthorized: access denied")
|
||||
|
||||
// getTokenSubject returns the subject of a JWT token, it does not validate the token
|
||||
func getTokenSubject(token string) string {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return ""
|
||||
}
|
||||
|
||||
payload := parts[1]
|
||||
payloadBytes, err := base64.RawURLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err))
|
||||
return ""
|
||||
}
|
||||
|
||||
var payloadMap map[string]interface{}
|
||||
if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err))
|
||||
return ""
|
||||
}
|
||||
|
||||
sub, ok := payloadMap["sub"]
|
||||
if !ok {
|
||||
slog.Error("jwt does not contain 'sub' field")
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s", sub)
|
||||
}
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||
anonymous := true // access will default to anonymous if no user is found associated with the public key
|
||||
for range 2 {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
if err != nil {
|
||||
@@ -1036,7 +1003,6 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
anonymous = getTokenSubject(token) == "anonymous"
|
||||
regOpts.Token = token
|
||||
if body != nil {
|
||||
_, err = body.Seek(0, io.SeekStart)
|
||||
@@ -1059,16 +1025,6 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
}
|
||||
}
|
||||
|
||||
if anonymous {
|
||||
// no user is associated with the public key, and the request requires non-anonymous access
|
||||
pubKey, nestedErr := auth.GetPublicKey()
|
||||
if nestedErr != nil {
|
||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
||||
}
|
||||
// user is associated with the public key, but is not authorized to make the request
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
|
||||
@@ -1120,17 +1076,15 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
|
||||
req.ContentLength = contentLength
|
||||
}
|
||||
|
||||
resp, err := (&http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: testMakeRequestDialContext,
|
||||
},
|
||||
c := &http.Client{
|
||||
CheckRedirect: regOpts.CheckRedirect,
|
||||
}).Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
if testMakeRequestDialContext != nil {
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.DialContext = testMakeRequestDialContext
|
||||
c.Transport = tr
|
||||
}
|
||||
return c.Do(req)
|
||||
}
|
||||
|
||||
func getValue(header, key string) string {
|
||||
|
@@ -39,6 +39,7 @@ func TestExecuteWithTools(t *testing.T) {
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
|
||||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
|
||||
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||
|
||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
|
@@ -148,10 +148,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Format != "" && req.Format != "json" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
|
||||
return
|
||||
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
||||
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
||||
return
|
||||
}
|
||||
@@ -251,6 +248,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
var b bytes.Buffer
|
||||
if req.Context != nil {
|
||||
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
|
||||
s, err := r.Detokenize(c.Request.Context(), req.Context)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -1141,7 +1139,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
config.AllowWildcard = true
|
||||
config.AllowBrowserExtensions = true
|
||||
config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
|
||||
openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"}
|
||||
openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async"}
|
||||
for _, prop := range openAIProperties {
|
||||
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
|
||||
}
|
||||
@@ -1458,6 +1456,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -1467,6 +1466,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
var sb strings.Builder
|
||||
var toolCallIndex int = 0
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
@@ -1492,7 +1493,37 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
ch <- res
|
||||
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
|
||||
// however this was a simple change for now without reworking streaming logic of this (and other)
|
||||
// handlers
|
||||
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
|
||||
ch <- res
|
||||
return
|
||||
}
|
||||
|
||||
// Streaming tool calls:
|
||||
// If tools are recognized, use a flag to track the sending of a tool downstream
|
||||
// This ensures that content is cleared from the message on the last chunk sent
|
||||
sb.WriteString(r.Content)
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
for i := range toolCalls {
|
||||
toolCalls[i].Function.Index = toolCallIndex
|
||||
toolCallIndex++
|
||||
}
|
||||
res.Message.Content = ""
|
||||
sb.Reset()
|
||||
ch <- res
|
||||
return
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
// Send any remaining content if no tool calls were detected
|
||||
if toolCallIndex == 0 {
|
||||
res.Message.Content = sb.String()
|
||||
}
|
||||
ch <- res
|
||||
}
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -25,10 +26,14 @@ type mockRunner struct {
|
||||
// CompletionRequest is only valid until the next call to Completion
|
||||
llm.CompletionRequest
|
||||
llm.CompletionResponse
|
||||
CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
|
||||
}
|
||||
|
||||
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
m.CompletionRequest = r
|
||||
if m.CompletionFn != nil {
|
||||
return m.CompletionFn(ctx, r, fn)
|
||||
}
|
||||
fn(m.CompletionResponse)
|
||||
return nil
|
||||
}
|
||||
@@ -88,9 +93,14 @@ func TestGenerateChat(t *testing.T) {
|
||||
Model: "test",
|
||||
Modelfile: fmt.Sprintf(`FROM %s
|
||||
TEMPLATE """
|
||||
{{- if .System }}System: {{ .System }} {{ end }}
|
||||
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
||||
{{- if .Tools }}
|
||||
{{ .Tools }}
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- .Role }}: {{ .Content }}
|
||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{- end }}
|
||||
{{ end }}"""
|
||||
`, createBinFile(t, llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
@@ -263,7 +273,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -292,7 +302,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -314,7 +324,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -337,12 +347,242 @@ func TestGenerateChat(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||
})
|
||||
|
||||
t.Run("messages with tools (non-streaming)", func(t *testing.T) {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("failed to create test-system model: %d", w.Code)
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: "string",
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mock.CompletionResponse = llm.CompletionResponse{
|
||||
Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
|
||||
Done: true,
|
||||
DoneReason: "done",
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
}
|
||||
|
||||
streamRequest := true
|
||||
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-system",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Seattle?"},
|
||||
},
|
||||
Tools: tools,
|
||||
Stream: &streamRequest,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
var errResp struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
|
||||
t.Logf("Failed to decode error response: %v", err)
|
||||
} else {
|
||||
t.Logf("Error response: %s", errResp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Message.ToolCalls == nil {
|
||||
t.Error("expected tool calls, got nil")
|
||||
}
|
||||
|
||||
expectedToolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Seattle, WA",
|
||||
"unit": "celsius",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
|
||||
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("messages with tools (streaming)", func(t *testing.T) {
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: "string",
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate streaming response with multiple chunks
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
||||
// Send chunks with small delays to simulate streaming
|
||||
responses := []llm.CompletionResponse{
|
||||
{
|
||||
Content: `{"name":"get_`,
|
||||
Done: false,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
},
|
||||
{
|
||||
Content: `weather","arguments":{"location":"Seattle`,
|
||||
Done: false,
|
||||
PromptEvalCount: 2,
|
||||
PromptEvalDuration: 1,
|
||||
},
|
||||
{
|
||||
Content: `, WA","unit":"celsius"}}`,
|
||||
Done: true,
|
||||
DoneReason: "tool_call",
|
||||
PromptEvalCount: 3,
|
||||
PromptEvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, resp := range responses {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
fn(resp)
|
||||
time.Sleep(10 * time.Millisecond) // Small delay between chunks
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-system",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Seattle?"},
|
||||
},
|
||||
Tools: tools,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Read and validate the streamed responses
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
var finalToolCall api.ToolCall
|
||||
|
||||
for {
|
||||
var resp api.ChatResponse
|
||||
if err := decoder.Decode(&resp); err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Done {
|
||||
if len(resp.Message.ToolCalls) != 1 {
|
||||
t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
|
||||
}
|
||||
finalToolCall = resp.Message.ToolCalls[0]
|
||||
}
|
||||
}
|
||||
|
||||
expectedToolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Seattle, WA",
|
||||
"unit": "celsius",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
|
||||
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
|
Reference in New Issue
Block a user